Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* |
2 | * From David McClain's language study. | |
3 | * http://www.azstarnet.com/~dmcclain/LanguageStudy.html | |
4 | * | |
5 | * Stephen Weeks replaced Unsafe.Real64Array with Real64Array. | |
6 | *) | |
7 | ||
8 | fun print _ = () | |
9 | ||
10 | (* array2.sml | |
11 | * | |
12 | * COPYRIGHT (c) 1998 D.McClain/MCFA | |
13 | * COPYRIGHT (c) 1997 AT&T Research. | |
14 | *) | |
15 | ||
16 | structure FastRealArray2 : | |
17 | sig | |
18 | type array | |
19 | ||
20 | type region | |
21 | = {base : array, | |
22 | row : int, | |
23 | col : int, | |
24 | nrows : int option, | |
25 | ncols : int option} | |
26 | ||
27 | datatype traversal = RowMajor | ColMajor | |
28 | ||
29 | val array : int * int * real -> array | |
30 | val fromList : real list list -> array | |
31 | val tabulate : traversal -> (int * int * (int * int -> real)) -> array | |
32 | val sub : array * int * int -> real | |
33 | val update : array * int * int * real -> unit | |
34 | val dimensions : array -> int * int | |
35 | val size : array -> int | |
36 | val nCols : array -> int | |
37 | val nRows : array -> int | |
38 | val row : array * int -> real Vector.vector | |
39 | val column : array * int -> real Vector.vector | |
40 | ||
41 | val copy : region * array * int * int -> unit | |
42 | val appi : traversal -> (int * int * real -> unit) -> region -> unit | |
43 | val app : traversal -> (real -> unit) -> array -> unit | |
44 | val modifyi : traversal -> (int * int * real -> real) -> region -> unit | |
45 | val modify : traversal -> (real -> real) -> array -> unit | |
46 | val foldi : traversal -> (int*int*real*'a -> 'a) -> 'a -> region -> 'a | |
47 | val fold : traversal -> (real * 'a -> 'a) -> 'a -> array -> 'a | |
48 | ||
49 | val rmSub : array * int -> real | |
50 | val rmUpdate : array * int * real -> unit | |
51 | ||
52 | val unop : array * array * (real -> real) -> unit | |
53 | val unopi : array * array * (real * int -> real) -> unit | |
54 | val binop : array * array * array * (real * real -> real) -> unit | |
55 | val binopi : array * array * array * (real * real * int -> real) -> unit | |
56 | val fill : array * real -> unit | |
57 | val fillf : array * (int -> real) -> unit | |
58 | ||
59 | val transpose : array -> array | |
60 | val extract : region -> array | |
61 | ||
62 | (* | |
63 | val shift : array * int * int -> array | |
64 | *) | |
65 | end = | |
66 | struct | |
67 | ||
68 | structure A = (*Unsafe.*)Real64Array | |
69 | ||
70 | type rawArray = A.array | |
71 | ||
72 | val unsafeUpdate = A.update | |
73 | val unsafeSub = A.sub | |
74 | fun mkRawArray n = A.array (n, 0.0) | |
75 | ||
76 | ||
77 | type array = {data : rawArray, | |
78 | nrows : int, | |
79 | ncols : int, | |
80 | nelts : int} | |
81 | ||
82 | type region = {base : array, | |
83 | row : int, | |
84 | col : int, | |
85 | nrows : int option, | |
86 | ncols : int option} | |
87 | ||
88 | datatype traversal = RowMajor | ColMajor | |
89 | ||
90 | ||
91 | fun dotimes n f = | |
92 | let (* going forward is twice as fast as backward! *) | |
93 | fun iter k = if k >= n then () | |
94 | else (f(k); iter(k+1)) | |
95 | in | |
96 | iter 0 | |
97 | end | |
98 | ||
99 | ||
100 | fun mkArray(n,v) = | |
101 | let | |
102 | val arr = mkRawArray n | |
103 | in | |
104 | dotimes n (fn ix => unsafeUpdate(arr,ix,v)); | |
105 | arr | |
106 | end | |
107 | ||
108 | (* compute the index of an array element *) | |
109 | fun ltu(i,limit) = (i >= 0) andalso (i < limit) | |
110 | fun unsafeIndex ({nrows, ncols, ...} : array, i, j) = (i*ncols + j) | |
111 | fun index (arr, i, j) = | |
112 | if (ltu(i, #nrows arr) andalso ltu(j, #ncols arr)) | |
113 | then unsafeIndex (arr, i, j) | |
114 | else raise General.Subscript | |
115 | (* row major index checking *) | |
116 | fun rmIndex ({nelts,...}: array, ix) = | |
117 | if ltu(ix, nelts) then ix | |
118 | else raise General.Subscript | |
119 | ||
120 | val max_length = 4096 * 4096; (* arbitrary - but this is 128 MB *) | |
121 | ||
122 | fun chkSize (nrows, ncols) = | |
123 | if (nrows <= 0) orelse (ncols <= 0) | |
124 | then raise General.Size | |
125 | else let | |
126 | val n = nrows*ncols handle Overflow => raise General.Size | |
127 | in | |
128 | if (max_length < n) then raise General.Size else n | |
129 | end | |
130 | ||
131 | fun array (nrows, ncols, v) = | |
132 | let | |
133 | val nelts = chkSize (nrows, ncols) | |
134 | in | |
135 | {data = mkArray (nelts, v), | |
136 | nrows = nrows, ncols = ncols, nelts = nelts} | |
137 | end | |
138 | ||
139 | fun fromList [] = raise General.Size | |
140 | | fromList (row1 :: rest) = let | |
141 | val ncols = List.length row1 | |
142 | fun chk ([], nrows, l) = (nrows, l) | |
143 | | chk (row::rest, nrows, l) = let | |
144 | fun chkRow ([], n, revCol) = ( | |
145 | if (n <> ncols) then raise General.Size else (); | |
146 | List.revAppend (revCol, l)) | |
147 | | chkRow (x::r, n, revCol) = chkRow (r, n+1, x::revCol) | |
148 | in | |
149 | chk (rest, nrows+1, chkRow(row, 0, [])) | |
150 | end | |
151 | val (nrows, flatList) = chk (rest, 1, []) | |
152 | val nelts = chkSize(nrows, ncols) | |
153 | val arr = mkRawArray nelts | |
154 | fun upd(_,nil) = arr | |
155 | | upd(k,v::vs) = (unsafeUpdate(arr,k,v); upd(k+1,vs)) | |
156 | in | |
157 | { data = upd(0,List.@(row1, flatList)), | |
158 | nrows = nrows, | |
159 | ncols = ncols, | |
160 | nelts = nelts } | |
161 | end | |
162 | ||
163 | fun tabulateRM (nrows, ncols, f) = | |
164 | let | |
165 | val nelts = chkSize(nrows, ncols) | |
166 | val arr = mkRawArray nelts | |
167 | fun lp1 (i, j, k) = if (i < nrows) | |
168 | then lp2 (i, 0, k) | |
169 | else () | |
170 | and lp2 (i, j, k) = if (j < ncols) | |
171 | then ( | |
172 | unsafeUpdate(arr, k, f(i, j)); | |
173 | lp2 (i, j+1, k+1)) | |
174 | else lp1 (i+1, 0, k) | |
175 | in | |
176 | lp2 (0, 0, 0); | |
177 | {data = arr, nrows = nrows, ncols = ncols, nelts = nelts} | |
178 | end | |
179 | ||
180 | fun tabulateCM (nrows, ncols, f) = | |
181 | let | |
182 | val nelts = chkSize(nrows,ncols) | |
183 | val arr = mkRawArray nelts | |
184 | val delta = nelts - 1 | |
185 | fun lp1 (i, j, k) = if (j < ncols) | |
186 | then lp2 (0, j, k) | |
187 | else () | |
188 | and lp2 (i, j, k) = if (i < nrows) | |
189 | then ( | |
190 | unsafeUpdate(arr, k, f(i, j)); | |
191 | lp2 (i+1, j, k+ncols)) | |
192 | else lp1 (0, j+1, k-delta) | |
193 | in | |
194 | lp2 (0, 0, 0); | |
195 | {data = arr, nrows = nrows, ncols = ncols, nelts = nelts} | |
196 | end | |
197 | ||
198 | fun tabulate RowMajor = tabulateRM | |
199 | | tabulate ColMajor = tabulateCM | |
200 | ||
201 | fun sub (a, i, j) = unsafeSub(#data a, index(a, i, j)) | |
202 | fun update (a, i, j, v) = unsafeUpdate(#data a, index(a, i, j), v) | |
203 | fun dimensions ({nrows, ncols, ...}: array) = (nrows, ncols) | |
204 | fun size ({nelts,...}: array) = nelts | |
205 | fun nCols (arr : array) = #ncols arr | |
206 | fun nRows (arr : array) = #nrows arr | |
207 | fun row ({data, nrows, ncols, ...}: array, i) = | |
208 | if ltu(i, nrows) then | |
209 | let | |
210 | val stop = i*ncols | |
211 | fun mkVec (j, l) = | |
212 | if (j < stop) | |
213 | then Vector.fromList l | |
214 | else mkVec(j-1, unsafeSub(data, j)::l) | |
215 | in | |
216 | if ltu(nrows, i) | |
217 | then raise General.Subscript | |
218 | else mkVec (stop+ncols-1, []) | |
219 | end | |
220 | else raise General.Subscript | |
221 | fun column ({data, ncols, nelts, ...}: array, j) = | |
222 | if ltu(j, ncols) then | |
223 | let | |
224 | fun mkVec (i, l) = | |
225 | if (i < 0) | |
226 | then Vector.fromList l | |
227 | else mkVec(i-ncols, unsafeSub(data, i)::l) | |
228 | in | |
229 | if ltu(ncols, j) | |
230 | then raise General.Subscript | |
231 | else mkVec ((nelts - ncols) + j, []) | |
232 | end | |
233 | else raise General.Subscript | |
234 | ||
235 | datatype index = DONE | INDX of {i:int, r:int, c:int} | |
236 | ||
237 | fun chkRegion {base={data, nrows, ncols, ...}: array, | |
238 | row, col, nrows=nr, ncols=nc} | |
239 | = let | |
240 | fun chk (start, n, NONE) = | |
241 | if ((start < 0) orelse (n < start)) | |
242 | then raise General.Subscript | |
243 | else n-start | |
244 | | chk (start, n, SOME len) = | |
245 | if ((start < 0) orelse (len < 0) orelse (n < start+len)) | |
246 | then raise General.Subscript | |
247 | else len | |
248 | val nr = chk (row, nrows, nr) | |
249 | val nc = chk (col, ncols, nc) | |
250 | in | |
251 | {data = data, i = (row*ncols + col), r=row, c=col, nr=nr, nc=nc} | |
252 | end | |
253 | ||
254 | fun copy (region, dst, dst_row, dst_col) = | |
255 | raise Fail "Array2.copy unimplemented" | |
256 | ||
257 | ||
258 | (* this function generates a stream of indices for the given region in | |
259 | * row-major order. | |
260 | *) | |
261 | fun iterateRM arg = let | |
262 | val {data, i, r, c, nr, nc} = chkRegion arg | |
263 | val ii = ref i and ri = ref r and ci = ref c | |
264 | fun mkIndx (r, c) = let val i = !ii | |
265 | in | |
266 | ii := i+1; | |
267 | INDX{i=i, c=c, r=r} | |
268 | end | |
269 | fun iter () = let | |
270 | val r = !ri and c = !ci | |
271 | in | |
272 | if (c < nc) | |
273 | then (ci := c+1; mkIndx(r, c)) | |
274 | else if (r+1 < nr) | |
275 | then (ci := 0; ri := r+1; iter()) | |
276 | else DONE | |
277 | end | |
278 | in | |
279 | (data, iter) | |
280 | end | |
281 | ||
282 | (* this function generates a stream of indices for the given region in | |
283 | * col-major order. | |
284 | *) | |
285 | fun iterateCM (arg as {base={ncols, ...}, ...}) = let | |
286 | val {data, i, r, c, nr, nc} = chkRegion arg | |
287 | val delta = nr * ncols - 1 | |
288 | val ii = ref i and ri = ref r and ci = ref c | |
289 | fun mkIndx (r, c) = let val i = !ii | |
290 | in | |
291 | ii := i+ncols; | |
292 | INDX{i=i, c=c, r=r} | |
293 | end | |
294 | fun iter () = let | |
295 | val r = !ri and c = !ci | |
296 | in | |
297 | if (r < nr) | |
298 | then (ri := r+1; mkIndx(r, c)) | |
299 | else if (c+1 < nc) | |
300 | then (ii := !ii-delta; ri := 0; ci := c+1; iter()) | |
301 | else DONE | |
302 | end | |
303 | in | |
304 | (data, iter) | |
305 | end | |
306 | ||
307 | fun appi order f region = let | |
308 | val (data, iter) = (case order | |
309 | of RowMajor => iterateRM region | |
310 | | ColMajor => iterateCM region | |
311 | (* end case *)) | |
312 | fun app () = (case iter() | |
313 | of DONE => () | |
314 | | INDX{i, r, c} => (f(r, c, unsafeSub(data, i)); app()) | |
315 | (* end case *)) | |
316 | in | |
317 | app () | |
318 | end | |
319 | ||
320 | fun appRM f ({data, nelts, ...}: array) = | |
321 | let | |
322 | fun appf k = | |
323 | if k < nelts then (f(unsafeSub(data,k)); | |
324 | appf(k+1)) | |
325 | else () | |
326 | in | |
327 | appf 0 | |
328 | end | |
329 | ||
330 | fun appCM f {data, ncols, nrows, nelts} = let | |
331 | val delta = nelts - 1 | |
332 | fun appf (i, k) = if (i < nrows) | |
333 | then (f(unsafeSub(data, k)); appf(i+1, k+ncols)) | |
334 | else let | |
335 | val k = k-delta | |
336 | in | |
337 | if (k < ncols) then appf (0, k) else () | |
338 | end | |
339 | in | |
340 | appf (0, 0) | |
341 | end | |
342 | fun app RowMajor = appRM | |
343 | | app ColMajor = appCM | |
344 | ||
345 | fun modifyi order f region = let | |
346 | val (data, iter) = (case order | |
347 | of RowMajor => iterateRM region | |
348 | | ColMajor => iterateCM region | |
349 | (* end case *)) | |
350 | fun modify () = (case iter() | |
351 | of DONE => () | |
352 | | INDX{i, r, c} => ( | |
353 | unsafeUpdate (data, i, f(r, c, unsafeSub(data, i))); | |
354 | modify()) | |
355 | (* end case *)) | |
356 | in | |
357 | modify () | |
358 | end | |
359 | ||
360 | fun modifyRM f ({data, nelts, ...}: array) = | |
361 | let | |
362 | fun modf k = | |
363 | if k < nelts then (unsafeUpdate(data,k,f(unsafeSub(data,k))); | |
364 | modf (k+1)) | |
365 | else () | |
366 | in | |
367 | modf 0 | |
368 | end | |
369 | ||
370 | fun modifyCM f {data, ncols, nrows, nelts} = let | |
371 | val delta = nelts - 1 | |
372 | fun modf (i, k) = if (i < nrows) | |
373 | then (unsafeUpdate(data, k, f(unsafeSub(data, k))); modf(i+1, k+ncols)) | |
374 | else let | |
375 | val k = k-delta | |
376 | in | |
377 | if (k < ncols) then modf (0, k) else () | |
378 | end | |
379 | in | |
380 | modf (0, 0) | |
381 | end | |
382 | fun modify RowMajor = modifyRM | |
383 | | modify ColMajor = modifyCM | |
384 | ||
385 | fun foldi order f init region = let | |
386 | val (data, iter) = (case order | |
387 | of RowMajor => iterateRM region | |
388 | | ColMajor => iterateCM region | |
389 | (* end case *)) | |
390 | fun fold accum = (case iter() | |
391 | of DONE => accum | |
392 | | INDX{i, r, c} => fold(f(r, c, unsafeSub(data, i), accum)) | |
393 | (* end case *)) | |
394 | in | |
395 | fold init | |
396 | end | |
397 | ||
398 | fun foldRM f init ({data, nelts, ...}: array) = | |
399 | let | |
400 | fun foldf (k, accum) = | |
401 | if k < nelts then foldf(k+1,f(unsafeSub(data,k),accum)) | |
402 | else accum | |
403 | in | |
404 | foldf (0,init) | |
405 | end | |
406 | ||
407 | fun foldCM f init {data, ncols, nrows, nelts} = let | |
408 | val delta = nelts - 1 | |
409 | fun foldf (i, k, accum) = if (i < nrows) | |
410 | then foldf (i+1, k+ncols, f(unsafeSub(data, k), accum)) | |
411 | else let | |
412 | val k = k-delta | |
413 | in | |
414 | if (k < ncols) then foldf (0, k, accum) else accum | |
415 | end | |
416 | in | |
417 | foldf (0, 0, init) | |
418 | end | |
419 | fun fold RowMajor = foldRM | |
420 | | fold ColMajor = foldCM | |
421 | ||
422 | ||
423 | fun transpose {data, nrows, ncols, nelts} = | |
424 | let | |
425 | val dst = mkRawArray nelts | |
426 | val delta = nelts - 1 | |
427 | fun iter (k,k') = | |
428 | if k >= nelts then {data = dst, | |
429 | nrows = ncols, | |
430 | ncols = nrows, | |
431 | nelts = nelts} | |
432 | else (if k' >= nelts then iter(k,k' - delta) | |
433 | else (unsafeUpdate(dst,k',unsafeSub(data,k)); | |
434 | iter(k+1,k'+nrows))) | |
435 | in | |
436 | iter(0,0) | |
437 | end | |
438 | ||
439 | fun extract (region as {base,row,col,nrows,ncols}) = | |
440 | let | |
441 | fun chk (start,limit,NONE) = | |
442 | if ltu(start,limit) then limit - start | |
443 | else raise General.Subscript | |
444 | ||
445 | | chk (start, limit, SOME len) = | |
446 | if ltu(start + len - 1, limit) then len | |
447 | else raise General.Subscript | |
448 | ||
449 | val nr = chk(row, nRows(base), nrows) | |
450 | val nc = chk(col, nCols(base), ncols) | |
451 | val n = nr * nc | |
452 | val dst = mkRawArray n | |
453 | val (data, iter) = iterateRM region | |
454 | fun app (k) = (case iter() of | |
455 | DONE => {data = dst, | |
456 | nrows = nr, | |
457 | ncols = nc, | |
458 | nelts = n} | |
459 | | INDX{i,...} => | |
460 | (unsafeUpdate(dst,k,unsafeSub(data,i)); | |
461 | app(k+1))) | |
462 | in | |
463 | app (0) | |
464 | end | |
465 | ||
466 | fun rmSub (arr as {data,...}: array,ix) = | |
467 | unsafeSub(data,rmIndex(arr, ix)) | |
468 | ||
469 | fun rmUpdate(arr as {data,...}: array,ix,v) = | |
470 | unsafeUpdate(data,rmIndex(arr, ix),v) | |
471 | ||
472 | fun binop ({data=dst,nelts=nelts,...}: array, | |
473 | {data=src1,...}: array, | |
474 | {data=src2,...}: array, | |
475 | f) = | |
476 | dotimes nelts | |
477 | (fn (ix) => unsafeUpdate(dst,ix,f(unsafeSub(src1,ix), | |
478 | unsafeSub(src2,ix)))) | |
479 | ||
480 | fun unop ({data=dst,nelts=nelts,...}: array, | |
481 | {data=src,...}: array, | |
482 | f) = | |
483 | dotimes nelts | |
484 | (fn (ix) => unsafeUpdate(dst,ix,f(unsafeSub(src,ix)))) | |
485 | ||
486 | fun binopi ({data=dst,nelts=nelts,...}: array, | |
487 | {data=src1,...}: array, | |
488 | {data=src2,...}: array, | |
489 | f) = | |
490 | dotimes nelts | |
491 | (fn ix => unsafeUpdate(dst,ix,f(unsafeSub(src1,ix), | |
492 | unsafeSub(src2,ix), | |
493 | ix))) | |
494 | ||
495 | fun unopi ({data=dst,nelts=nelts,...}: array, | |
496 | {data=src,...}: array, | |
497 | f) = | |
498 | dotimes nelts | |
499 | (fn ix => unsafeUpdate(dst,ix,f(unsafeSub(src,ix),ix))) | |
500 | ||
501 | fun fill ({data=dst,nelts=nelts,...}: array,v) = | |
502 | dotimes nelts | |
503 | (fn ix => unsafeUpdate(dst,ix,v)) | |
504 | ||
505 | fun fillf ({data=dst,nelts=nelts,...}: array,f) = | |
506 | dotimes nelts | |
507 | (fn ix => unsafeUpdate(dst,ix,f(ix))) | |
508 | ||
509 | end | |
510 | ||
511 | (* test of Zernick phase screen E-field generation *) | |
512 | (* This is 1.9 times faster than IDL!!!! *) | |
513 | structure MSpeed = | |
514 | struct | |
515 | structure F = FastRealArray2 | |
516 | ||
517 | val sin = Math.sin | |
518 | val cos = Math.cos | |
519 | ||
520 | val fromInt = LargeReal.fromInt | |
521 | ||
522 | (* setup working vectors and arrays *) | |
523 | fun collect n f = | |
524 | let | |
525 | fun g 0 l = l | |
526 | | g n l = g (n-1) ((f n) :: l) | |
527 | in | |
528 | g n nil | |
529 | end | |
530 | ||
531 | val ncoefs = 15 | |
532 | val nx = 128 | |
533 | val ny = nx | |
534 | val nel = nx * ny | |
535 | ||
536 | (* generate an array from a scaled vector *) | |
537 | fun mulsv (dst, sf, a) = | |
538 | F.unop(dst,a,fn(vsrc) => sf * vsrc) | |
539 | ||
540 | ||
541 | (* compute the complex exponential of an array *) | |
542 | fun cisv (a, rpart, ipart) = | |
543 | (F.unop(rpart,a,cos); | |
544 | F.unop(ipart,a,sin); | |
545 | (rpart,ipart)) | |
546 | ||
547 | (* accumulate scaled vectors into an array *) | |
548 | fun mpadd dst (sf, src) = | |
549 | F.binop(dst,dst,src,fn(vdst,vsrc) => vdst + sf * vsrc) | |
550 | ||
551 | ||
552 | (* compute an E-field from a set of Zernike screens *) | |
553 | fun zern (dst, rpart, ipart, coefs, zerns) = | |
554 | (mulsv (dst, hd coefs, hd zerns); | |
555 | ListPair.app (mpadd dst) (tl coefs, tl zerns); | |
556 | cisv (dst, rpart, ipart)) | |
557 | ||
558 | (* timing tests and reporting *) | |
559 | fun report_times(niter, nel, (start, stop)) = | |
560 | let | |
561 | val secs = Time.-(stop,start) | |
562 | val dur = Time.toReal(secs) * 1.0E6 | |
563 | val ops_per_us = ((fromInt niter) * (fromInt nel)) / dur | |
564 | val ns_per_op = 1000.0 / ops_per_us | |
565 | in | |
566 | print(Time.toString (Time.-(stop,start))); | |
567 | print("\n"); | |
568 | { ops_per_us = ops_per_us, ns_per_op = ns_per_op} | |
569 | end | |
570 | ||
571 | fun time_iterations f niter = | |
572 | let | |
573 | fun iter 0 = Time.now() | |
574 | | iter n = (ignore (f()); iter (n-1)) | |
575 | in | |
576 | (Time.now(), iter niter) | |
577 | end | |
578 | ||
579 | fun ztest niter = | |
580 | report_times(niter, nel, | |
581 | time_iterations | |
582 | (fn () => | |
583 | let val sum = F.array(ny,nx, 0.0) | |
584 | val rpart = F.array(ny,nx, 0.0) | |
585 | val ipart = F.array(ny,nx, 0.0) | |
586 | val coefs = collect ncoefs (fn(x) => real(1 + x)) | |
587 | val zerns = | |
588 | collect ncoefs | |
589 | (fn(x) => F.tabulate F.RowMajor | |
590 | (ny, nx, fn(r,c) => 0.01 * real(nx * r + c))) | |
591 | val (rpart, _) = | |
592 | zern (sum, rpart, ipart, coefs, zerns) | |
593 | in if Real.abs(FastRealArray2.sub(rpart, 0, 1) - 0.219) | |
594 | < 0.001 | |
595 | then () | |
596 | else raise Fail "compiler bug" | |
597 | end) | |
598 | niter) | |
599 | end | |
600 | ||
601 | structure Main = | |
602 | struct | |
603 | fun doit n = MSpeed.ztest n | |
604 | end |