Import Upstream version 20180207
[hcoop/debian/mlton.git] / benchmark / tests / zern.sml
1 (*
2 * From David McClain's language study.
3 * http://www.azstarnet.com/~dmcclain/LanguageStudy.html
4 *
5 * Stephen Weeks replaced Unsafe.Real64Array with Real64Array.
6 *)
7
8 fun print _ = ()
9
10 (* array2.sml
11 *
12 * COPYRIGHT (c) 1998 D.McClain/MCFA
13 * COPYRIGHT (c) 1997 AT&T Research.
14 *)
15
16 structure FastRealArray2 :
17 sig
18 type array
19
20 type region
21 = {base : array,
22 row : int,
23 col : int,
24 nrows : int option,
25 ncols : int option}
26
27 datatype traversal = RowMajor | ColMajor
28
29 val array : int * int * real -> array
30 val fromList : real list list -> array
31 val tabulate : traversal -> (int * int * (int * int -> real)) -> array
32 val sub : array * int * int -> real
33 val update : array * int * int * real -> unit
34 val dimensions : array -> int * int
35 val size : array -> int
36 val nCols : array -> int
37 val nRows : array -> int
38 val row : array * int -> real Vector.vector
39 val column : array * int -> real Vector.vector
40
41 val copy : region * array * int * int -> unit
42 val appi : traversal -> (int * int * real -> unit) -> region -> unit
43 val app : traversal -> (real -> unit) -> array -> unit
44 val modifyi : traversal -> (int * int * real -> real) -> region -> unit
45 val modify : traversal -> (real -> real) -> array -> unit
46 val foldi : traversal -> (int*int*real*'a -> 'a) -> 'a -> region -> 'a
47 val fold : traversal -> (real * 'a -> 'a) -> 'a -> array -> 'a
48
49 val rmSub : array * int -> real
50 val rmUpdate : array * int * real -> unit
51
52 val unop : array * array * (real -> real) -> unit
53 val unopi : array * array * (real * int -> real) -> unit
54 val binop : array * array * array * (real * real -> real) -> unit
55 val binopi : array * array * array * (real * real * int -> real) -> unit
56 val fill : array * real -> unit
57 val fillf : array * (int -> real) -> unit
58
59 val transpose : array -> array
60 val extract : region -> array
61
62 (*
63 val shift : array * int * int -> array
64 *)
65 end =
66 struct
67
68 structure A = (*Unsafe.*)Real64Array
69
70 type rawArray = A.array
71
72 val unsafeUpdate = A.update
73 val unsafeSub = A.sub
74 fun mkRawArray n = A.array (n, 0.0)
75
76
77 type array = {data : rawArray,
78 nrows : int,
79 ncols : int,
80 nelts : int}
81
82 type region = {base : array,
83 row : int,
84 col : int,
85 nrows : int option,
86 ncols : int option}
87
88 datatype traversal = RowMajor | ColMajor
89
90
91 fun dotimes n f =
92 let (* going forward is twice as fast as backward! *)
93 fun iter k = if k >= n then ()
94 else (f(k); iter(k+1))
95 in
96 iter 0
97 end
98
99
100 fun mkArray(n,v) =
101 let
102 val arr = mkRawArray n
103 in
104 dotimes n (fn ix => unsafeUpdate(arr,ix,v));
105 arr
106 end
107
108 (* compute the index of an array element *)
109 fun ltu(i,limit) = (i >= 0) andalso (i < limit)
110 fun unsafeIndex ({nrows, ncols, ...} : array, i, j) = (i*ncols + j)
111 fun index (arr, i, j) =
112 if (ltu(i, #nrows arr) andalso ltu(j, #ncols arr))
113 then unsafeIndex (arr, i, j)
114 else raise General.Subscript
115 (* row major index checking *)
116 fun rmIndex ({nelts,...}: array, ix) =
117 if ltu(ix, nelts) then ix
118 else raise General.Subscript
119
120 val max_length = 4096 * 4096; (* arbitrary - but this is 128 MB *)
121
122 fun chkSize (nrows, ncols) =
123 if (nrows <= 0) orelse (ncols <= 0)
124 then raise General.Size
125 else let
126 val n = nrows*ncols handle Overflow => raise General.Size
127 in
128 if (max_length < n) then raise General.Size else n
129 end
130
131 fun array (nrows, ncols, v) =
132 let
133 val nelts = chkSize (nrows, ncols)
134 in
135 {data = mkArray (nelts, v),
136 nrows = nrows, ncols = ncols, nelts = nelts}
137 end
138
139 fun fromList [] = raise General.Size
140 | fromList (row1 :: rest) = let
141 val ncols = List.length row1
142 fun chk ([], nrows, l) = (nrows, l)
143 | chk (row::rest, nrows, l) = let
144 fun chkRow ([], n, revCol) = (
145 if (n <> ncols) then raise General.Size else ();
146 List.revAppend (revCol, l))
147 | chkRow (x::r, n, revCol) = chkRow (r, n+1, x::revCol)
148 in
149 chk (rest, nrows+1, chkRow(row, 0, []))
150 end
151 val (nrows, flatList) = chk (rest, 1, [])
152 val nelts = chkSize(nrows, ncols)
153 val arr = mkRawArray nelts
154 fun upd(_,nil) = arr
155 | upd(k,v::vs) = (unsafeUpdate(arr,k,v); upd(k+1,vs))
156 in
157 { data = upd(0,List.@(row1, flatList)),
158 nrows = nrows,
159 ncols = ncols,
160 nelts = nelts }
161 end
162
163 fun tabulateRM (nrows, ncols, f) =
164 let
165 val nelts = chkSize(nrows, ncols)
166 val arr = mkRawArray nelts
167 fun lp1 (i, j, k) = if (i < nrows)
168 then lp2 (i, 0, k)
169 else ()
170 and lp2 (i, j, k) = if (j < ncols)
171 then (
172 unsafeUpdate(arr, k, f(i, j));
173 lp2 (i, j+1, k+1))
174 else lp1 (i+1, 0, k)
175 in
176 lp2 (0, 0, 0);
177 {data = arr, nrows = nrows, ncols = ncols, nelts = nelts}
178 end
179
180 fun tabulateCM (nrows, ncols, f) =
181 let
182 val nelts = chkSize(nrows,ncols)
183 val arr = mkRawArray nelts
184 val delta = nelts - 1
185 fun lp1 (i, j, k) = if (j < ncols)
186 then lp2 (0, j, k)
187 else ()
188 and lp2 (i, j, k) = if (i < nrows)
189 then (
190 unsafeUpdate(arr, k, f(i, j));
191 lp2 (i+1, j, k+ncols))
192 else lp1 (0, j+1, k-delta)
193 in
194 lp2 (0, 0, 0);
195 {data = arr, nrows = nrows, ncols = ncols, nelts = nelts}
196 end
197
198 fun tabulate RowMajor = tabulateRM
199 | tabulate ColMajor = tabulateCM
200
201 fun sub (a, i, j) = unsafeSub(#data a, index(a, i, j))
202 fun update (a, i, j, v) = unsafeUpdate(#data a, index(a, i, j), v)
203 fun dimensions ({nrows, ncols, ...}: array) = (nrows, ncols)
204 fun size ({nelts,...}: array) = nelts
205 fun nCols (arr : array) = #ncols arr
206 fun nRows (arr : array) = #nrows arr
207 fun row ({data, nrows, ncols, ...}: array, i) =
208 if ltu(i, nrows) then
209 let
210 val stop = i*ncols
211 fun mkVec (j, l) =
212 if (j < stop)
213 then Vector.fromList l
214 else mkVec(j-1, unsafeSub(data, j)::l)
215 in
216 if ltu(nrows, i)
217 then raise General.Subscript
218 else mkVec (stop+ncols-1, [])
219 end
220 else raise General.Subscript
221 fun column ({data, ncols, nelts, ...}: array, j) =
222 if ltu(j, ncols) then
223 let
224 fun mkVec (i, l) =
225 if (i < 0)
226 then Vector.fromList l
227 else mkVec(i-ncols, unsafeSub(data, i)::l)
228 in
229 if ltu(ncols, j)
230 then raise General.Subscript
231 else mkVec ((nelts - ncols) + j, [])
232 end
233 else raise General.Subscript
234
235 datatype index = DONE | INDX of {i:int, r:int, c:int}
236
237 fun chkRegion {base={data, nrows, ncols, ...}: array,
238 row, col, nrows=nr, ncols=nc}
239 = let
240 fun chk (start, n, NONE) =
241 if ((start < 0) orelse (n < start))
242 then raise General.Subscript
243 else n-start
244 | chk (start, n, SOME len) =
245 if ((start < 0) orelse (len < 0) orelse (n < start+len))
246 then raise General.Subscript
247 else len
248 val nr = chk (row, nrows, nr)
249 val nc = chk (col, ncols, nc)
250 in
251 {data = data, i = (row*ncols + col), r=row, c=col, nr=nr, nc=nc}
252 end
253
254 fun copy (region, dst, dst_row, dst_col) =
255 raise Fail "Array2.copy unimplemented"
256
257
258 (* this function generates a stream of indices for the given region in
259 * row-major order.
260 *)
261 fun iterateRM arg = let
262 val {data, i, r, c, nr, nc} = chkRegion arg
263 val ii = ref i and ri = ref r and ci = ref c
264 fun mkIndx (r, c) = let val i = !ii
265 in
266 ii := i+1;
267 INDX{i=i, c=c, r=r}
268 end
269 fun iter () = let
270 val r = !ri and c = !ci
271 in
272 if (c < nc)
273 then (ci := c+1; mkIndx(r, c))
274 else if (r+1 < nr)
275 then (ci := 0; ri := r+1; iter())
276 else DONE
277 end
278 in
279 (data, iter)
280 end
281
282 (* this function generates a stream of indices for the given region in
283 * col-major order.
284 *)
285 fun iterateCM (arg as {base={ncols, ...}, ...}) = let
286 val {data, i, r, c, nr, nc} = chkRegion arg
287 val delta = nr * ncols - 1
288 val ii = ref i and ri = ref r and ci = ref c
289 fun mkIndx (r, c) = let val i = !ii
290 in
291 ii := i+ncols;
292 INDX{i=i, c=c, r=r}
293 end
294 fun iter () = let
295 val r = !ri and c = !ci
296 in
297 if (r < nr)
298 then (ri := r+1; mkIndx(r, c))
299 else if (c+1 < nc)
300 then (ii := !ii-delta; ri := 0; ci := c+1; iter())
301 else DONE
302 end
303 in
304 (data, iter)
305 end
306
307 fun appi order f region = let
308 val (data, iter) = (case order
309 of RowMajor => iterateRM region
310 | ColMajor => iterateCM region
311 (* end case *))
312 fun app () = (case iter()
313 of DONE => ()
314 | INDX{i, r, c} => (f(r, c, unsafeSub(data, i)); app())
315 (* end case *))
316 in
317 app ()
318 end
319
320 fun appRM f ({data, nelts, ...}: array) =
321 let
322 fun appf k =
323 if k < nelts then (f(unsafeSub(data,k));
324 appf(k+1))
325 else ()
326 in
327 appf 0
328 end
329
330 fun appCM f {data, ncols, nrows, nelts} = let
331 val delta = nelts - 1
332 fun appf (i, k) = if (i < nrows)
333 then (f(unsafeSub(data, k)); appf(i+1, k+ncols))
334 else let
335 val k = k-delta
336 in
337 if (k < ncols) then appf (0, k) else ()
338 end
339 in
340 appf (0, 0)
341 end
342 fun app RowMajor = appRM
343 | app ColMajor = appCM
344
345 fun modifyi order f region = let
346 val (data, iter) = (case order
347 of RowMajor => iterateRM region
348 | ColMajor => iterateCM region
349 (* end case *))
350 fun modify () = (case iter()
351 of DONE => ()
352 | INDX{i, r, c} => (
353 unsafeUpdate (data, i, f(r, c, unsafeSub(data, i)));
354 modify())
355 (* end case *))
356 in
357 modify ()
358 end
359
360 fun modifyRM f ({data, nelts, ...}: array) =
361 let
362 fun modf k =
363 if k < nelts then (unsafeUpdate(data,k,f(unsafeSub(data,k)));
364 modf (k+1))
365 else ()
366 in
367 modf 0
368 end
369
370 fun modifyCM f {data, ncols, nrows, nelts} = let
371 val delta = nelts - 1
372 fun modf (i, k) = if (i < nrows)
373 then (unsafeUpdate(data, k, f(unsafeSub(data, k))); modf(i+1, k+ncols))
374 else let
375 val k = k-delta
376 in
377 if (k < ncols) then modf (0, k) else ()
378 end
379 in
380 modf (0, 0)
381 end
382 fun modify RowMajor = modifyRM
383 | modify ColMajor = modifyCM
384
385 fun foldi order f init region = let
386 val (data, iter) = (case order
387 of RowMajor => iterateRM region
388 | ColMajor => iterateCM region
389 (* end case *))
390 fun fold accum = (case iter()
391 of DONE => accum
392 | INDX{i, r, c} => fold(f(r, c, unsafeSub(data, i), accum))
393 (* end case *))
394 in
395 fold init
396 end
397
398 fun foldRM f init ({data, nelts, ...}: array) =
399 let
400 fun foldf (k, accum) =
401 if k < nelts then foldf(k+1,f(unsafeSub(data,k),accum))
402 else accum
403 in
404 foldf (0,init)
405 end
406
407 fun foldCM f init {data, ncols, nrows, nelts} = let
408 val delta = nelts - 1
409 fun foldf (i, k, accum) = if (i < nrows)
410 then foldf (i+1, k+ncols, f(unsafeSub(data, k), accum))
411 else let
412 val k = k-delta
413 in
414 if (k < ncols) then foldf (0, k, accum) else accum
415 end
416 in
417 foldf (0, 0, init)
418 end
419 fun fold RowMajor = foldRM
420 | fold ColMajor = foldCM
421
422
423 fun transpose {data, nrows, ncols, nelts} =
424 let
425 val dst = mkRawArray nelts
426 val delta = nelts - 1
427 fun iter (k,k') =
428 if k >= nelts then {data = dst,
429 nrows = ncols,
430 ncols = nrows,
431 nelts = nelts}
432 else (if k' >= nelts then iter(k,k' - delta)
433 else (unsafeUpdate(dst,k',unsafeSub(data,k));
434 iter(k+1,k'+nrows)))
435 in
436 iter(0,0)
437 end
438
439 fun extract (region as {base,row,col,nrows,ncols}) =
440 let
441 fun chk (start,limit,NONE) =
442 if ltu(start,limit) then limit - start
443 else raise General.Subscript
444
445 | chk (start, limit, SOME len) =
446 if ltu(start + len - 1, limit) then len
447 else raise General.Subscript
448
449 val nr = chk(row, nRows(base), nrows)
450 val nc = chk(col, nCols(base), ncols)
451 val n = nr * nc
452 val dst = mkRawArray n
453 val (data, iter) = iterateRM region
454 fun app (k) = (case iter() of
455 DONE => {data = dst,
456 nrows = nr,
457 ncols = nc,
458 nelts = n}
459 | INDX{i,...} =>
460 (unsafeUpdate(dst,k,unsafeSub(data,i));
461 app(k+1)))
462 in
463 app (0)
464 end
465
466 fun rmSub (arr as {data,...}: array,ix) =
467 unsafeSub(data,rmIndex(arr, ix))
468
469 fun rmUpdate(arr as {data,...}: array,ix,v) =
470 unsafeUpdate(data,rmIndex(arr, ix),v)
471
472 fun binop ({data=dst,nelts=nelts,...}: array,
473 {data=src1,...}: array,
474 {data=src2,...}: array,
475 f) =
476 dotimes nelts
477 (fn (ix) => unsafeUpdate(dst,ix,f(unsafeSub(src1,ix),
478 unsafeSub(src2,ix))))
479
480 fun unop ({data=dst,nelts=nelts,...}: array,
481 {data=src,...}: array,
482 f) =
483 dotimes nelts
484 (fn (ix) => unsafeUpdate(dst,ix,f(unsafeSub(src,ix))))
485
486 fun binopi ({data=dst,nelts=nelts,...}: array,
487 {data=src1,...}: array,
488 {data=src2,...}: array,
489 f) =
490 dotimes nelts
491 (fn ix => unsafeUpdate(dst,ix,f(unsafeSub(src1,ix),
492 unsafeSub(src2,ix),
493 ix)))
494
495 fun unopi ({data=dst,nelts=nelts,...}: array,
496 {data=src,...}: array,
497 f) =
498 dotimes nelts
499 (fn ix => unsafeUpdate(dst,ix,f(unsafeSub(src,ix),ix)))
500
501 fun fill ({data=dst,nelts=nelts,...}: array,v) =
502 dotimes nelts
503 (fn ix => unsafeUpdate(dst,ix,v))
504
505 fun fillf ({data=dst,nelts=nelts,...}: array,f) =
506 dotimes nelts
507 (fn ix => unsafeUpdate(dst,ix,f(ix)))
508
509 end
510
511 (* test of Zernick phase screen E-field generation *)
512 (* This is 1.9 times faster than IDL!!!! *)
513 structure MSpeed =
514 struct
515 structure F = FastRealArray2
516
517 val sin = Math.sin
518 val cos = Math.cos
519
520 val fromInt = LargeReal.fromInt
521
522 (* setup working vectors and arrays *)
523 fun collect n f =
524 let
525 fun g 0 l = l
526 | g n l = g (n-1) ((f n) :: l)
527 in
528 g n nil
529 end
530
531 val ncoefs = 15
532 val nx = 128
533 val ny = nx
534 val nel = nx * ny
535
536 (* generate an array from a scaled vector *)
537 fun mulsv (dst, sf, a) =
538 F.unop(dst,a,fn(vsrc) => sf * vsrc)
539
540
541 (* compute the complex exponential of an array *)
542 fun cisv (a, rpart, ipart) =
543 (F.unop(rpart,a,cos);
544 F.unop(ipart,a,sin);
545 (rpart,ipart))
546
547 (* accumulate scaled vectors into an array *)
548 fun mpadd dst (sf, src) =
549 F.binop(dst,dst,src,fn(vdst,vsrc) => vdst + sf * vsrc)
550
551
552 (* compute an E-field from a set of Zernike screens *)
553 fun zern (dst, rpart, ipart, coefs, zerns) =
554 (mulsv (dst, hd coefs, hd zerns);
555 ListPair.app (mpadd dst) (tl coefs, tl zerns);
556 cisv (dst, rpart, ipart))
557
558 (* timing tests and reporting *)
559 fun report_times(niter, nel, (start, stop)) =
560 let
561 val secs = Time.-(stop,start)
562 val dur = Time.toReal(secs) * 1.0E6
563 val ops_per_us = ((fromInt niter) * (fromInt nel)) / dur
564 val ns_per_op = 1000.0 / ops_per_us
565 in
566 print(Time.toString (Time.-(stop,start)));
567 print("\n");
568 { ops_per_us = ops_per_us, ns_per_op = ns_per_op}
569 end
570
571 fun time_iterations f niter =
572 let
573 fun iter 0 = Time.now()
574 | iter n = (ignore (f()); iter (n-1))
575 in
576 (Time.now(), iter niter)
577 end
578
579 fun ztest niter =
580 report_times(niter, nel,
581 time_iterations
582 (fn () =>
583 let val sum = F.array(ny,nx, 0.0)
584 val rpart = F.array(ny,nx, 0.0)
585 val ipart = F.array(ny,nx, 0.0)
586 val coefs = collect ncoefs (fn(x) => real(1 + x))
587 val zerns =
588 collect ncoefs
589 (fn(x) => F.tabulate F.RowMajor
590 (ny, nx, fn(r,c) => 0.01 * real(nx * r + c)))
591 val (rpart, _) =
592 zern (sum, rpart, ipart, coefs, zerns)
593 in if Real.abs(FastRealArray2.sub(rpart, 0, 1) - 0.219)
594 < 0.001
595 then ()
596 else raise Fail "compiler bug"
597 end)
598 niter)
599 end
600
601 structure Main =
602 struct
603 fun doit n = MSpeed.ztest n
604 end