Import Upstream version 20180207
[hcoop/debian/mlton.git] / benchmark / tests / fft.sml
CommitLineData
7f918cf1
CE
1(* From the SML/NJ benchmark suite. *)
2fun print _ = ()
3signature BMARK =
4 sig
5 val doit : int -> unit
6 end;
7structure Main: BMARK = struct
8
9local
10open Array Math
11
12val printr = print o (Real.fmt (StringCvt.SCI(SOME 14)))
13val printi = print o Int.toString
14in
15
16val PI = 3.14159265358979323846
17
18val tpi = 2.0 * PI
19(*
20fun trace(name,f) x =
21 (print name ;
22 print "(" ;
23 printr x ;
24 print ") = " ;
25 let val y = f x
26 in printr y ;
27 print "\n" ;
28 y
29 end)
30
31fun trace2(name,f) (x,y) =
32 (printr x ;
33 print " ";
34 print name ;
35 print " ";
36 printr y ;
37 print " = " ;
38 let val z = f(x,y)
39 in printr z ;
40 print "\n" ;
41 z
42 end)
43
44fun trace2(_,f) = f
45
46val op * = trace2("*", Real.* )
47val op - = trace2("-", Real.-)
48val op + = trace2("+", Real.+)
49
50local
51 nonfix * + -
52in
53 _overload * : ('a * 'a -> 'a)
54 as Int.*
55 and *
56
57 _overload + : ('a * 'a -> 'a)
58 as Int.+
59 and +
60
61 _overload - : ('a * 'a -> 'a)
62 as Int.-
63 and -
64end
65
66val sin = trace("sin", sin)
67val cos = trace("cos", cos)
68
69val sub =
70 fn (a,i) =>
71 let val x = sub(a,i)
72 in print "sub(_, " ;
73 printi i ;
74 print ") = ";
75 printr x ;
76 print "\n" ;
77 x
78 end
79
80val update =
81 fn (a,i,x) =>
82 (update(a,i,x);
83 print "update(_, " ;
84 printi i ;
85 print ", " ;
86 printr x ;
87 print ")\n")
88
89*)
90fun fft px py np =
91 let
92 fun find_num_points i m =
93 if i < np
94 then find_num_points (i+i) (m+1)
95 else (i,m)
96 val (n,m) = find_num_points 2 1
97(* val _ = (printi n ;
98 print "\n" ;
99 printi m ;
100 print "\n") *)
101 in
102 if n <> np then
103 let
104 fun loop i =
105 if i > n then ()
106 else (update(px, i, 0.0);
107 update(py, i, 0.0);
108 loop (i+1))
109 in
110 loop (np+1);
111 print "Use "; printi n; print " point fft\n"
112 end
113 else ();
114
115 let
116 fun loop_k k n2 =
117 if k >= m then ()
118 else
119 let
120 val n4 = n2 div 4
121 val e = tpi / (real n2)
122 fun loop_j j a =
123 if j > n4 then ()
124 else
125 let val a3 = 3.0 * a
126 val cc1 = cos(a)
127 val ss1 = sin(a)
128 val cc3 = cos(a3)
129 val ss3 = sin(a3)
130 fun loop_is is id =
131 if is >= n
132 then ()
133 else
134 let
135 fun loop_i0 i0 =
136 if i0 >= n
137 then ()
138 else
139 let val i1 = i0 + n4
140 val i2 = i1 + n4
141 val i3 = i2 + n4
142 val r1 = sub(px, i0) - sub(px, i2)
143 val _ = update(px, i0, sub(px, i0) + sub(px, i2))
144 val r2 = sub(px, i1) - sub(px, i3)
145 val _ = update(px, i1, sub(px, i1) + sub(px, i3))
146 val s1 = sub(py, i0) - sub(py, i2)
147 val _ = update(py, i0, sub(py, i0) + sub(py, i2))
148 val s2 = sub(py, i1) - sub(py, i3)
149 val _ = update(py, i1, sub(py, i1) + sub(py, i3))
150 val s3 = r1 - s2
151 val r1 = r1 + s2
152 val s2 = r2 - s1
153 val r2 = r2 + s1
154 val _ = update(px, i2, r1*cc1 - s2*ss1)
155 val _ = update(py, i2, ~s2*cc1 - r1*ss1)
156 val _ = update(px, i3, s3*cc3 + r2*ss3)
157 val _ = update(py, i3, r2*cc3 - s3*ss3)
158 in
159 loop_i0 (i0 + id)
160 end
161 in
162 loop_i0 is;
163 loop_is (2 * id - n2 + j) (4 * id)
164 end
165 in
166 loop_is j (2 * n2);
167 loop_j (j+1) (e * real j)
168 end
169 in
170 loop_j 1 0.0;
171 loop_k (k+1) (n2 div 2)
172 end
173 in
174 loop_k 1 n
175 end;
176
177(************************************)
178(* Last stage, length=2 butterfly *)
179(************************************)
180
181let fun loop_is is id = if is >= n then () else
182 let fun loop_i0 i0 = if i0 > n then () else
183 let val i1 = i0 + 1
184 val r1 = sub(px, i0)
185 val _ = update(px, i0, r1 + sub(px, i1))
186 val _ = update(px, i1, r1 - sub(px, i1))
187 val r1 = sub(py, i0)
188 val _ = update(py, i0, r1 + sub(py, i1))
189 val _ = update(py, i1, r1 - sub(py, i1))
190 in
191 loop_i0 (i0 + id)
192 end
193 in
194 loop_i0 is;
195 loop_is (2*id - 1) (4 * id)
196 end
197 in
198 loop_is 1 4
199 end;
200
201(*************************)
202(* Bit reverse counter *)
203(*************************)
204
205 let
206 fun loop_i i j =
207 if i >= n
208 then ()
209 else
210 (if i < j
211 then (let val xt = sub(px, j)
212 in update(px, j, sub(px, i)); update(px, i, xt)
213 end;
214 let val xt = sub(py, j)
215 in update(py, j, sub(py, i)); update(py, i, xt)
216 end)
217 else ();
218 let
219 fun loop_k k j =
220 if k < j then loop_k (k div 2) (j-k) else j+k
221 val j' = loop_k (n div 2) j
222 in
223 loop_i (i+1) j'
224 end)
225 in
226 loop_i 1 1
227 end;
228
229 n
230
231 end
232
233fun abs x = if x >= 0.0 then x else ~x
234
235fun test np =
236 let val _ = (printi np; print "... ")
237 val enp = real np
238 val npm = (np div 2) - 1
239 val pxr = array (np+2, 0.0)
240 val pxi = array (np+2, 0.0)
241 val t = PI / enp
242 val _ = update(pxr, 1, (enp - 1.0) * 0.5)
243 val _ = update(pxi, 1, 0.0)
244 val n2 = np div 2
245 val _ = update(pxr, n2+1, ~0.5)
246 val _ = update(pxi, n2+1, 0.0)
247 fun loop_i i = if i > npm then () else
248 let val j = np - i
249 val _ = update(pxr, i+1, ~0.5)
250 val _ = update(pxr, j+1, ~0.5)
251 val z = t * real i
252 val y = ~0.5*(cos(z)/sin(z))
253 val _ = update(pxi, i+1, y)
254 val _ = update(pxi, j+1, ~y)
255 in
256 loop_i (i+1)
257 end
258 val _ = loop_i 1
259
260(* val _ = print "\n"
261 fun loop_i i = if i > 15 then () else
262 (printi i; print "\t";
263 printr (sub(pxr, i+1)); print "\t";
264 printr (sub(pxi, i+1)); print "\n"; loop_i (i+1))
265 val _ = loop_i 0
266*)
267 val _ = fft pxr pxi np
268(*
269 fun loop_i i = if i > 15 then () else
270 (printi i; print "\t";
271 printr (sub(pxr, i+1)); print "\t";
272 printr (sub(pxi, i+1)); print "\n"; loop_i (i+1))
273 val _ = loop_i 0
274*)
275 fun loop_i i zr zi kr ki = if i >= np then (zr,zi) else
276 let val a = abs(sub(pxr, i+1) - real i)
277 val (zr, kr) =
278 if zr < a then (a, i) else (zr, kr)
279 val a = abs(sub(pxi, i+1))
280 val (zi, ki) =
281 if zi < a then (a, i) else (zi, ki)
282 in
283 loop_i (i+1) zr zi kr ki
284 end
285 val (zr, zi) = loop_i 0 0.0 0.0 0 0
286 val zm = if abs zr < abs zi then zi else zr
287 in
288 printr zm; print "\n"
289 end
290
291fun loop_np i np = if i > 15 then () else
292 (test np; loop_np (i+1) (np*2))
293
294fun doit n =
295 if n = 0
296 then ()
297 else (loop_np 1 256; doit (n - 1))
298
299end
300end;