Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* From the SML/NJ benchmark suite. *) |
2 | fun print _ = () | |
3 | signature BMARK = | |
4 | sig | |
5 | val doit : int -> unit | |
6 | end; | |
7 | structure Main: BMARK = struct | |
8 | ||
9 | local | |
10 | open Array Math | |
11 | ||
12 | val printr = print o (Real.fmt (StringCvt.SCI(SOME 14))) | |
13 | val printi = print o Int.toString | |
14 | in | |
15 | ||
16 | val PI = 3.14159265358979323846 | |
17 | ||
18 | val tpi = 2.0 * PI | |
19 | (* | |
20 | fun trace(name,f) x = | |
21 | (print name ; | |
22 | print "(" ; | |
23 | printr x ; | |
24 | print ") = " ; | |
25 | let val y = f x | |
26 | in printr y ; | |
27 | print "\n" ; | |
28 | y | |
29 | end) | |
30 | ||
31 | fun trace2(name,f) (x,y) = | |
32 | (printr x ; | |
33 | print " "; | |
34 | print name ; | |
35 | print " "; | |
36 | printr y ; | |
37 | print " = " ; | |
38 | let val z = f(x,y) | |
39 | in printr z ; | |
40 | print "\n" ; | |
41 | z | |
42 | end) | |
43 | ||
44 | fun trace2(_,f) = f | |
45 | ||
46 | val op * = trace2("*", Real.* ) | |
47 | val op - = trace2("-", Real.-) | |
48 | val op + = trace2("+", Real.+) | |
49 | ||
50 | local | |
51 | nonfix * + - | |
52 | in | |
53 | _overload * : ('a * 'a -> 'a) | |
54 | as Int.* | |
55 | and * | |
56 | ||
57 | _overload + : ('a * 'a -> 'a) | |
58 | as Int.+ | |
59 | and + | |
60 | ||
61 | _overload - : ('a * 'a -> 'a) | |
62 | as Int.- | |
63 | and - | |
64 | end | |
65 | ||
66 | val sin = trace("sin", sin) | |
67 | val cos = trace("cos", cos) | |
68 | ||
69 | val sub = | |
70 | fn (a,i) => | |
71 | let val x = sub(a,i) | |
72 | in print "sub(_, " ; | |
73 | printi i ; | |
74 | print ") = "; | |
75 | printr x ; | |
76 | print "\n" ; | |
77 | x | |
78 | end | |
79 | ||
80 | val update = | |
81 | fn (a,i,x) => | |
82 | (update(a,i,x); | |
83 | print "update(_, " ; | |
84 | printi i ; | |
85 | print ", " ; | |
86 | printr x ; | |
87 | print ")\n") | |
88 | ||
89 | *) | |
90 | fun fft px py np = | |
91 | let | |
92 | fun find_num_points i m = | |
93 | if i < np | |
94 | then find_num_points (i+i) (m+1) | |
95 | else (i,m) | |
96 | val (n,m) = find_num_points 2 1 | |
97 | (* val _ = (printi n ; | |
98 | print "\n" ; | |
99 | printi m ; | |
100 | print "\n") *) | |
101 | in | |
102 | if n <> np then | |
103 | let | |
104 | fun loop i = | |
105 | if i > n then () | |
106 | else (update(px, i, 0.0); | |
107 | update(py, i, 0.0); | |
108 | loop (i+1)) | |
109 | in | |
110 | loop (np+1); | |
111 | print "Use "; printi n; print " point fft\n" | |
112 | end | |
113 | else (); | |
114 | ||
115 | let | |
116 | fun loop_k k n2 = | |
117 | if k >= m then () | |
118 | else | |
119 | let | |
120 | val n4 = n2 div 4 | |
121 | val e = tpi / (real n2) | |
122 | fun loop_j j a = | |
123 | if j > n4 then () | |
124 | else | |
125 | let val a3 = 3.0 * a | |
126 | val cc1 = cos(a) | |
127 | val ss1 = sin(a) | |
128 | val cc3 = cos(a3) | |
129 | val ss3 = sin(a3) | |
130 | fun loop_is is id = | |
131 | if is >= n | |
132 | then () | |
133 | else | |
134 | let | |
135 | fun loop_i0 i0 = | |
136 | if i0 >= n | |
137 | then () | |
138 | else | |
139 | let val i1 = i0 + n4 | |
140 | val i2 = i1 + n4 | |
141 | val i3 = i2 + n4 | |
142 | val r1 = sub(px, i0) - sub(px, i2) | |
143 | val _ = update(px, i0, sub(px, i0) + sub(px, i2)) | |
144 | val r2 = sub(px, i1) - sub(px, i3) | |
145 | val _ = update(px, i1, sub(px, i1) + sub(px, i3)) | |
146 | val s1 = sub(py, i0) - sub(py, i2) | |
147 | val _ = update(py, i0, sub(py, i0) + sub(py, i2)) | |
148 | val s2 = sub(py, i1) - sub(py, i3) | |
149 | val _ = update(py, i1, sub(py, i1) + sub(py, i3)) | |
150 | val s3 = r1 - s2 | |
151 | val r1 = r1 + s2 | |
152 | val s2 = r2 - s1 | |
153 | val r2 = r2 + s1 | |
154 | val _ = update(px, i2, r1*cc1 - s2*ss1) | |
155 | val _ = update(py, i2, ~s2*cc1 - r1*ss1) | |
156 | val _ = update(px, i3, s3*cc3 + r2*ss3) | |
157 | val _ = update(py, i3, r2*cc3 - s3*ss3) | |
158 | in | |
159 | loop_i0 (i0 + id) | |
160 | end | |
161 | in | |
162 | loop_i0 is; | |
163 | loop_is (2 * id - n2 + j) (4 * id) | |
164 | end | |
165 | in | |
166 | loop_is j (2 * n2); | |
167 | loop_j (j+1) (e * real j) | |
168 | end | |
169 | in | |
170 | loop_j 1 0.0; | |
171 | loop_k (k+1) (n2 div 2) | |
172 | end | |
173 | in | |
174 | loop_k 1 n | |
175 | end; | |
176 | ||
177 | (************************************) | |
178 | (* Last stage, length=2 butterfly *) | |
179 | (************************************) | |
180 | ||
181 | let fun loop_is is id = if is >= n then () else | |
182 | let fun loop_i0 i0 = if i0 > n then () else | |
183 | let val i1 = i0 + 1 | |
184 | val r1 = sub(px, i0) | |
185 | val _ = update(px, i0, r1 + sub(px, i1)) | |
186 | val _ = update(px, i1, r1 - sub(px, i1)) | |
187 | val r1 = sub(py, i0) | |
188 | val _ = update(py, i0, r1 + sub(py, i1)) | |
189 | val _ = update(py, i1, r1 - sub(py, i1)) | |
190 | in | |
191 | loop_i0 (i0 + id) | |
192 | end | |
193 | in | |
194 | loop_i0 is; | |
195 | loop_is (2*id - 1) (4 * id) | |
196 | end | |
197 | in | |
198 | loop_is 1 4 | |
199 | end; | |
200 | ||
201 | (*************************) | |
202 | (* Bit reverse counter *) | |
203 | (*************************) | |
204 | ||
205 | let | |
206 | fun loop_i i j = | |
207 | if i >= n | |
208 | then () | |
209 | else | |
210 | (if i < j | |
211 | then (let val xt = sub(px, j) | |
212 | in update(px, j, sub(px, i)); update(px, i, xt) | |
213 | end; | |
214 | let val xt = sub(py, j) | |
215 | in update(py, j, sub(py, i)); update(py, i, xt) | |
216 | end) | |
217 | else (); | |
218 | let | |
219 | fun loop_k k j = | |
220 | if k < j then loop_k (k div 2) (j-k) else j+k | |
221 | val j' = loop_k (n div 2) j | |
222 | in | |
223 | loop_i (i+1) j' | |
224 | end) | |
225 | in | |
226 | loop_i 1 1 | |
227 | end; | |
228 | ||
229 | n | |
230 | ||
231 | end | |
232 | ||
233 | fun abs x = if x >= 0.0 then x else ~x | |
234 | ||
235 | fun test np = | |
236 | let val _ = (printi np; print "... ") | |
237 | val enp = real np | |
238 | val npm = (np div 2) - 1 | |
239 | val pxr = array (np+2, 0.0) | |
240 | val pxi = array (np+2, 0.0) | |
241 | val t = PI / enp | |
242 | val _ = update(pxr, 1, (enp - 1.0) * 0.5) | |
243 | val _ = update(pxi, 1, 0.0) | |
244 | val n2 = np div 2 | |
245 | val _ = update(pxr, n2+1, ~0.5) | |
246 | val _ = update(pxi, n2+1, 0.0) | |
247 | fun loop_i i = if i > npm then () else | |
248 | let val j = np - i | |
249 | val _ = update(pxr, i+1, ~0.5) | |
250 | val _ = update(pxr, j+1, ~0.5) | |
251 | val z = t * real i | |
252 | val y = ~0.5*(cos(z)/sin(z)) | |
253 | val _ = update(pxi, i+1, y) | |
254 | val _ = update(pxi, j+1, ~y) | |
255 | in | |
256 | loop_i (i+1) | |
257 | end | |
258 | val _ = loop_i 1 | |
259 | ||
260 | (* val _ = print "\n" | |
261 | fun loop_i i = if i > 15 then () else | |
262 | (printi i; print "\t"; | |
263 | printr (sub(pxr, i+1)); print "\t"; | |
264 | printr (sub(pxi, i+1)); print "\n"; loop_i (i+1)) | |
265 | val _ = loop_i 0 | |
266 | *) | |
267 | val _ = fft pxr pxi np | |
268 | (* | |
269 | fun loop_i i = if i > 15 then () else | |
270 | (printi i; print "\t"; | |
271 | printr (sub(pxr, i+1)); print "\t"; | |
272 | printr (sub(pxi, i+1)); print "\n"; loop_i (i+1)) | |
273 | val _ = loop_i 0 | |
274 | *) | |
275 | fun loop_i i zr zi kr ki = if i >= np then (zr,zi) else | |
276 | let val a = abs(sub(pxr, i+1) - real i) | |
277 | val (zr, kr) = | |
278 | if zr < a then (a, i) else (zr, kr) | |
279 | val a = abs(sub(pxi, i+1)) | |
280 | val (zi, ki) = | |
281 | if zi < a then (a, i) else (zi, ki) | |
282 | in | |
283 | loop_i (i+1) zr zi kr ki | |
284 | end | |
285 | val (zr, zi) = loop_i 0 0.0 0.0 0 0 | |
286 | val zm = if abs zr < abs zi then zi else zr | |
287 | in | |
288 | printr zm; print "\n" | |
289 | end | |
290 | ||
291 | fun loop_np i np = if i > 15 then () else | |
292 | (test np; loop_np (i+1) (np*2)) | |
293 | ||
294 | fun doit n = | |
295 | if n = 0 | |
296 | then () | |
297 | else (loop_np 1 256; doit (n - 1)) | |
298 | ||
299 | end | |
300 | end; |