Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* Obtained at http://www.arrakis.es/~worm/ *) |
2 | ||
3 | signature MONO_VECTOR = | |
4 | sig | |
5 | type vector | |
6 | type elem | |
7 | val maxLen : int | |
8 | val fromList : elem list -> vector | |
9 | val tabulate : (int * (int -> elem)) -> vector | |
10 | val length : vector -> int | |
11 | val sub : (vector * int) -> elem | |
12 | val extract : (vector * int * int option) -> vector | |
13 | val concat : vector list -> vector | |
14 | val mapi : ((int * elem) -> elem) -> (vector * int * int option) -> vector | |
15 | val map : (elem -> elem) -> vector -> vector | |
16 | val appi : ((int * elem) -> unit) -> (vector * int * int option) -> unit | |
17 | val app : (elem -> unit) -> vector -> unit | |
18 | val foldli : ((int * elem * 'a) -> 'a) -> 'a -> (vector * int * int option) -> 'a | |
19 | val foldri : ((int * elem * 'a) -> 'a) -> 'a -> (vector * int * int option) -> 'a | |
20 | val foldl : ((elem * 'a) -> 'a) -> 'a -> vector -> 'a | |
21 | val foldr : ((elem * 'a) -> 'a) -> 'a -> vector -> 'a | |
22 | end | |
23 | ||
24 | (* | |
25 | Copyright (c) Juan Jose Garcia Ripoll. | |
26 | All rights reserved. | |
27 | ||
28 | Refer to the COPYRIGHT file for license conditions | |
29 | *) | |
30 | ||
31 | (* COPYRIGHT | |
32 | ||
33 | Redistribution and use in source and binary forms, with or | |
34 | without modification, are permitted provided that the following | |
35 | conditions are met: | |
36 | ||
37 | 1. Redistributions of source code must retain the above copyright | |
38 | notice, this list of conditions and the following disclaimer. | |
39 | ||
40 | 2. Redistributions in binary form must reproduce the above | |
41 | copyright notice, this list of conditions and the following | |
42 | disclaimer in the documentation and/or other materials provided | |
43 | with the distribution. | |
44 | ||
45 | 3. All advertising materials mentioning features or use of this | |
46 | software must display the following acknowledgement: | |
47 | This product includes software developed by Juan Jose | |
48 | Garcia Ripoll. | |
49 | ||
50 | 4. The name of Juan Jose Garcia Ripoll may not be used to endorse | |
51 | or promote products derived from this software without | |
52 | specific prior written permission. | |
53 | ||
54 | THIS SOFTWARE IS PROVIDED BY JUAN JOSE GARCIA RIPOLL ``AS IS'' | |
55 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED | |
56 | TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A | |
57 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL HE BE | |
58 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, | |
59 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | |
60 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, | |
61 | OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |
62 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR | |
63 | TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT | |
64 | OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY | |
65 | OF SUCH DAMAGE. | |
66 | *) | |
67 | ||
68 | structure EvalTimer = | |
69 | struct | |
70 | local | |
71 | val TIME = ref (Time.now()) | |
72 | in | |
73 | fun timerOn () = | |
74 | (TIME := Time.now(); ()) | |
75 | fun timerRead () = | |
76 | Time.toMilliseconds(Time.-(Time.now(),!TIME)) | |
77 | fun timerOff () = | |
78 | let val delta = timerRead() | |
79 | in | |
80 | print "Elapsed: "; | |
81 | print (LargeInt.toString delta); | |
82 | print " ms\n" | |
83 | end | |
84 | fun time f = (timerOn(); f(); timerOff()) | |
85 | end | |
86 | end | |
87 | structure Loop = | |
88 | struct | |
89 | fun all (a, b, f) = | |
90 | if a > b then | |
91 | true | |
92 | else if f a then | |
93 | all (a+1, b, f) | |
94 | else | |
95 | false | |
96 | ||
97 | fun any (a, b, f) = | |
98 | if a > b then | |
99 | false | |
100 | else if f a then | |
101 | true | |
102 | else | |
103 | any (a+1, b, f) | |
104 | ||
105 | fun app (a, b, f) = | |
106 | if a < b then | |
107 | (f a; app (a+1, b, f)) | |
108 | else | |
109 | () | |
110 | ||
111 | fun app' (a, b, d, f) = | |
112 | if a < b then | |
113 | (f a; app' (a+d, b, d, f)) | |
114 | else | |
115 | () | |
116 | ||
117 | fun appi' (a, b, d, f) = | |
118 | if a < b then | |
119 | (f a; appi' (a+d, b, d, f)) | |
120 | else | |
121 | () | |
122 | end | |
123 | (* | |
124 | INDEX -Signature- | |
125 | ||
126 | Indices are a enumerable finite set of data with an order and a map | |
127 | to a continous nonnegative interval of integers. In the sample | |
128 | implementation, Index, each index is a list of integers, | |
129 | [i1,...,in] | |
130 | and each set of indices is defined by a shape, which has the same | |
131 | shape of an index but with each integer incremented by one | |
132 | shape = [k1,...,kn] | |
133 | 0 <= i1 < k1 | |
134 | ||
135 | type storage = RowMajor | ColumnMajor | |
136 | order : storage | |
137 | Identifies: | |
138 | 1) the underlying algorithms for this structure | |
139 | 2) the most significant index | |
140 | 3) the index that varies more slowly | |
141 | 4) the total order | |
142 | RowMajor means that first index is most significant and varies | |
143 | more slowly, while ColumnMajor means that last index is the most | |
144 | significant and varies more slowly. For instance | |
145 | RowMajor => [0,0]<[0,1]<[1,0]<[1,1] (C, C++, Pascal) | |
146 | ColumnMajor => [0,0]>[1,0]>[0,1]>[1,1] (Fortran) | |
147 | last shape | |
148 | first shape | |
149 | Returns the last/first index that belongs to the sed defined by | |
150 | 'shape'. | |
151 | inBounds shape index | |
152 | Checkes whether 'index' belongs to the set defined by 'shape'. | |
153 | toInt shape index | |
154 | As we said, indices can be sorted and mapped to a finite set of | |
155 | integers. 'toInt' obtaines the integer number that corresponds to | |
156 | a certain index. | |
157 | indexer shape | |
158 | It is equivalent to the partial evaluation 'toInt shape' but | |
159 | optimized for 'shape'. | |
160 | ||
161 | next shape index | |
162 | prev shape index | |
163 | next' shape index | |
164 | prev' shape index | |
165 | Obtain the following or previous index to the one we supply. | |
166 | next and prev return an object of type 'index option' so that | |
167 | if there is no such following/previous, the output is NONE. | |
168 | On the other hand, next'/prev' raise an exception when the | |
169 | output is not well defined and their output is always of type | |
170 | index. next/prev/next'/prev' raise an exception if 'index' | |
171 | does not belong to the set of 'shape'. | |
172 | ||
173 | all shape f | |
174 | any shape f | |
175 | app shape f | |
176 | Iterates 'f' over every index of the set defined by 'shape'. | |
177 | 'all' stops when 'f' first returns false, 'any' stops when | |
178 | 'f' first returns true and 'app' does not stop and discards the | |
179 | output of 'f'. | |
180 | ||
181 | compare(a,b) | |
182 | Returns LESS/GREATER/EQUAL according to the total order which | |
183 | is defined in the set of all indices. | |
184 | <,>,eq,<=,>=,<> | |
185 | Reduced comparisons which are defined in terms of 'compare'. | |
186 | ||
187 | validShape t | |
188 | validIndex t | |
189 | Checks whether 't' conforms a valid shape or index. | |
190 | ||
191 | iteri shape f | |
192 | *) | |
193 | ||
194 | signature INDEX = | |
195 | sig | |
196 | type t | |
197 | type indexer = t -> int | |
198 | datatype storage = RowMajor | ColumnMajor | |
199 | ||
200 | exception Index | |
201 | exception Shape | |
202 | ||
203 | val order : storage | |
204 | val toInt : t -> t -> int | |
205 | val length : t -> int | |
206 | val first : t -> t | |
207 | val last : t -> t | |
208 | val next : t -> t -> t option | |
209 | val prev : t -> t -> t option | |
210 | val next' : t -> t -> t | |
211 | val prev' : t -> t -> t | |
212 | val indexer : t -> (t -> int) | |
213 | ||
214 | val inBounds : t -> t -> bool | |
215 | val compare : t * t -> order | |
216 | val < : t * t -> bool | |
217 | val > : t * t -> bool | |
218 | val eq : t * t -> bool | |
219 | val <= : t * t -> bool | |
220 | val >= : t * t -> bool | |
221 | val <> : t * t -> bool | |
222 | val - : t * t -> t | |
223 | ||
224 | val validShape : t -> bool | |
225 | val validIndex : t -> bool | |
226 | ||
227 | val all : t -> (t -> bool) -> bool | |
228 | val any : t -> (t -> bool) -> bool | |
229 | val app : t -> (t -> unit) -> unit | |
230 | end | |
231 | structure Index : INDEX = | |
232 | struct | |
233 | type t = int list | |
234 | type indexer = t -> int | |
235 | datatype storage = RowMajor | ColumnMajor | |
236 | ||
237 | exception Index | |
238 | exception Shape | |
239 | ||
240 | val order = ColumnMajor | |
241 | ||
242 | fun validShape shape = List.all (fn x => x > 0) shape | |
243 | ||
244 | fun validIndex index = List.all (fn x => x >= 0) index | |
245 | ||
246 | fun toInt shape index = | |
247 | let fun loop ([], [], accum, _) = accum | |
248 | | loop ([], _, _, _) = raise Index | |
249 | | loop (_, [], _, _) = raise Index | |
250 | | loop (i::ri, l::rl, accum, fac) = | |
251 | if (i >= 0) andalso (i < l) then | |
252 | loop (ri, rl, i*fac + accum, fac*l) | |
253 | else | |
254 | raise Index | |
255 | in loop (index, shape, 0, 1) | |
256 | end | |
257 | ||
258 | (* ----- CACHED LINEAR INDEXER ----- | |
259 | ||
260 | An indexer is a function that takes a list of | |
261 | indices, validates it and produces a nonnegative | |
262 | integer number. In short, the indexer is the | |
263 | mapper from indices to element positions in | |
264 | arrays. | |
265 | ||
266 | 'indexer' builds such a mapper by optimizing | |
267 | the most common cases, which are 1d and 2d | |
268 | tensors. | |
269 | *) | |
270 | local | |
271 | fun doindexer [] _ = raise Shape | |
272 | | doindexer [a] [dx] = | |
273 | let fun f [x] = if (x > 0) andalso (x < a) | |
274 | then x | |
275 | else raise Index | |
276 | | f _ = raise Index | |
277 | in f end | |
278 | | doindexer [a,b] [dx, dy] = | |
279 | let fun f [x,y] = if ((x > 0) andalso (x < a) andalso | |
280 | (y > 0) andalso (y < b)) | |
281 | then x + dy * y | |
282 | else raise Index | |
283 | | f _ = raise Index | |
284 | in f end | |
285 | | doindexer [a,b,c] [dx,dy,dz] = | |
286 | let fun f [x,y,z] = if ((x > 0) andalso (x < a) andalso | |
287 | (y > 0) andalso (y < b) andalso | |
288 | (z > 0) andalso (z < c)) | |
289 | then x + dy * y + dz * z | |
290 | else raise Index | |
291 | | f _ = raise Index | |
292 | in f end | |
293 | | doindexer shape memo = | |
294 | let fun f [] [] accum [] = accum | |
295 | | f _ _ _ [] = raise Index | |
296 | | f (fac::rf) (ndx::ri) accum (dim::rd) = | |
297 | if (ndx >= 0) andalso (ndx < dim) then | |
298 | f rf ri (accum + ndx * fac) rd | |
299 | else | |
300 | raise Index | |
301 | in f shape memo 0 | |
302 | end | |
303 | in | |
304 | fun indexer shape = | |
305 | let fun memoize accum [] = [] | |
306 | | memoize accum (dim::rd) = | |
307 | accum :: (memoize (dim * accum) rd) | |
308 | in | |
309 | if validShape shape | |
310 | then doindexer shape (memoize 1 shape) | |
311 | else raise Shape | |
312 | end | |
313 | end | |
314 | ||
315 | fun length shape = | |
316 | let fun prod (a,b) = | |
317 | if b < 0 then raise Shape else a * b | |
318 | in foldl prod 1 shape | |
319 | end | |
320 | ||
321 | fun first shape = map (fn x => 0) shape | |
322 | ||
323 | fun last [] = [] | |
324 | | last (size :: rest) = | |
325 | if size < 1 | |
326 | then raise Shape | |
327 | else size - 1 :: last rest | |
328 | ||
329 | fun next' [] [] = raise Subscript | |
330 | | next' _ [] = raise Index | |
331 | | next' [] _ = raise Index | |
332 | | next' (dimension::restd) (index::resti) = | |
333 | if (index + 1) < dimension | |
334 | then (index + 1) :: resti | |
335 | else 0 :: (next' restd resti) | |
336 | ||
337 | fun prev' [] [] = raise Subscript | |
338 | | prev' _ [] = raise Index | |
339 | | prev' [] _ = raise Index | |
340 | | prev' (dimension::restd) (index::resti) = | |
341 | if (index > 0) | |
342 | then index - 1 :: resti | |
343 | else dimension - 1 :: prev' restd resti | |
344 | ||
345 | fun next shape index = (SOME (next' shape index)) handle | |
346 | Subscript => NONE | |
347 | ||
348 | fun prev shape index = (SOME (prev' shape index)) handle | |
349 | Subscript => NONE | |
350 | ||
351 | fun inBounds shape index = | |
352 | ListPair.all (fn (x,y) => (x >= 0) andalso (x < y)) | |
353 | (index, shape) | |
354 | ||
355 | fun compare ([],[]) = EQUAL | |
356 | | compare (_, []) = raise Index | |
357 | | compare ([],_) = raise Index | |
358 | | compare (a::ra, b::rb) = | |
359 | case Int.compare (a,b) of | |
360 | EQUAL => compare (ra,rb) | |
361 | | LESS => LESS | |
362 | | GREATER => GREATER | |
363 | ||
364 | local | |
365 | fun iterator a inner = | |
366 | let fun loop accum f = | |
367 | let fun innerloop i = | |
368 | if i < a | |
369 | then if inner (i::accum) f | |
370 | then innerloop (i+1) | |
371 | else false | |
372 | else true | |
373 | in innerloop 0 | |
374 | end | |
375 | in loop | |
376 | end | |
377 | fun build_iterator [a] = | |
378 | let fun loop accum f = | |
379 | let fun innerloop i = | |
380 | if i < a | |
381 | then if f (i::accum) | |
382 | then innerloop (i+1) | |
383 | else false | |
384 | else true | |
385 | in innerloop 0 | |
386 | end | |
387 | in loop | |
388 | end | |
389 | | build_iterator (a::rest) = iterator a (build_iterator rest) | |
390 | in | |
391 | fun all shape = build_iterator shape [] | |
392 | end | |
393 | ||
394 | local | |
395 | fun iterator a inner = | |
396 | let fun loop accum f = | |
397 | let fun innerloop i = | |
398 | if i < a | |
399 | then if inner (i::accum) f | |
400 | then true | |
401 | else innerloop (i+1) | |
402 | else false | |
403 | in innerloop 0 | |
404 | end | |
405 | in loop | |
406 | end | |
407 | fun build_iterator [a] = | |
408 | let fun loop accum f = | |
409 | let fun innerloop i = | |
410 | if i < a | |
411 | then if f (i::accum) | |
412 | then true | |
413 | else innerloop (i+1) | |
414 | else false | |
415 | in innerloop 0 | |
416 | end | |
417 | in loop | |
418 | end | |
419 | | build_iterator (a::rest) = iterator a (build_iterator rest) | |
420 | in | |
421 | fun any shape = build_iterator shape [] | |
422 | end | |
423 | ||
424 | local | |
425 | fun iterator a inner = | |
426 | let fun loop accum f = | |
427 | let fun innerloop i = | |
428 | if i < a | |
429 | then (inner (i::accum) f; | |
430 | innerloop (i+1)) | |
431 | else () | |
432 | in innerloop 0 | |
433 | end | |
434 | in loop | |
435 | end | |
436 | fun build_iterator [a] = | |
437 | let fun loop accum f = | |
438 | let fun innerloop i = | |
439 | if i < a | |
440 | then (f (i::accum); innerloop (i+1)) | |
441 | else () | |
442 | in innerloop 0 | |
443 | end | |
444 | in loop | |
445 | end | |
446 | | build_iterator (a::rest) = iterator a (build_iterator rest) | |
447 | in | |
448 | fun app shape = build_iterator shape [] | |
449 | end | |
450 | ||
451 | fun a < b = compare(a,b) = LESS | |
452 | fun a > b = compare(a,b) = GREATER | |
453 | fun eq (a, b) = compare(a,b) = EQUAL | |
454 | fun a <> b = not (a = b) | |
455 | fun a <= b = not (a > b) | |
456 | fun a >= b = not (a < b) | |
457 | fun a - b = ListPair.map Int.- (a,b) | |
458 | ||
459 | end | |
460 | (* | |
461 | Copyright (c) Juan Jose Garcia Ripoll. | |
462 | All rights reserved. | |
463 | ||
464 | Refer to the COPYRIGHT file for license conditions | |
465 | *) | |
466 | ||
467 | (* | |
468 | TENSOR - Signature - | |
469 | ||
470 | Polymorphic tensors of any type. With 'tensor' we denote a (mutable) | |
471 | array of any rank, with as many indices as one wishes, and that may | |
472 | be traversed (map, fold, etc) according to any of those indices. | |
473 | ||
474 | type 'a tensor | |
475 | Polymorphic tensor whose elements are all of type 'a. | |
476 | val storage = RowMajor | ColumnMajor | |
477 | RowMajor = data is stored in consecutive cells, first index | |
478 | varying fastest (FORTRAN convention) | |
479 | ColumnMajor = data is stored in consecutive cells, last | |
480 | index varying fastest (C,C++,Pascal,CommonLisp convention) | |
481 | new ([i1,...,in],init) | |
482 | Build a new tensor with n indices, each of sizes i1...in, | |
483 | filled with 'init'. | |
484 | fromArray (shape,data) | |
485 | fromList (shape,data) | |
486 | Use 'data' to fill a tensor of that shape. An exception is | |
487 | raised if 'data' is too large or too small to properly | |
488 | fill the vector. Later use of a 'data' array is disregarded | |
489 | -- one must think that the tensor now owns the array. | |
490 | length tensor | |
491 | rank tensor | |
492 | shape tensor | |
493 | Return the number of elements, the number of indices and | |
494 | the shape (size of each index) of the tensor. | |
495 | toArray tensor | |
496 | Return the data of the tensor in the form of an array. | |
497 | Mutation of this array may lead to unexpected behavior. | |
498 | ||
499 | sub (tensor,[i1,...,in]) | |
500 | update (tensor,[i1,...,in],new_value) | |
501 | Access the element that is indexed by the numbers [i1,..,in] | |
502 | ||
503 | app f a | |
504 | appi f a | |
505 | The same as 'map' and 'mapi' but the function 'f' outputs | |
506 | nothing and no new array is produced, i.e. one only seeks | |
507 | the side effect that 'f' may produce. | |
508 | map2 operation a b | |
509 | Apply function 'f' to pairs of elements of 'a' and 'b' | |
510 | and build a new tensor with the output. Both operands | |
511 | must have the same shape or an exception is raised. | |
512 | The procedure is sequential, as specified by 'storage'. | |
513 | foldl operation a n | |
514 | Fold-left the elements of tensor 'a' along the n-th | |
515 | index. | |
516 | all test a | |
517 | any test a | |
518 | Folded boolean tests on the elements of the tensor. | |
519 | *) | |
520 | ||
521 | signature TENSOR = | |
522 | sig | |
523 | structure Array : ARRAY | |
524 | structure Index : INDEX | |
525 | type index = Index.t | |
526 | type 'a tensor | |
527 | ||
528 | val new : index * 'a -> 'a tensor | |
529 | val tabulate : index * (index -> 'a) -> 'a tensor | |
530 | val length : 'a tensor -> int | |
531 | val rank : 'a tensor -> int | |
532 | val shape : 'a tensor -> (index) | |
533 | val reshape : index -> 'a tensor -> 'a tensor | |
534 | val fromList : index * 'a list -> 'a tensor | |
535 | val fromArray : index * 'a array -> 'a tensor | |
536 | val toArray : 'a tensor -> 'a array | |
537 | ||
538 | val sub : 'a tensor * index -> 'a | |
539 | val update : 'a tensor * index * 'a -> unit | |
540 | val map : ('a -> 'b) -> 'a tensor -> 'b tensor | |
541 | val map2 : ('a * 'b -> 'c) -> 'a tensor -> 'b tensor -> 'c tensor | |
542 | val app : ('a -> unit) -> 'a tensor -> unit | |
543 | val appi : (int * 'a -> unit) -> 'a tensor -> unit | |
544 | val foldl : ('c * 'a -> 'c) -> 'c -> 'a tensor -> int -> 'c tensor | |
545 | val all : ('a -> bool) -> 'a tensor -> bool | |
546 | val any : ('a -> bool) -> 'a tensor -> bool | |
547 | end | |
548 | ||
549 | (* | |
550 | Copyright (c) Juan Jose Garcia Ripoll. | |
551 | All rights reserved. | |
552 | ||
553 | Refer to the COPYRIGHT file for license conditions | |
554 | *) | |
555 | ||
556 | structure Tensor : TENSOR = | |
557 | struct | |
558 | structure Array = Array | |
559 | structure Index = Index | |
560 | ||
561 | type index = Index.t | |
562 | type 'a tensor = {shape : index, indexer : Index.indexer, data : 'a array} | |
563 | ||
564 | exception Shape | |
565 | exception Match | |
566 | exception Index | |
567 | ||
568 | local | |
569 | (*----- LOCALS -----*) | |
570 | ||
571 | fun make' (shape, data) = | |
572 | {shape = shape, indexer = Index.indexer shape, data = data} | |
573 | ||
574 | fun toInt {shape, indexer, data} index = indexer index | |
575 | ||
576 | fun array_map f a = | |
577 | let fun apply index = f(Array.sub(a,index)) in | |
578 | Array.tabulate(Array.length a, apply) | |
579 | end | |
580 | ||
581 | fun splitList (l as (a::rest), place) = | |
582 | let fun loop (left,here,right) 0 = (List.rev left,here,right) | |
583 | | loop (_,_,[]) place = raise Index | |
584 | | loop (left,here,a::right) place = | |
585 | loop (here::left,a,right) (place-1) | |
586 | in | |
587 | if place <= 0 then | |
588 | loop ([],a,rest) (List.length rest - place) | |
589 | else | |
590 | loop ([],a,rest) (place - 1) | |
591 | end | |
592 | ||
593 | in | |
594 | (*----- STRUCTURAL OPERATIONS & QUERIES ------*) | |
595 | ||
596 | fun new (shape, init) = | |
597 | if not (Index.validShape shape) then | |
598 | raise Shape | |
599 | else | |
600 | let val length = Index.length shape in | |
601 | {shape = shape, | |
602 | indexer = Index.indexer shape, | |
603 | data = Array.array(length,init)} | |
604 | end | |
605 | ||
606 | fun toArray {shape, indexer, data} = data | |
607 | ||
608 | fun length {shape, indexer, data} = Array.length data | |
609 | ||
610 | fun shape {shape, indexer, data} = shape | |
611 | ||
612 | fun rank t = List.length (shape t) | |
613 | ||
614 | fun reshape new_shape tensor = | |
615 | if Index.validShape new_shape then | |
616 | case (Index.length new_shape) = length tensor of | |
617 | true => make'(new_shape, toArray tensor) | |
618 | | false => raise Match | |
619 | else | |
620 | raise Shape | |
621 | ||
622 | fun fromArray (s, a) = | |
623 | case Index.validShape s andalso | |
624 | ((Index.length s) = (Array.length a)) of | |
625 | true => make'(s, a) | |
626 | | false => raise Shape | |
627 | ||
628 | fun fromList (s, a) = fromArray (s, Array.fromList a) | |
629 | ||
630 | fun tabulate (shape,f) = | |
631 | if Index.validShape shape then | |
632 | let val last = Index.last shape | |
633 | val length = Index.length shape | |
634 | val c = Array.array(length, f last) | |
635 | fun dotable (c, indices, i) = | |
636 | (Array.update(c, i, f indices); | |
637 | case i of | |
638 | 0 => c | |
639 | | i => dotable(c, Index.prev' shape indices, i-1)) | |
640 | in | |
641 | make'(shape,dotable(c, Index.prev' shape last, length-1)) | |
642 | end | |
643 | else | |
644 | raise Shape | |
645 | ||
646 | (*----- ELEMENTWISE OPERATIONS -----*) | |
647 | ||
648 | fun sub (t, index) = Array.sub(#data t, toInt t index) | |
649 | ||
650 | fun update (t, index, value) = | |
651 | Array.update(toArray t, toInt t index, value) | |
652 | ||
653 | fun map f {shape, indexer, data} = | |
654 | {shape = shape, indexer = indexer, data = array_map f data} | |
655 | ||
656 | fun map2 f t1 t2= | |
657 | let val {shape, indexer, data} = t1 | |
658 | val {shape=shape2, indexer=indexer2, data=data2} = t2 | |
659 | fun apply i = f (Array.sub(data,i), Array.sub(data2,i)) | |
660 | val len = Array.length data | |
661 | in | |
662 | if Index.eq(shape, shape2) then | |
663 | {shape = shape, | |
664 | indexer = indexer, | |
665 | data = Array.tabulate(len, apply)} | |
666 | else | |
667 | raise Match | |
668 | end | |
669 | ||
670 | fun appi f tensor = Array.appi f (toArray tensor) | |
671 | ||
672 | fun app f tensor = Array.app f (toArray tensor) | |
673 | ||
674 | fun all f tensor = | |
675 | let val a = toArray tensor | |
676 | in Loop.all(0, length tensor - 1, fn i => | |
677 | f (Array.sub(a, i))) | |
678 | end | |
679 | ||
680 | fun any f tensor = | |
681 | let val a = toArray tensor | |
682 | in Loop.any(0, length tensor - 1, fn i => | |
683 | f (Array.sub(a, i))) | |
684 | end | |
685 | ||
686 | fun foldl f init {shape, indexer, data=a} index = | |
687 | let val (head,lk,tail) = splitList(shape, index) | |
688 | val li = Index.length head | |
689 | val lj = Index.length tail | |
690 | val c = Array.array(li * lj,init) | |
691 | fun loopi (0, _, _) = () | |
692 | | loopi (i, ia, ic) = | |
693 | (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia))); | |
694 | loopi (i-1, ia+1, ic+1)) | |
695 | fun loopk (0, ia, _) = ia | |
696 | | loopk (k, ia, ic) = (loopi (li, ia, ic); | |
697 | loopk (k-1, ia+li, ic)) | |
698 | fun loopj (0, _, _) = () | |
699 | | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li) | |
700 | in | |
701 | loopj (lj, 0, 0); | |
702 | make'(head @ tail, c) | |
703 | end | |
704 | ||
705 | end | |
706 | end (* Tensor *) | |
707 | ||
708 | (* | |
709 | Copyright (c) Juan Jose Garcia Ripoll. | |
710 | All rights reserved. | |
711 | ||
712 | Refer to the COPYRIGHT file for license conditions | |
713 | *) | |
714 | ||
715 | (* | |
716 | MONO_TENSOR - signature - | |
717 | ||
718 | Monomorphic tensor of arbitrary data (not only numbers). Operations | |
719 | should be provided to run the data in several ways, according to one | |
720 | index. | |
721 | ||
722 | type tensor | |
723 | The type of the tensor itself | |
724 | type elem | |
725 | The type of every element | |
726 | val storage = RowMajor | ColumnMajor | |
727 | RowMajor = data is stored in consecutive cells, first index | |
728 | varying fastest (FORTRAN convention) | |
729 | ColumnMajor = data is stored in consecutive cells, last | |
730 | index varying fastest (C,C++,Pascal,CommonLisp convention) | |
731 | new ([i1,...,in],init) | |
732 | Build a new tensor with n indices, each of sizes i1...in, | |
733 | filled with 'init'. | |
734 | fromArray (shape,data) | |
735 | fromList (shape,data) | |
736 | Use 'data' to fill a tensor of that shape. An exception is | |
737 | raised if 'data' is too large or too small to properly | |
738 | fill the vector. Later use of a 'data' array is disregarded | |
739 | -- one must think that the tensor now owns the array. | |
740 | length tensor | |
741 | rank tensor | |
742 | shape tensor | |
743 | Return the number of elements, the number of indices and | |
744 | the shape (size of each index) of the tensor. | |
745 | toArray tensor | |
746 | Return the data of the tensor in the form of an array. | |
747 | Mutation of this array may lead to unexpected behavior. | |
748 | The data in the array is stored according to `storage'. | |
749 | ||
750 | sub (tensor,[i1,...,in]) | |
751 | update (tensor,[i1,...,in],new_value) | |
752 | Access the element that is indexed by the numbers [i1,..,in] | |
753 | ||
754 | map f a | |
755 | mapi f a | |
756 | Produce a new array by mapping the function sequentially | |
757 | as specified by 'storage', to each element of tensor 'a'. | |
758 | In 'mapi' the function receives a (indices,value) tuple, | |
759 | while in 'map' it only receives the value. | |
760 | app f a | |
761 | appi f a | |
762 | The same as 'map' and 'mapi' but the function 'f' outputs | |
763 | nothing and no new array is produced, i.e. one only seeks | |
764 | the side effect that 'f' may produce. | |
765 | map2 operation a b | |
766 | Apply function 'f' to pairs of elements of 'a' and 'b' | |
767 | and build a new tensor with the output. Both operands | |
768 | must have the same shape or an exception is raised. | |
769 | The procedure is sequential, as specified by 'storage'. | |
770 | foldl operation a n | |
771 | Fold-left the elements of tensor 'a' along the n-th | |
772 | index. | |
773 | all test a | |
774 | any test a | |
775 | Folded boolean tests on the elements of the tensor. | |
776 | ||
777 | map', map2', foldl' | |
778 | Polymorphic versions of map, map2, foldl. | |
779 | *) | |
780 | ||
781 | signature MONO_TENSOR = | |
782 | sig | |
783 | structure Array : MONO_ARRAY | |
784 | structure Index : INDEX | |
785 | type index = Index.t | |
786 | type elem | |
787 | type tensor | |
788 | type t = tensor | |
789 | ||
790 | val new : index * elem -> tensor | |
791 | val tabulate : index * (index -> elem) -> tensor | |
792 | val length : tensor -> int | |
793 | val rank : tensor -> int | |
794 | val shape : tensor -> (index) | |
795 | val reshape : index -> tensor -> tensor | |
796 | val fromList : index * elem list -> tensor | |
797 | val fromArray : index * Array.array -> tensor | |
798 | val toArray : tensor -> Array.array | |
799 | ||
800 | val sub : tensor * index -> elem | |
801 | val update : tensor * index * elem -> unit | |
802 | val map : (elem -> elem) -> tensor -> tensor | |
803 | val map2 : (elem * elem -> elem) -> tensor -> tensor -> tensor | |
804 | val app : (elem -> unit) -> tensor -> unit | |
805 | val appi : (int * elem -> unit) -> tensor -> unit | |
806 | val foldl : (elem * 'a -> 'a) -> 'a -> tensor -> tensor | |
807 | val foldln : (elem * elem -> elem) -> elem -> tensor -> int -> tensor | |
808 | val all : (elem -> bool) -> tensor -> bool | |
809 | val any : (elem -> bool) -> tensor -> bool | |
810 | ||
811 | val map' : (elem -> 'a) -> tensor -> 'a Tensor.tensor | |
812 | val map2' : (elem * elem -> 'a) -> tensor -> tensor -> 'a Tensor.tensor | |
813 | val foldl' : ('a * elem -> 'a) -> 'a -> tensor -> int -> 'a Tensor.tensor | |
814 | end | |
815 | ||
816 | (* | |
817 | NUMBER - Signature - | |
818 | ||
819 | Guarantees a structure with a minimal number of mathematical operations | |
820 | so as to build an algebraic structure named Tensor. | |
821 | *) | |
822 | ||
823 | signature NUMBER = | |
824 | sig | |
825 | type t | |
826 | val zero : t | |
827 | val one : t | |
828 | val ~ : t -> t | |
829 | val + : t * t -> t | |
830 | val - : t * t -> t | |
831 | val * : t * t -> t | |
832 | val / : t * t -> t | |
833 | val toString : t -> string | |
834 | end | |
835 | ||
836 | signature NUMBER = | |
837 | sig | |
838 | type t | |
839 | val zero : t | |
840 | val one : t | |
841 | ||
842 | val + : t * t -> t | |
843 | val - : t * t -> t | |
844 | val * : t * t -> t | |
845 | val *+ : t * t * t -> t | |
846 | val *- : t * t * t -> t | |
847 | val ** : t * int -> t | |
848 | ||
849 | val ~ : t -> t | |
850 | val abs : t -> t | |
851 | val signum : t -> t | |
852 | ||
853 | val == : t * t -> bool | |
854 | val != : t * t -> bool | |
855 | ||
856 | val toString : t -> string | |
857 | val fromInt : int -> t | |
858 | val scan : (char,'a) StringCvt.reader -> (t,'a) StringCvt.reader | |
859 | end | |
860 | ||
861 | signature INTEGRAL_NUMBER = | |
862 | sig | |
863 | include NUMBER | |
864 | ||
865 | val quot : t * t -> t | |
866 | val rem : t * t -> t | |
867 | val mod : t * t -> t | |
868 | val div : t * t -> t | |
869 | ||
870 | val compare : t * t -> order | |
871 | val < : t * t -> bool | |
872 | val > : t * t -> bool | |
873 | val <= : t * t -> bool | |
874 | val >= : t * t -> bool | |
875 | ||
876 | val max : t * t -> t | |
877 | val min : t * t -> t | |
878 | end | |
879 | ||
880 | signature FRACTIONAL_NUMBER = | |
881 | sig | |
882 | include NUMBER | |
883 | ||
884 | val pi : t | |
885 | val e : t | |
886 | ||
887 | val / : t * t -> t | |
888 | val recip : t -> t | |
889 | ||
890 | val ln : t -> t | |
891 | val pow : t * t -> t | |
892 | val exp : t -> t | |
893 | val sqrt : t -> t | |
894 | ||
895 | val cos : t -> t | |
896 | val sin : t -> t | |
897 | val tan : t -> t | |
898 | val sinh : t -> t | |
899 | val cosh : t -> t | |
900 | val tanh : t -> t | |
901 | ||
902 | val acos : t -> t | |
903 | val asin : t -> t | |
904 | val atan : t -> t | |
905 | val asinh : t -> t | |
906 | val acosh : t -> t | |
907 | val atanh : t -> t | |
908 | val atan2 : t * t -> t | |
909 | end | |
910 | ||
911 | signature REAL_NUMBER = | |
912 | sig | |
913 | include FRACTIONAL_NUMBER | |
914 | ||
915 | val compare : t * t -> order | |
916 | val < : t * t -> bool | |
917 | val > : t * t -> bool | |
918 | val <= : t * t -> bool | |
919 | val >= : t * t -> bool | |
920 | ||
921 | val max : t * t -> t | |
922 | val min : t * t -> t | |
923 | end | |
924 | ||
925 | signature COMPLEX_NUMBER = | |
926 | sig | |
927 | include FRACTIONAL_NUMBER | |
928 | ||
929 | structure Real : REAL_NUMBER | |
930 | type real = Real.t | |
931 | ||
932 | val make : real * real -> t | |
933 | val split : t -> real * real | |
934 | val realPart : t -> real | |
935 | val imagPart : t -> real | |
936 | val abs2 : t -> real | |
937 | end | |
938 | ||
939 | structure INumber : INTEGRAL_NUMBER = | |
940 | struct | |
941 | open Int | |
942 | type t = Int.int | |
943 | val zero = 0 | |
944 | val one = 1 | |
945 | ||
946 | infix ** | |
947 | fun i ** n = | |
948 | let fun loop 0 = 1 | |
949 | | loop 1 = i | |
950 | | loop n = | |
951 | let val x = loop (Int.div(n, 2)) | |
952 | val m = Int.mod(n, 2) | |
953 | in | |
954 | if m = 0 then | |
955 | x * x | |
956 | else | |
957 | x * x * i | |
958 | end | |
959 | in if n < 0 | |
960 | then raise Domain | |
961 | else loop n | |
962 | end | |
963 | ||
964 | fun signum i = case compare(i, 0) of | |
965 | GREATER => 1 | |
966 | | EQUAL => 0 | |
967 | | LESS => ~1 | |
968 | ||
969 | infix == | |
970 | infix != | |
971 | fun a == b = a = b | |
972 | fun a != b = (a <> b) | |
973 | fun *+(b,c,a) = b * c + a | |
974 | fun *-(b,c,a) = b * c - b | |
975 | ||
976 | fun scan getc = Int.scan StringCvt.DEC getc | |
977 | end | |
978 | ||
979 | structure RNumber : REAL_NUMBER = | |
980 | struct | |
981 | open Real | |
982 | open Real.Math | |
983 | type t = Real.real | |
984 | val zero = 0.0 | |
985 | val one = 1.0 | |
986 | ||
987 | fun signum x = case compare(x,0.0) of | |
988 | LESS => ~1.0 | |
989 | | GREATER => 1.0 | |
990 | | EQUAL => 0.0 | |
991 | ||
992 | fun recip x = 1.0 / x | |
993 | ||
994 | infix ** | |
995 | fun i ** n = | |
996 | let fun loop 0 = one | |
997 | | loop 1 = i | |
998 | | loop n = | |
999 | let val x = loop (Int.div(n, 2)) | |
1000 | val m = Int.mod(n, 2) | |
1001 | in | |
1002 | if m = 0 then | |
1003 | x * x | |
1004 | else | |
1005 | x * x * i | |
1006 | end | |
1007 | in if Int.<(n, 0) | |
1008 | then raise Domain | |
1009 | else loop n | |
1010 | end | |
1011 | ||
1012 | fun max (a, b) = if a < b then b else a | |
1013 | fun min (a, b) = if a < b then a else b | |
1014 | ||
1015 | fun asinh x = ln (x + sqrt(1.0 + x * x)) | |
1016 | fun acosh x = ln (x + (x + 1.0) * sqrt((x - 1.0)/(x + 1.0))) | |
1017 | fun atanh x = ln ((1.0 + x) / sqrt(1.0 - x * x)) | |
1018 | ||
1019 | end | |
1020 | (* | |
1021 | Complex(R) - Functor - | |
1022 | ||
1023 | Provides support for complex numbers based on tuples. Should be | |
1024 | highly efficient as most operations can be inlined. | |
1025 | *) | |
1026 | ||
1027 | structure CNumber : COMPLEX_NUMBER = | |
1028 | struct | |
1029 | structure Real = RNumber | |
1030 | ||
1031 | type t = Real.t * Real.t | |
1032 | type real = Real.t | |
1033 | ||
1034 | val zero = (0.0,0.0) | |
1035 | val one = (1.0,0.0) | |
1036 | val pi = (Real.pi, 0.0) | |
1037 | val e = (Real.e, 0.0) | |
1038 | ||
1039 | fun make (r,i) = (r,i) : t | |
1040 | fun split z = z | |
1041 | fun realPart (r,_) = r | |
1042 | fun imagPart (_,i) = i | |
1043 | ||
1044 | fun abs2 (r,i) = Real.+(Real.*(r,r),Real.*(i,i)) (* FIXME!!! *) | |
1045 | fun arg (r,i) = Real.atan2(i,r) | |
1046 | fun modulus z = Real.sqrt(abs2 z) | |
1047 | fun abs z = (modulus z, 0.0) | |
1048 | fun signum (z as (r,i)) = | |
1049 | let val m = modulus z | |
1050 | in (Real./(r,m), Real./(i,m)) | |
1051 | end | |
1052 | ||
1053 | fun ~ (r1,i1) = (Real.~ r1, Real.~ i1) | |
1054 | fun (r1,i1) + (r2,i2) = (Real.+(r1,r2), Real.+(i1,i2)) | |
1055 | fun (r1,i1) - (r2,i2) = (Real.-(r1,r2), Real.-(i1,i1)) | |
1056 | fun (r1,i1) * (r2,i2) = (Real.-(Real.*(r1,r2),Real.*(i1,i2)), | |
1057 | Real.+(Real.*(r1,i2),Real.*(r2,i1))) | |
1058 | fun (r1,i1) / (r2,i2) = | |
1059 | let val modulus = abs2(r2,i2) | |
1060 | val (nr,ni) = (r1,i1) * (r2,i2) | |
1061 | in | |
1062 | (Real./(nr,modulus), Real./(ni,modulus)) | |
1063 | end | |
1064 | fun *+((r1,i1),(r2,i2),(r0,i0)) = | |
1065 | (Real.*+(Real.~ i1, i2, Real.*+(r1,r2,r0)), | |
1066 | Real.*+(r2, i2, Real.*+(r1,i2,i0))) | |
1067 | fun *-((r1,i1),(r2,i2),(r0,i0)) = | |
1068 | (Real.*+(Real.~ i1, i2, Real.*-(r1,r2,r0)), | |
1069 | Real.*+(r2, i2, Real.*-(r1,i2,i0))) | |
1070 | ||
1071 | infix ** | |
1072 | fun i ** n = | |
1073 | let fun loop 0 = one | |
1074 | | loop 1 = i | |
1075 | | loop n = | |
1076 | let val x = loop (Int.div(n, 2)) | |
1077 | val m = Int.mod(n, 2) | |
1078 | in | |
1079 | if m = 0 then | |
1080 | x * x | |
1081 | else | |
1082 | x * x * i | |
1083 | end | |
1084 | in if Int.<(n, 0) | |
1085 | then raise Domain | |
1086 | else loop n | |
1087 | end | |
1088 | ||
1089 | fun recip (r1, i1) = | |
1090 | let val modulus = abs2(r1, i1) | |
1091 | in (Real./(r1, modulus), Real./(Real.~ i1, modulus)) | |
1092 | end | |
1093 | fun ==(z, w) = Real.==(realPart z, realPart w) andalso Real.==(imagPart z, imagPart w) | |
1094 | fun !=(z, w) = Real.!=(realPart z, realPart w) andalso Real.!=(imagPart z, imagPart w) | |
1095 | fun fromInt i = (Real.fromInt i, 0.0) | |
1096 | fun toString (r,i) = | |
1097 | String.concat ["(",Real.toString r,",",Real.toString i,")"] | |
1098 | ||
1099 | fun exp (x, y) = | |
1100 | let val expx = Real.exp x | |
1101 | in (Real.*(x, (Real.cos y)), Real.*(x, (Real.sin y))) | |
1102 | end | |
1103 | ||
1104 | local | |
1105 | val half = Real.recip (Real.fromInt 2) | |
1106 | in | |
1107 | fun sqrt (z as (x,y)) = | |
1108 | if Real.==(x, 0.0) andalso Real.==(y, 0.0) then | |
1109 | zero | |
1110 | else | |
1111 | let val m = Real.+(modulus z, Real.abs x) | |
1112 | val u' = Real.sqrt (Real.*(m, half)) | |
1113 | val v' = Real./(Real.abs y , Real.+(u',u')) | |
1114 | val (u,v) = if Real.<(x, 0.0) then (v',u') else (u',v') | |
1115 | in (u, if Real.<(y, 0.0) then Real.~ v else v) | |
1116 | end | |
1117 | end | |
1118 | fun ln z = (Real.ln (modulus z), arg z) | |
1119 | ||
1120 | fun pow (z, n) = | |
1121 | let val l = ln z | |
1122 | in exp (l * n) | |
1123 | end | |
1124 | ||
1125 | fun sin (x, y) = (Real.*(Real.sin x, Real.cosh y), | |
1126 | Real.*(Real.cos x, Real.sinh y)) | |
1127 | fun cos (x, y) = (Real.*(Real.cos x, Real.cosh y), | |
1128 | Real.~ (Real.*(Real.sin x, Real.sinh y))) | |
1129 | fun tan (x, y) = | |
1130 | let val (sx, cx) = (Real.sin x, Real.cos x) | |
1131 | val (shy, chy) = (Real.sinh y, Real.cosh y) | |
1132 | val a = (Real.*(sx, chy), Real.*(cx, shy)) | |
1133 | val b = (Real.*(cx, chy), Real.*(Real.~ sx, shy)) | |
1134 | in a / b | |
1135 | end | |
1136 | ||
1137 | fun sinh (x, y) = (Real.*(Real.cos y, Real.sinh x), | |
1138 | Real.*(Real.sin y, Real.cosh x)) | |
1139 | fun cosh (x, y) = (Real.*(Real.cos y, Real.cosh x), | |
1140 | Real.*(Real.sin y, Real.sinh x)) | |
1141 | fun tanh (x, y) = | |
1142 | let val (sy, cy) = (Real.sin y, Real.cos y) | |
1143 | val (shx, chx) = (Real.sinh x, Real.cosh x) | |
1144 | val a = (Real.*(cy, shx), Real.*(sy, chx)) | |
1145 | val b = (Real.*(cy, chx), Real.*(sy, shx)) | |
1146 | in a / b | |
1147 | end | |
1148 | ||
1149 | fun asin (z as (x,y)) = | |
1150 | let val w = sqrt (one - z * z) | |
1151 | val (x',y') = ln ((Real.~ y, x) + w) | |
1152 | in (y', Real.~ x') | |
1153 | end | |
1154 | ||
1155 | fun acos (z as (x,y)) = | |
1156 | let val (x', y') = sqrt (one + z * z) | |
1157 | val (x'', y'') = ln (z + (Real.~ y', x')) | |
1158 | in (y'', Real.~ x'') | |
1159 | end | |
1160 | ||
1161 | fun atan (z as (x,y)) = | |
1162 | let val w = sqrt (one + z*z) | |
1163 | val (x',y') = ln ((Real.-(1.0, y), x) / w) | |
1164 | in (y', Real.~ x') | |
1165 | end | |
1166 | ||
1167 | fun atan2 (y, x) = atan(y / x) | |
1168 | ||
1169 | fun asinh x = ln (x + sqrt(one + x * x)) | |
1170 | fun acosh x = ln (x + (x + one) * sqrt((x - one)/(x + one))) | |
1171 | fun atanh x = ln ((one + x) / sqrt(one - x * x)) | |
1172 | ||
1173 | fun scan getc = | |
1174 | let val scanner = Real.scan getc | |
1175 | in fn stream => | |
1176 | case scanner stream of | |
1177 | NONE => NONE | |
1178 | | SOME (a, rest) => | |
1179 | case scanner rest of | |
1180 | NONE => NONE | |
1181 | | SOME (b, rest) => SOME (make(a,b), rest) | |
1182 | end | |
1183 | ||
1184 | end (* ComplexNumber *) | |
1185 | ||
1186 | (* | |
1187 | Copyright (c) Juan Jose Garcia Ripoll. | |
1188 | All rights reserved. | |
1189 | ||
1190 | Refer to the COPYRIGHT file for license conditions | |
1191 | *) | |
1192 | ||
1193 | structure INumberArray = | |
1194 | struct | |
1195 | open Array | |
1196 | type array = INumber.t array | |
1197 | type vector = INumber.t vector | |
1198 | type elem = INumber.t | |
1199 | structure Vector = | |
1200 | struct | |
1201 | open Vector | |
1202 | type vector = INumber.t Vector.vector | |
1203 | type elem = INumber.t | |
1204 | end | |
1205 | fun map f a = tabulate(length a, fn x => (f (sub(a,x)))) | |
1206 | fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x)))) | |
1207 | fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x)))) | |
1208 | end | |
1209 | ||
1210 | structure RNumberArray = | |
1211 | struct | |
1212 | open Real64Array | |
1213 | val sub = Unsafe.Real64Array.sub | |
1214 | val update = Unsafe.Real64Array.update | |
1215 | fun map f a = tabulate(length a, fn x => (f (sub(a,x)))) | |
1216 | fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x)))) | |
1217 | fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x)))) | |
1218 | end | |
1219 | ||
1220 | (*--------------------- COMPLEX ARRAY -------------------------*) | |
1221 | ||
1222 | structure BasicCNumberArray = | |
1223 | struct | |
1224 | structure Complex : COMPLEX_NUMBER = CNumber | |
1225 | structure Array : MONO_ARRAY = RNumberArray | |
1226 | ||
1227 | type elem = Complex.t | |
1228 | type array = Array.array * Array.array | |
1229 | ||
1230 | val maxLen = Array.maxLen | |
1231 | ||
1232 | fun length (a,b) = Array.length a | |
1233 | ||
1234 | fun sub ((a,b),index) = Complex.make(Array.sub(a,index),Array.sub(b,index)) | |
1235 | ||
1236 | fun update ((a,b),index,z) = | |
1237 | let val (re,im) = Complex.split z in | |
1238 | Array.update(a, index, re); | |
1239 | Array.update(b, index, im) | |
1240 | end | |
1241 | ||
1242 | local | |
1243 | fun makeRange (a, start, NONE) = makeRange(a, start, SOME (length a - 1)) | |
1244 | | makeRange (a, start, SOME last) = | |
1245 | let val len = length a | |
1246 | val diff = last - start | |
1247 | in | |
1248 | if (start >= len) orelse (last >= len) then | |
1249 | raise Subscript | |
1250 | else if diff < 0 then | |
1251 | (a, start, 0) | |
1252 | else | |
1253 | (a, start, diff + 1) | |
1254 | end | |
1255 | ||
1256 | in | |
1257 | ||
1258 | fun array (size,z:elem) = | |
1259 | let val realsize = size * 2 | |
1260 | val r = Complex.realPart z | |
1261 | val i = Complex.imagPart z in | |
1262 | (Array.array(size,r), Array.array(size,i)) | |
1263 | end | |
1264 | ||
1265 | fun zeroarray size = | |
1266 | (Array.array(size,Complex.Real.zero), | |
1267 | Array.array(size,Complex.Real.zero)) | |
1268 | ||
1269 | fun tabulate (size,f) = | |
1270 | let val a = array(size, Complex.zero) | |
1271 | fun loop i = | |
1272 | case i = size of | |
1273 | true => a | |
1274 | | false => (update(a, i, f i); loop (i+1)) | |
1275 | in | |
1276 | loop 0 | |
1277 | end | |
1278 | ||
1279 | fun fromList list = | |
1280 | let val length = List.length list | |
1281 | val a = zeroarray length | |
1282 | fun loop (_, []) = a | |
1283 | | loop (i, z::rest) = (update(a, i, z); | |
1284 | loop (i+1, rest)) | |
1285 | in | |
1286 | loop(0,list) | |
1287 | end | |
1288 | ||
1289 | fun extract range = | |
1290 | let val (a, start, len) = makeRange range | |
1291 | fun copy i = sub(a, i + start) | |
1292 | in tabulate(len, copy) | |
1293 | end | |
1294 | ||
1295 | fun concat array_list = | |
1296 | let val total_length = foldl (op +) 0 (map length array_list) | |
1297 | val a = array(total_length, Complex.zero) | |
1298 | fun copy (_, []) = a | |
1299 | | copy (pos, v::rest) = | |
1300 | let fun loop i = | |
1301 | case i = 0 of | |
1302 | true => () | |
1303 | | false => (update(a, i+pos, sub(v, i)); loop (i-1)) | |
1304 | in (loop (length v - 1); copy(length v + pos, rest)) | |
1305 | end | |
1306 | in | |
1307 | copy(0, array_list) | |
1308 | end | |
1309 | ||
1310 | fun copy {src : array, si : int, len : int option, dst : array, di : int } = | |
1311 | let val (a, ia, la) = makeRange (src, si, len) | |
1312 | val (b, ib, lb) = makeRange (dst, di, len) | |
1313 | fun copy i = | |
1314 | case i < 0 of | |
1315 | true => () | |
1316 | | false => (update(b, i+ib, sub(a, i+ia)); copy (i-1)) | |
1317 | in copy (la - 1) | |
1318 | end | |
1319 | ||
1320 | val copyVec = copy | |
1321 | ||
1322 | fun modifyi f range = | |
1323 | let val (a, start, len) = makeRange range | |
1324 | val last = start + len | |
1325 | fun loop i = | |
1326 | case i >= last of | |
1327 | true => () | |
1328 | | false => (update(a, i, f(i, sub(a,i))); loop (i+1)) | |
1329 | in loop start | |
1330 | end | |
1331 | ||
1332 | fun modify f a = | |
1333 | let val last = length a | |
1334 | fun loop i = | |
1335 | case i >= last of | |
1336 | true => () | |
1337 | | false => (update(a, i, f(sub(a,i))); loop (i+1)) | |
1338 | in loop 0 | |
1339 | end | |
1340 | ||
1341 | fun app f a = | |
1342 | let val size = length a | |
1343 | fun loop i = | |
1344 | case i = size of | |
1345 | true => () | |
1346 | | false => (f(sub(a,i)); loop (i+1)) | |
1347 | in | |
1348 | loop 0 | |
1349 | end | |
1350 | ||
1351 | fun appi f range = | |
1352 | let val (a, start, len) = makeRange range | |
1353 | val last = start + len | |
1354 | fun loop i = | |
1355 | case i >= last of | |
1356 | true => () | |
1357 | | false => (f(i, sub(a,i)); loop (i+1)) | |
1358 | in | |
1359 | loop start | |
1360 | end | |
1361 | ||
1362 | fun map f a = | |
1363 | let val len = length a | |
1364 | val c = zeroarray len | |
1365 | fun loop ~1 = c | |
1366 | | loop i = (update(a, i, f(sub(a,i))); loop (i-1)) | |
1367 | in loop (len-1) | |
1368 | end | |
1369 | ||
1370 | fun map2 f a b = | |
1371 | let val len = length a | |
1372 | val c = zeroarray len | |
1373 | fun loop ~1 = c | |
1374 | | loop i = (update(c, i, f(sub(a,i),sub(b,i))); | |
1375 | loop (i-1)) | |
1376 | in loop (len-1) | |
1377 | end | |
1378 | ||
1379 | fun mapi f range = | |
1380 | let val (a, start, len) = makeRange range | |
1381 | fun rule i = f (i+start, sub(a, i+start)) | |
1382 | in tabulate(len, rule) | |
1383 | end | |
1384 | ||
1385 | fun foldli f init range = | |
1386 | let val (a, start, len) = makeRange range | |
1387 | val last = start + len - 1 | |
1388 | fun loop (i, accum) = | |
1389 | case i > last of | |
1390 | true => accum | |
1391 | | false => loop (i+1, f(i, sub(a,i), accum)) | |
1392 | in loop (start, init) | |
1393 | end | |
1394 | ||
1395 | fun foldri f init range = | |
1396 | let val (a, start, len) = makeRange range | |
1397 | val last = start + len - 1 | |
1398 | fun loop (i, accum) = | |
1399 | case i < start of | |
1400 | true => accum | |
1401 | | false => loop (i-1, f(i, sub(a,i), accum)) | |
1402 | in loop (last, init) | |
1403 | end | |
1404 | ||
1405 | fun foldl f init a = foldli (fn (_, a, x) => f(a,x)) init (a,0,NONE) | |
1406 | fun foldr f init a = foldri (fn (_, x, a) => f(x,a)) init (a,0,NONE) | |
1407 | end | |
1408 | end (* BasicCNumberArray *) | |
1409 | ||
1410 | ||
1411 | structure CNumberArray = | |
1412 | struct | |
1413 | structure Vector = | |
1414 | struct | |
1415 | open BasicCNumberArray | |
1416 | type vector = array | |
1417 | end : MONO_VECTOR | |
1418 | type vector = Vector.vector | |
1419 | open BasicCNumberArray | |
1420 | end (* CNumberArray *) | |
1421 | structure INumber : INTEGRAL_NUMBER = | |
1422 | struct | |
1423 | open Int | |
1424 | type t = Int.int | |
1425 | val zero = 0 | |
1426 | val one = 1 | |
1427 | infix ** | |
1428 | fun i ** n = | |
1429 | let fun loop 0 = 1 | |
1430 | | loop 1 = i | |
1431 | | loop n = | |
1432 | let val x = loop (Int.div(n, 2)) | |
1433 | val m = Int.mod(n, 2) | |
1434 | in | |
1435 | if m = 0 then | |
1436 | x * x | |
1437 | else | |
1438 | x * x * i | |
1439 | end | |
1440 | in if n < 0 | |
1441 | then raise Domain | |
1442 | else loop n | |
1443 | end | |
1444 | fun signum i = case compare(i, 0) of | |
1445 | GREATER => 1 | |
1446 | | EQUAL => 0 | |
1447 | | LESS => ~1 | |
1448 | infix == | |
1449 | infix != | |
1450 | fun a == b = a = b | |
1451 | fun a != b = (a <> b) | |
1452 | fun *+(b,c,a) = b * c + a | |
1453 | fun *-(b,c,a) = b * c - b | |
1454 | fun scan getc = Int.scan StringCvt.DEC getc | |
1455 | end | |
1456 | structure RNumber : REAL_NUMBER = | |
1457 | struct | |
1458 | open Real | |
1459 | open Real.Math | |
1460 | type t = Real.real | |
1461 | val zero = 0.0 | |
1462 | val one = 1.0 | |
1463 | fun signum x = case compare(x,0.0) of | |
1464 | LESS => ~1.0 | |
1465 | | GREATER => 1.0 | |
1466 | | EQUAL => 0.0 | |
1467 | fun recip x = 1.0 / x | |
1468 | infix ** | |
1469 | fun i ** n = | |
1470 | let fun loop 0 = one | |
1471 | | loop 1 = i | |
1472 | | loop n = | |
1473 | let val x = loop (Int.div(n, 2)) | |
1474 | val m = Int.mod(n, 2) | |
1475 | in | |
1476 | if m = 0 then | |
1477 | x * x | |
1478 | else | |
1479 | x * x * i | |
1480 | end | |
1481 | in if Int.<(n, 0) | |
1482 | then raise Domain | |
1483 | else loop n | |
1484 | end | |
1485 | fun max (a, b) = if a < b then b else a | |
1486 | fun min (a, b) = if a < b then a else b | |
1487 | fun asinh x = ln (x + sqrt(1.0 + x * x)) | |
1488 | fun acosh x = ln (x + (x + 1.0) * sqrt((x - 1.0)/(x + 1.0))) | |
1489 | fun atanh x = ln ((1.0 + x) / sqrt(1.0 - x * x)) | |
1490 | end | |
1491 | (* | |
1492 | Complex(R) - Functor - | |
1493 | Provides support for complex numbers based on tuples. Should be | |
1494 | highly efficient as most operations can be inlined. | |
1495 | *) | |
1496 | structure CNumber : COMPLEX_NUMBER = | |
1497 | struct | |
1498 | structure Real = RNumber | |
1499 | type t = Real.t * Real.t | |
1500 | type real = Real.t | |
1501 | val zero = (0.0,0.0) | |
1502 | val one = (1.0,0.0) | |
1503 | val pi = (Real.pi, 0.0) | |
1504 | val e = (Real.e, 0.0) | |
1505 | fun make (r,i) = (r,i) : t | |
1506 | fun split z = z | |
1507 | fun realPart (r,_) = r | |
1508 | fun imagPart (_,i) = i | |
1509 | fun abs2 (r,i) = Real.+(Real.*(r,r),Real.*(i,i)) (* FIXME!!! *) | |
1510 | fun arg (r,i) = Real.atan2(i,r) | |
1511 | fun modulus z = Real.sqrt(abs2 z) | |
1512 | fun abs z = (modulus z, 0.0) | |
1513 | fun signum (z as (r,i)) = | |
1514 | let val m = modulus z | |
1515 | in (Real./(r,m), Real./(i,m)) | |
1516 | end | |
1517 | fun ~ (r1,i1) = (Real.~ r1, Real.~ i1) | |
1518 | fun (r1,i1) + (r2,i2) = (Real.+(r1,r2), Real.+(i1,i2)) | |
1519 | fun (r1,i1) - (r2,i2) = (Real.-(r1,r2), Real.-(i1,i1)) | |
1520 | fun (r1,i1) * (r2,i2) = (Real.-(Real.*(r1,r2),Real.*(i1,i2)), | |
1521 | Real.+(Real.*(r1,i2),Real.*(r2,i1))) | |
1522 | fun (r1,i1) / (r2,i2) = | |
1523 | let val modulus = abs2(r2,i2) | |
1524 | val (nr,ni) = (r1,i1) * (r2,i2) | |
1525 | in | |
1526 | (Real./(nr,modulus), Real./(ni,modulus)) | |
1527 | end | |
1528 | fun *+((r1,i1),(r2,i2),(r0,i0)) = | |
1529 | (Real.*+(Real.~ i1, i2, Real.*+(r1,r2,r0)), | |
1530 | Real.*+(r2, i2, Real.*+(r1,i2,i0))) | |
1531 | fun *-((r1,i1),(r2,i2),(r0,i0)) = | |
1532 | (Real.*+(Real.~ i1, i2, Real.*-(r1,r2,r0)), | |
1533 | Real.*+(r2, i2, Real.*-(r1,i2,i0))) | |
1534 | infix ** | |
1535 | fun i ** n = | |
1536 | let fun loop 0 = one | |
1537 | | loop 1 = i | |
1538 | | loop n = | |
1539 | let val x = loop (Int.div(n, 2)) | |
1540 | val m = Int.mod(n, 2) | |
1541 | in | |
1542 | if m = 0 then | |
1543 | x * x | |
1544 | else | |
1545 | x * x * i | |
1546 | end | |
1547 | in if Int.<(n, 0) | |
1548 | then raise Domain | |
1549 | else loop n | |
1550 | end | |
1551 | fun recip (r1, i1) = | |
1552 | let val modulus = abs2(r1, i1) | |
1553 | in (Real./(r1, modulus), Real./(Real.~ i1, modulus)) | |
1554 | end | |
1555 | fun ==(z, w) = Real.==(realPart z, realPart w) andalso Real.==(imagPart z, imagPart w) | |
1556 | fun !=(z, w) = Real.!=(realPart z, realPart w) andalso Real.!=(imagPart z, imagPart w) | |
1557 | fun fromInt i = (Real.fromInt i, 0.0) | |
1558 | fun toString (r,i) = | |
1559 | String.concat ["(",Real.toString r,",",Real.toString i,")"] | |
1560 | fun exp (x, y) = | |
1561 | let val expx = Real.exp x | |
1562 | in (Real.*(x, (Real.cos y)), Real.*(x, (Real.sin y))) | |
1563 | end | |
1564 | local | |
1565 | val half = Real.recip (Real.fromInt 2) | |
1566 | in | |
1567 | fun sqrt (z as (x,y)) = | |
1568 | if Real.==(x, 0.0) andalso Real.==(y, 0.0) then | |
1569 | zero | |
1570 | else | |
1571 | let val m = Real.+(modulus z, Real.abs x) | |
1572 | val u' = Real.sqrt (Real.*(m, half)) | |
1573 | val v' = Real./(Real.abs y , Real.+(u',u')) | |
1574 | val (u,v) = if Real.<(x, 0.0) then (v',u') else (u',v') | |
1575 | in (u, if Real.<(y, 0.0) then Real.~ v else v) | |
1576 | end | |
1577 | end | |
1578 | fun ln z = (Real.ln (modulus z), arg z) | |
1579 | fun pow (z, n) = | |
1580 | let val l = ln z | |
1581 | in exp (l * n) | |
1582 | end | |
1583 | fun sin (x, y) = (Real.*(Real.sin x, Real.cosh y), | |
1584 | Real.*(Real.cos x, Real.sinh y)) | |
1585 | fun cos (x, y) = (Real.*(Real.cos x, Real.cosh y), | |
1586 | Real.~ (Real.*(Real.sin x, Real.sinh y))) | |
1587 | fun tan (x, y) = | |
1588 | let val (sx, cx) = (Real.sin x, Real.cos x) | |
1589 | val (shy, chy) = (Real.sinh y, Real.cosh y) | |
1590 | val a = (Real.*(sx, chy), Real.*(cx, shy)) | |
1591 | val b = (Real.*(cx, chy), Real.*(Real.~ sx, shy)) | |
1592 | in a / b | |
1593 | end | |
1594 | fun sinh (x, y) = (Real.*(Real.cos y, Real.sinh x), | |
1595 | Real.*(Real.sin y, Real.cosh x)) | |
1596 | fun cosh (x, y) = (Real.*(Real.cos y, Real.cosh x), | |
1597 | Real.*(Real.sin y, Real.sinh x)) | |
1598 | fun tanh (x, y) = | |
1599 | let val (sy, cy) = (Real.sin y, Real.cos y) | |
1600 | val (shx, chx) = (Real.sinh x, Real.cosh x) | |
1601 | val a = (Real.*(cy, shx), Real.*(sy, chx)) | |
1602 | val b = (Real.*(cy, chx), Real.*(sy, shx)) | |
1603 | in a / b | |
1604 | end | |
1605 | fun asin (z as (x,y)) = | |
1606 | let val w = sqrt (one - z * z) | |
1607 | val (x',y') = ln ((Real.~ y, x) + w) | |
1608 | in (y', Real.~ x') | |
1609 | end | |
1610 | fun acos (z as (x,y)) = | |
1611 | let val (x', y') = sqrt (one + z * z) | |
1612 | val (x'', y'') = ln (z + (Real.~ y', x')) | |
1613 | in (y'', Real.~ x'') | |
1614 | end | |
1615 | fun atan (z as (x,y)) = | |
1616 | let val w = sqrt (one + z*z) | |
1617 | val (x',y') = ln ((Real.-(1.0, y), x) / w) | |
1618 | in (y', Real.~ x') | |
1619 | end | |
1620 | fun atan2 (y, x) = atan(y / x) | |
1621 | fun asinh x = ln (x + sqrt(one + x * x)) | |
1622 | fun acosh x = ln (x + (x + one) * sqrt((x - one)/(x + one))) | |
1623 | fun atanh x = ln ((one + x) / sqrt(one - x * x)) | |
1624 | fun scan getc = | |
1625 | let val scanner = Real.scan getc | |
1626 | in fn stream => | |
1627 | case scanner stream of | |
1628 | NONE => NONE | |
1629 | | SOME (a, rest) => | |
1630 | case scanner rest of | |
1631 | NONE => NONE | |
1632 | | SOME (b, rest) => SOME (make(a,b), rest) | |
1633 | end | |
1634 | end (* ComplexNumber *) | |
1635 | (* | |
1636 | Copyright (c) Juan Jose Garcia Ripoll. | |
1637 | All rights reserved. | |
1638 | Refer to the COPYRIGHT file for license conditions | |
1639 | *) | |
1640 | structure PrettyPrint :> | |
1641 | sig | |
1642 | datatype modifier = | |
1643 | Int of int | | |
1644 | Real of real | | |
1645 | Complex of CNumber.t | | |
1646 | String of string | |
1647 | val list : ('a -> string) -> 'a list -> unit | |
1648 | val intList : int list -> unit | |
1649 | val realList : real list -> unit | |
1650 | val stringList : string list -> unit | |
1651 | val array : ('a -> string) -> 'a array -> unit | |
1652 | val intArray : int array -> unit | |
1653 | val realArray : real array -> unit | |
1654 | val stringArray : string array -> unit | |
1655 | val sequence : | |
1656 | int -> ((int * 'a -> unit) -> 'b -> unit) -> ('a -> string) -> 'b -> unit | |
1657 | val print : modifier list -> unit | |
1658 | end = | |
1659 | struct | |
1660 | datatype modifier = | |
1661 | Int of int | | |
1662 | Real of real | | |
1663 | Complex of CNumber.t | | |
1664 | String of string | |
1665 | fun list _ [] = print "[]" | |
1666 | | list cvt (a::resta) = | |
1667 | let fun loop a [] = (print(cvt a); print "]") | |
1668 | | loop a (b::restb) = (print(cvt a); print ", "; loop b restb) | |
1669 | in | |
1670 | print "["; | |
1671 | loop a resta | |
1672 | end | |
1673 | fun boolList a = list Bool.toString a | |
1674 | fun intList a = list Int.toString a | |
1675 | fun realList a = list Real.toString a | |
1676 | fun stringList a = list (fn x => x) a | |
1677 | fun array cvt a = | |
1678 | let val length = Array.length a - 1 | |
1679 | fun print_one (i,x) = | |
1680 | (print(cvt x); if not(i = length) then print ", " else ()) | |
1681 | in | |
1682 | Array.appi print_one a | |
1683 | end | |
1684 | fun boolArray a = array Bool.toString a | |
1685 | fun intArray a = array Int.toString a | |
1686 | fun realArray a = array Real.toString a | |
1687 | fun stringArray a = array (fn x => x) a | |
1688 | fun sequence length appi cvt seq = | |
1689 | let val length = length - 1 | |
1690 | fun print_one (i:int,x) = | |
1691 | (print(cvt x); if not(i = length) then print ", " else ()) | |
1692 | in | |
1693 | print "["; | |
1694 | appi print_one seq; | |
1695 | print "]\n" | |
1696 | end | |
1697 | fun print b = | |
1698 | let fun printer (Int a) = INumber.toString a | |
1699 | | printer (Real a) = RNumber.toString a | |
1700 | | printer (Complex a) = CNumber.toString a | |
1701 | | printer (String a) = a | |
1702 | in List.app (fn x => (TextIO.print (printer x))) b | |
1703 | end | |
1704 | end (* PrettyPrint *) | |
1705 | fun print' x = List.app print x | |
1706 | (* | |
1707 | Copyright (c) Juan Jose Garcia Ripoll. | |
1708 | All rights reserved. | |
1709 | Refer to the COPYRIGHT file for license conditions | |
1710 | *) | |
1711 | structure INumberArray = | |
1712 | struct | |
1713 | open Array | |
1714 | type array = INumber.t array | |
1715 | type vector = INumber.t vector | |
1716 | type elem = INumber.t | |
1717 | structure Vector = | |
1718 | struct | |
1719 | open Vector | |
1720 | type vector = INumber.t Vector.vector | |
1721 | type elem = INumber.t | |
1722 | end | |
1723 | fun map f a = tabulate(length a, fn x => (f (sub(a,x)))) | |
1724 | fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x)))) | |
1725 | fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x)))) | |
1726 | end | |
1727 | structure RNumberArray = | |
1728 | struct | |
1729 | open Real64Array | |
1730 | val sub = Unsafe.Real64Array.sub | |
1731 | val update = Unsafe.Real64Array.update | |
1732 | fun map f a = tabulate(length a, fn x => (f (sub(a,x)))) | |
1733 | fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x)))) | |
1734 | fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x)))) | |
1735 | end | |
1736 | (*--------------------- COMPLEX ARRAY -------------------------*) | |
1737 | structure BasicCNumberArray = | |
1738 | struct | |
1739 | structure Complex : COMPLEX_NUMBER = CNumber | |
1740 | structure Array : MONO_ARRAY = RNumberArray | |
1741 | type elem = Complex.t | |
1742 | type array = Array.array * Array.array | |
1743 | val maxLen = Array.maxLen | |
1744 | fun length (a,b) = Array.length a | |
1745 | fun sub ((a,b),index) = Complex.make(Array.sub(a,index),Array.sub(b,index)) | |
1746 | fun update ((a,b),index,z) = | |
1747 | let val (re,im) = Complex.split z in | |
1748 | Array.update(a, index, re); | |
1749 | Array.update(b, index, im) | |
1750 | end | |
1751 | local | |
1752 | fun makeRange (a, start, NONE) = makeRange(a, start, SOME (length a - 1)) | |
1753 | | makeRange (a, start, SOME last) = | |
1754 | let val len = length a | |
1755 | val diff = last - start | |
1756 | in | |
1757 | if (start >= len) orelse (last >= len) then | |
1758 | raise Subscript | |
1759 | else if diff < 0 then | |
1760 | (a, start, 0) | |
1761 | else | |
1762 | (a, start, diff + 1) | |
1763 | end | |
1764 | in | |
1765 | fun array (size,z:elem) = | |
1766 | let val realsize = size * 2 | |
1767 | val r = Complex.realPart z | |
1768 | val i = Complex.imagPart z in | |
1769 | (Array.array(size,r), Array.array(size,i)) | |
1770 | end | |
1771 | fun zeroarray size = | |
1772 | (Array.array(size,Complex.Real.zero), | |
1773 | Array.array(size,Complex.Real.zero)) | |
1774 | fun tabulate (size,f) = | |
1775 | let val a = array(size, Complex.zero) | |
1776 | fun loop i = | |
1777 | case i = size of | |
1778 | true => a | |
1779 | | false => (update(a, i, f i); loop (i+1)) | |
1780 | in | |
1781 | loop 0 | |
1782 | end | |
1783 | fun fromList list = | |
1784 | let val length = List.length list | |
1785 | val a = zeroarray length | |
1786 | fun loop (_, []) = a | |
1787 | | loop (i, z::rest) = (update(a, i, z); | |
1788 | loop (i+1, rest)) | |
1789 | in | |
1790 | loop(0,list) | |
1791 | end | |
1792 | fun extract range = | |
1793 | let val (a, start, len) = makeRange range | |
1794 | fun copy i = sub(a, i + start) | |
1795 | in tabulate(len, copy) | |
1796 | end | |
1797 | fun concat array_list = | |
1798 | let val total_length = foldl (op +) 0 (map length array_list) | |
1799 | val a = array(total_length, Complex.zero) | |
1800 | fun copy (_, []) = a | |
1801 | | copy (pos, v::rest) = | |
1802 | let fun loop i = | |
1803 | case i = 0 of | |
1804 | true => () | |
1805 | | false => (update(a, i+pos, sub(v, i)); loop (i-1)) | |
1806 | in (loop (length v - 1); copy(length v + pos, rest)) | |
1807 | end | |
1808 | in | |
1809 | copy(0, array_list) | |
1810 | end | |
1811 | fun copy {src : array, si : int, len : int option, dst : array, di : int } = | |
1812 | let val (a, ia, la) = makeRange (src, si, len) | |
1813 | val (b, ib, lb) = makeRange (dst, di, len) | |
1814 | fun copy i = | |
1815 | case i < 0 of | |
1816 | true => () | |
1817 | | false => (update(b, i+ib, sub(a, i+ia)); copy (i-1)) | |
1818 | in copy (la - 1) | |
1819 | end | |
1820 | val copyVec = copy | |
1821 | fun modifyi f range = | |
1822 | let val (a, start, len) = makeRange range | |
1823 | val last = start + len | |
1824 | fun loop i = | |
1825 | case i >= last of | |
1826 | true => () | |
1827 | | false => (update(a, i, f(i, sub(a,i))); loop (i+1)) | |
1828 | in loop start | |
1829 | end | |
1830 | fun modify f a = | |
1831 | let val last = length a | |
1832 | fun loop i = | |
1833 | case i >= last of | |
1834 | true => () | |
1835 | | false => (update(a, i, f(sub(a,i))); loop (i+1)) | |
1836 | in loop 0 | |
1837 | end | |
1838 | fun app f a = | |
1839 | let val size = length a | |
1840 | fun loop i = | |
1841 | case i = size of | |
1842 | true => () | |
1843 | | false => (f(sub(a,i)); loop (i+1)) | |
1844 | in | |
1845 | loop 0 | |
1846 | end | |
1847 | fun appi f range = | |
1848 | let val (a, start, len) = makeRange range | |
1849 | val last = start + len | |
1850 | fun loop i = | |
1851 | case i >= last of | |
1852 | true => () | |
1853 | | false => (f(i, sub(a,i)); loop (i+1)) | |
1854 | in | |
1855 | loop start | |
1856 | end | |
1857 | fun map f a = | |
1858 | let val len = length a | |
1859 | val c = zeroarray len | |
1860 | fun loop ~1 = c | |
1861 | | loop i = (update(a, i, f(sub(a,i))); loop (i-1)) | |
1862 | in loop (len-1) | |
1863 | end | |
1864 | fun map2 f a b = | |
1865 | let val len = length a | |
1866 | val c = zeroarray len | |
1867 | fun loop ~1 = c | |
1868 | | loop i = (update(c, i, f(sub(a,i),sub(b,i))); | |
1869 | loop (i-1)) | |
1870 | in loop (len-1) | |
1871 | end | |
1872 | fun mapi f range = | |
1873 | let val (a, start, len) = makeRange range | |
1874 | fun rule i = f (i+start, sub(a, i+start)) | |
1875 | in tabulate(len, rule) | |
1876 | end | |
1877 | fun foldli f init range = | |
1878 | let val (a, start, len) = makeRange range | |
1879 | val last = start + len - 1 | |
1880 | fun loop (i, accum) = | |
1881 | case i > last of | |
1882 | true => accum | |
1883 | | false => loop (i+1, f(i, sub(a,i), accum)) | |
1884 | in loop (start, init) | |
1885 | end | |
1886 | fun foldri f init range = | |
1887 | let val (a, start, len) = makeRange range | |
1888 | val last = start + len - 1 | |
1889 | fun loop (i, accum) = | |
1890 | case i < start of | |
1891 | true => accum | |
1892 | | false => loop (i-1, f(i, sub(a,i), accum)) | |
1893 | in loop (last, init) | |
1894 | end | |
1895 | fun foldl f init a = foldli (fn (_, a, x) => f(a,x)) init (a,0,NONE) | |
1896 | fun foldr f init a = foldri (fn (_, x, a) => f(x,a)) init (a,0,NONE) | |
1897 | end | |
1898 | end (* BasicCNumberArray *) | |
1899 | structure CNumberArray = | |
1900 | struct | |
1901 | structure Vector = | |
1902 | struct | |
1903 | open BasicCNumberArray | |
1904 | type vector = array | |
1905 | end : MONO_VECTOR | |
1906 | type vector = Vector.vector | |
1907 | open BasicCNumberArray | |
1908 | end (* CNumberArray *) | |
1909 | structure ITensor = | |
1910 | struct | |
1911 | structure Number = INumber | |
1912 | structure Array = INumberArray | |
1913 | (* | |
1914 | Copyright (c) Juan Jose Garcia Ripoll. | |
1915 | All rights reserved. | |
1916 | Refer to the COPYRIGHT file for license conditions | |
1917 | *) | |
1918 | structure MonoTensor = | |
1919 | struct | |
1920 | (* PARAMETERS | |
1921 | structure Array = Array | |
1922 | *) | |
1923 | structure Index = Index | |
1924 | type elem = Array.elem | |
1925 | type index = Index.t | |
1926 | type tensor = {shape : index, indexer : Index.indexer, data : Array.array} | |
1927 | type t = tensor | |
1928 | exception Shape | |
1929 | exception Match | |
1930 | exception Index | |
1931 | local | |
1932 | (*----- LOCALS -----*) | |
1933 | fun make' (shape, data) = | |
1934 | {shape = shape, indexer = Index.indexer shape, data = data} | |
1935 | fun toInt {shape, indexer, data} index = indexer index | |
1936 | fun splitList (l as (a::rest), place) = | |
1937 | let fun loop (left,here,right) 0 = (List.rev left,here,right) | |
1938 | | loop (_,_,[]) place = raise Index | |
1939 | | loop (left,here,a::right) place = | |
1940 | loop (here::left,a,right) (place-1) | |
1941 | in | |
1942 | if place <= 0 then | |
1943 | loop ([],a,rest) (List.length rest - place) | |
1944 | else | |
1945 | loop ([],a,rest) (place - 1) | |
1946 | end | |
1947 | in | |
1948 | (*----- STRUCTURAL OPERATIONS & QUERIES ------*) | |
1949 | fun new (shape, init) = | |
1950 | if not (Index.validShape shape) then | |
1951 | raise Shape | |
1952 | else | |
1953 | let val length = Index.length shape in | |
1954 | {shape = shape, | |
1955 | indexer = Index.indexer shape, | |
1956 | data = Array.array(length,init)} | |
1957 | end | |
1958 | fun toArray {shape, indexer, data} = data | |
1959 | fun length {shape, indexer, data} = Array.length data | |
1960 | fun shape {shape, indexer, data} = shape | |
1961 | fun rank t = List.length (shape t) | |
1962 | fun reshape new_shape tensor = | |
1963 | if Index.validShape new_shape then | |
1964 | case (Index.length new_shape) = length tensor of | |
1965 | true => make'(new_shape, toArray tensor) | |
1966 | | false => raise Match | |
1967 | else | |
1968 | raise Shape | |
1969 | fun fromArray (s, a) = | |
1970 | case Index.validShape s andalso | |
1971 | ((Index.length s) = (Array.length a)) of | |
1972 | true => make'(s, a) | |
1973 | | false => raise Shape | |
1974 | fun fromList (s, a) = fromArray (s, Array.fromList a) | |
1975 | fun tabulate (shape,f) = | |
1976 | if Index.validShape shape then | |
1977 | let val last = Index.last shape | |
1978 | val length = Index.length shape | |
1979 | val c = Array.array(length, f last) | |
1980 | fun dotable (c, indices, i) = | |
1981 | (Array.update(c, i, f indices); | |
1982 | if i <= 1 | |
1983 | then c | |
1984 | else dotable(c, Index.prev' shape indices, i-1)) | |
1985 | in make'(shape,dotable(c, Index.prev' shape last, length-2)) | |
1986 | end | |
1987 | else | |
1988 | raise Shape | |
1989 | (*----- ELEMENTWISE OPERATIONS -----*) | |
1990 | fun sub (t, index) = Array.sub(#data t, toInt t index) | |
1991 | fun update (t, index, value) = | |
1992 | Array.update(toArray t, toInt t index, value) | |
1993 | fun map f {shape, indexer, data} = | |
1994 | {shape = shape, indexer = indexer, data = Array.map f data} | |
1995 | fun map2 f t1 t2= | |
1996 | let val {shape=shape1, indexer=indexer1, data=data1} = t1 | |
1997 | val {shape=shape2, indexer=indexer2, data=data2} = t2 | |
1998 | in | |
1999 | if Index.eq(shape1,shape2) then | |
2000 | {shape = shape1, | |
2001 | indexer = indexer1, | |
2002 | data = Array.map2 f data1 data2} | |
2003 | else | |
2004 | raise Match | |
2005 | end | |
2006 | fun appi f tensor = Array.appi f (toArray tensor) | |
2007 | fun app f tensor = Array.app f (toArray tensor) | |
2008 | fun all f tensor = | |
2009 | let val a = toArray tensor | |
2010 | in Loop.all(0, length tensor - 1, fn i => | |
2011 | f (Array.sub(a, i))) | |
2012 | end | |
2013 | fun any f tensor = | |
2014 | let val a = toArray tensor | |
2015 | in Loop.any(0, length tensor - 1, fn i => | |
2016 | f (Array.sub(a, i))) | |
2017 | end | |
2018 | fun foldl f init tensor = Array.foldl f init (toArray tensor) | |
2019 | fun foldln f init {shape, indexer, data=a} index = | |
2020 | let val (head,lk,tail) = splitList(shape, index) | |
2021 | val li = Index.length head | |
2022 | val lj = Index.length tail | |
2023 | val c = Array.array(li * lj,init) | |
2024 | fun loopi (0, _, _) = () | |
2025 | | loopi (i, ia, ic) = | |
2026 | (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia))); | |
2027 | loopi (i-1, ia+1, ic+1)) | |
2028 | fun loopk (0, ia, _) = ia | |
2029 | | loopk (k, ia, ic) = (loopi (li, ia, ic); | |
2030 | loopk (k-1, ia+li, ic)) | |
2031 | fun loopj (0, _, _) = () | |
2032 | | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li) | |
2033 | in | |
2034 | loopj (lj, 0, 0); | |
2035 | make'(head @ tail, c) | |
2036 | end | |
2037 | (* --- POLYMORPHIC ELEMENTWISE OPERATIONS --- *) | |
2038 | fun array_map' f a = | |
2039 | let fun apply index = f(Array.sub(a,index)) in | |
2040 | Tensor.Array.tabulate(Array.length a, apply) | |
2041 | end | |
2042 | fun map' f t = Tensor.fromArray(shape t, array_map' f (toArray t)) | |
2043 | fun map2' f t1 t2 = | |
2044 | let val d1 = toArray t1 | |
2045 | val d2 = toArray t2 | |
2046 | fun apply i = f (Array.sub(d1,i), Array.sub(d2,i)) | |
2047 | val len = Array.length d1 | |
2048 | in | |
2049 | if Index.eq(shape t1, shape t2) then | |
2050 | Tensor.fromArray(shape t1, Tensor.Array.tabulate(len,apply)) | |
2051 | else | |
2052 | raise Match | |
2053 | end | |
2054 | fun foldl' f init {shape, indexer, data=a} index = | |
2055 | let val (head,lk,tail) = splitList(shape, index) | |
2056 | val li = Index.length head | |
2057 | val lj = Index.length tail | |
2058 | val c = Tensor.Array.array(li * lj,init) | |
2059 | fun loopi (0, _, _) = () | |
2060 | | loopi (i, ia, ic) = | |
2061 | (Tensor.Array.update(c,ic,f(Tensor.Array.sub(c,ic),Array.sub(a,ia))); | |
2062 | loopi (i-1, ia+1, ic+1)) | |
2063 | fun loopk (0, ia, _) = ia | |
2064 | | loopk (k, ia, ic) = (loopi (li, ia, ic); | |
2065 | loopk (k-1, ia+li, ic)) | |
2066 | fun loopj (0, _, _) = () | |
2067 | | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li) | |
2068 | in | |
2069 | loopj (lj, 0, 0); | |
2070 | make'(head @ tail, c) | |
2071 | end | |
2072 | end | |
2073 | end (* MonoTensor *) | |
2074 | open MonoTensor | |
2075 | local | |
2076 | (* | |
2077 | LEFT INDEX CONTRACTION: | |
2078 | a = a(i1,i2,...,in) | |
2079 | b = b(j1,j2,...,jn) | |
2080 | c = c(i2,...,in,j2,...,jn) | |
2081 | = sum(a(k,i2,...,jn)*b(k,j2,...jn)) forall k | |
2082 | MEANINGFUL VARIABLES: | |
2083 | lk = i1 = j1 | |
2084 | li = i2*...*in | |
2085 | lj = j2*...*jn | |
2086 | *) | |
2087 | fun do_fold_first a b c lk lj li = | |
2088 | let fun loopk (0, _, _, accum) = accum | |
2089 | | loopk (k, ia, ib, accum) = | |
2090 | let val delta = Number.*(Array.sub(a,ia),Array.sub(b,ib)) | |
2091 | in loopk (k-1, ia+1, ib+1, Number.+(delta,accum)) | |
2092 | end | |
2093 | fun loopj (0, ib, ic) = c | |
2094 | | loopj (j, ib, ic) = | |
2095 | let fun loopi (0, ia, ic) = ic | |
2096 | | loopi (i, ia, ic) = | |
2097 | (Array.update(c, ic, loopk(lk, ia, ib, Number.zero)); | |
2098 | loopi(i-1, ia+lk, ic+1)) | |
2099 | in | |
2100 | loopj(j-1, ib+lk, loopi(li, 0, ic)) | |
2101 | end | |
2102 | in loopj(lj, 0, 0) | |
2103 | end | |
2104 | in | |
2105 | fun +* ta tb = | |
2106 | let val (rank_a,lk::rest_a,a) = (rank ta, shape ta, toArray ta) | |
2107 | val (rank_b,lk2::rest_b,b) = (rank tb, shape tb, toArray tb) | |
2108 | in if not(lk = lk2) | |
2109 | then raise Match | |
2110 | else let val li = Index.length rest_a | |
2111 | val lj = Index.length rest_b | |
2112 | val c = Array.array(li*lj,Number.zero) | |
2113 | in fromArray(rest_a @ rest_b, | |
2114 | do_fold_first a b c lk li lj) | |
2115 | end | |
2116 | end | |
2117 | end | |
2118 | local | |
2119 | (* | |
2120 | LAST INDEX CONTRACTION: | |
2121 | a = a(i1,i2,...,in) | |
2122 | b = b(j1,j2,...,jn) | |
2123 | c = c(i2,...,in,j2,...,jn) | |
2124 | = sum(mult(a(i1,i2,...,k),b(j1,j2,...,k))) forall k | |
2125 | MEANINGFUL VARIABLES: | |
2126 | lk = in = jn | |
2127 | li = i1*...*i(n-1) | |
2128 | lj = j1*...*j(n-1) | |
2129 | *) | |
2130 | fun do_fold_last a b c lk lj li = | |
2131 | let fun loopi (0, ia, ic, fac) = () | |
2132 | | loopi (i, ia, ic, fac) = | |
2133 | let val old = Array.sub(c,ic) | |
2134 | val inc = Number.*(Array.sub(a,ia),fac) | |
2135 | in | |
2136 | Array.update(c,ic,Number.+(old,inc)); | |
2137 | loopi(i-1, ia+1, ic+1, fac) | |
2138 | end | |
2139 | fun loopj (j, ib, ic) = | |
2140 | let fun loopk (0, ia, ib) = () | |
2141 | | loopk (k, ia, ib) = | |
2142 | (loopi(li, ia, ic, Array.sub(b,ib)); | |
2143 | loopk(k-1, ia+li, ib+lj)) | |
2144 | in case j of | |
2145 | 0 => c | |
2146 | | _ => (loopk(lk, 0, ib); | |
2147 | loopj(j-1, ib+1, ic+li)) | |
2148 | end (* loopj *) | |
2149 | in | |
2150 | loopj(lj, 0, 0) | |
2151 | end | |
2152 | in | |
2153 | fun *+ ta tb = | |
2154 | let val (rank_a,shape_a,a) = (rank ta, shape ta, toArray ta) | |
2155 | val (rank_b,shape_b,b) = (rank tb, shape tb, toArray tb) | |
2156 | val (lk::rest_a) = List.rev shape_a | |
2157 | val (lk2::rest_b) = List.rev shape_b | |
2158 | in if not(lk = lk2) | |
2159 | then raise Match | |
2160 | else let val li = Index.length rest_a | |
2161 | val lj = Index.length rest_b | |
2162 | val c = Array.array(li*lj,Number.zero) | |
2163 | in fromArray(List.rev rest_a @ List.rev rest_b, | |
2164 | do_fold_last a b c lk li lj) | |
2165 | end | |
2166 | end | |
2167 | end | |
2168 | (* ALGEBRAIC OPERATIONS *) | |
2169 | infix ** | |
2170 | infix == | |
2171 | infix != | |
2172 | fun a + b = map2 Number.+ a b | |
2173 | fun a - b = map2 Number.- a b | |
2174 | fun a * b = map2 Number.* a b | |
2175 | fun a ** i = map (fn x => (Number.**(x,i))) a | |
2176 | fun ~ a = map Number.~ a | |
2177 | fun abs a = map Number.abs a | |
2178 | fun signum a = map Number.signum a | |
2179 | fun a == b = map2' Number.== a b | |
2180 | fun a != b = map2' Number.!= a b | |
2181 | fun toString a = raise Domain | |
2182 | fun fromInt a = new([1], Number.fromInt a) | |
2183 | (* TENSOR SPECIFIC OPERATIONS *) | |
2184 | fun *> n = map (fn x => Number.*(n,x)) | |
2185 | fun print t = | |
2186 | (PrettyPrint.intList (shape t); | |
2187 | TextIO.print "\n"; | |
2188 | PrettyPrint.sequence (length t) appi Number.toString t) | |
2189 | fun normInf a = | |
2190 | let fun accum (y,x) = Number.max(x,Number.abs y) | |
2191 | in foldl accum Number.zero a | |
2192 | end | |
2193 | end (* NumberTensor *) | |
2194 | structure RTensor = | |
2195 | struct | |
2196 | structure Number = RNumber | |
2197 | structure Array = RNumberArray | |
2198 | (* | |
2199 | Copyright (c) Juan Jose Garcia Ripoll. | |
2200 | All rights reserved. | |
2201 | Refer to the COPYRIGHT file for license conditions | |
2202 | *) | |
2203 | structure MonoTensor = | |
2204 | struct | |
2205 | (* PARAMETERS | |
2206 | structure Array = Array | |
2207 | *) | |
2208 | structure Index = Index | |
2209 | type elem = Array.elem | |
2210 | type index = Index.t | |
2211 | type tensor = {shape : index, indexer : Index.indexer, data : Array.array} | |
2212 | type t = tensor | |
2213 | exception Shape | |
2214 | exception Match | |
2215 | exception Index | |
2216 | local | |
2217 | (*----- LOCALS -----*) | |
2218 | fun make' (shape, data) = | |
2219 | {shape = shape, indexer = Index.indexer shape, data = data} | |
2220 | fun toInt {shape, indexer, data} index = indexer index | |
2221 | fun splitList (l as (a::rest), place) = | |
2222 | let fun loop (left,here,right) 0 = (List.rev left,here,right) | |
2223 | | loop (_,_,[]) place = raise Index | |
2224 | | loop (left,here,a::right) place = | |
2225 | loop (here::left,a,right) (place-1) | |
2226 | in | |
2227 | if place <= 0 then | |
2228 | loop ([],a,rest) (List.length rest - place) | |
2229 | else | |
2230 | loop ([],a,rest) (place - 1) | |
2231 | end | |
2232 | in | |
2233 | (*----- STRUCTURAL OPERATIONS & QUERIES ------*) | |
2234 | fun new (shape, init) = | |
2235 | if not (Index.validShape shape) then | |
2236 | raise Shape | |
2237 | else | |
2238 | let val length = Index.length shape in | |
2239 | {shape = shape, | |
2240 | indexer = Index.indexer shape, | |
2241 | data = Array.array(length,init)} | |
2242 | end | |
2243 | fun toArray {shape, indexer, data} = data | |
2244 | fun length {shape, indexer, data} = Array.length data | |
2245 | fun shape {shape, indexer, data} = shape | |
2246 | fun rank t = List.length (shape t) | |
2247 | fun reshape new_shape tensor = | |
2248 | if Index.validShape new_shape then | |
2249 | case (Index.length new_shape) = length tensor of | |
2250 | true => make'(new_shape, toArray tensor) | |
2251 | | false => raise Match | |
2252 | else | |
2253 | raise Shape | |
2254 | fun fromArray (s, a) = | |
2255 | case Index.validShape s andalso | |
2256 | ((Index.length s) = (Array.length a)) of | |
2257 | true => make'(s, a) | |
2258 | | false => raise Shape | |
2259 | fun fromList (s, a) = fromArray (s, Array.fromList a) | |
2260 | fun tabulate (shape,f) = | |
2261 | if Index.validShape shape then | |
2262 | let val last = Index.last shape | |
2263 | val length = Index.length shape | |
2264 | val c = Array.array(length, f last) | |
2265 | fun dotable (c, indices, i) = | |
2266 | (Array.update(c, i, f indices); | |
2267 | if i <= 1 | |
2268 | then c | |
2269 | else dotable(c, Index.prev' shape indices, i-1)) | |
2270 | in make'(shape,dotable(c, Index.prev' shape last, length-2)) | |
2271 | end | |
2272 | else | |
2273 | raise Shape | |
2274 | (*----- ELEMENTWISE OPERATIONS -----*) | |
2275 | fun sub (t, index) = Array.sub(#data t, toInt t index) | |
2276 | fun update (t, index, value) = | |
2277 | Array.update(toArray t, toInt t index, value) | |
2278 | fun map f {shape, indexer, data} = | |
2279 | {shape = shape, indexer = indexer, data = Array.map f data} | |
2280 | fun map2 f t1 t2= | |
2281 | let val {shape=shape1, indexer=indexer1, data=data1} = t1 | |
2282 | val {shape=shape2, indexer=indexer2, data=data2} = t2 | |
2283 | in | |
2284 | if Index.eq(shape1,shape2) then | |
2285 | {shape = shape1, | |
2286 | indexer = indexer1, | |
2287 | data = Array.map2 f data1 data2} | |
2288 | else | |
2289 | raise Match | |
2290 | end | |
2291 | fun appi f tensor = Array.appi f (toArray tensor) | |
2292 | fun app f tensor = Array.app f (toArray tensor) | |
2293 | fun all f tensor = | |
2294 | let val a = toArray tensor | |
2295 | in Loop.all(0, length tensor - 1, fn i => | |
2296 | f (Array.sub(a, i))) | |
2297 | end | |
2298 | fun any f tensor = | |
2299 | let val a = toArray tensor | |
2300 | in Loop.any(0, length tensor - 1, fn i => | |
2301 | f (Array.sub(a, i))) | |
2302 | end | |
2303 | fun foldl f init tensor = Array.foldl f init (toArray tensor) | |
2304 | fun foldln f init {shape, indexer, data=a} index = | |
2305 | let val (head,lk,tail) = splitList(shape, index) | |
2306 | val li = Index.length head | |
2307 | val lj = Index.length tail | |
2308 | val c = Array.array(li * lj,init) | |
2309 | fun loopi (0, _, _) = () | |
2310 | | loopi (i, ia, ic) = | |
2311 | (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia))); | |
2312 | loopi (i-1, ia+1, ic+1)) | |
2313 | fun loopk (0, ia, _) = ia | |
2314 | | loopk (k, ia, ic) = (loopi (li, ia, ic); | |
2315 | loopk (k-1, ia+li, ic)) | |
2316 | fun loopj (0, _, _) = () | |
2317 | | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li) | |
2318 | in | |
2319 | loopj (lj, 0, 0); | |
2320 | make'(head @ tail, c) | |
2321 | end | |
2322 | (* --- POLYMORPHIC ELEMENTWISE OPERATIONS --- *) | |
2323 | fun array_map' f a = | |
2324 | let fun apply index = f(Array.sub(a,index)) in | |
2325 | Tensor.Array.tabulate(Array.length a, apply) | |
2326 | end | |
2327 | fun map' f t = Tensor.fromArray(shape t, array_map' f (toArray t)) | |
2328 | fun map2' f t1 t2 = | |
2329 | let val d1 = toArray t1 | |
2330 | val d2 = toArray t2 | |
2331 | fun apply i = f (Array.sub(d1,i), Array.sub(d2,i)) | |
2332 | val len = Array.length d1 | |
2333 | in | |
2334 | if Index.eq(shape t1, shape t2) then | |
2335 | Tensor.fromArray(shape t1, Tensor.Array.tabulate(len,apply)) | |
2336 | else | |
2337 | raise Match | |
2338 | end | |
2339 | fun foldl' f init {shape, indexer, data=a} index = | |
2340 | let val (head,lk,tail) = splitList(shape, index) | |
2341 | val li = Index.length head | |
2342 | val lj = Index.length tail | |
2343 | val c = Tensor.Array.array(li * lj,init) | |
2344 | fun loopi (0, _, _) = () | |
2345 | | loopi (i, ia, ic) = | |
2346 | (Tensor.Array.update(c,ic,f(Tensor.Array.sub(c,ic),Array.sub(a,ia))); | |
2347 | loopi (i-1, ia+1, ic+1)) | |
2348 | fun loopk (0, ia, _) = ia | |
2349 | | loopk (k, ia, ic) = (loopi (li, ia, ic); | |
2350 | loopk (k-1, ia+li, ic)) | |
2351 | fun loopj (0, _, _) = () | |
2352 | | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li) | |
2353 | in | |
2354 | loopj (lj, 0, 0); | |
2355 | make'(head @ tail, c) | |
2356 | end | |
2357 | end | |
2358 | end (* MonoTensor *) | |
2359 | open MonoTensor | |
2360 | local | |
2361 | (* | |
2362 | LEFT INDEX CONTRACTION: | |
2363 | a = a(i1,i2,...,in) | |
2364 | b = b(j1,j2,...,jn) | |
2365 | c = c(i2,...,in,j2,...,jn) | |
2366 | = sum(a(k,i2,...,jn)*b(k,j2,...jn)) forall k | |
2367 | MEANINGFUL VARIABLES: | |
2368 | lk = i1 = j1 | |
2369 | li = i2*...*in | |
2370 | lj = j2*...*jn | |
2371 | *) | |
2372 | fun do_fold_first a b c lk lj li = | |
2373 | let fun loopk (0, _, _, accum) = accum | |
2374 | | loopk (k, ia, ib, accum) = | |
2375 | let val delta = Number.*(Array.sub(a,ia),Array.sub(b,ib)) | |
2376 | in loopk (k-1, ia+1, ib+1, Number.+(delta,accum)) | |
2377 | end | |
2378 | fun loopj (0, ib, ic) = c | |
2379 | | loopj (j, ib, ic) = | |
2380 | let fun loopi (0, ia, ic) = ic | |
2381 | | loopi (i, ia, ic) = | |
2382 | (Array.update(c, ic, loopk(lk, ia, ib, Number.zero)); | |
2383 | loopi(i-1, ia+lk, ic+1)) | |
2384 | in | |
2385 | loopj(j-1, ib+lk, loopi(li, 0, ic)) | |
2386 | end | |
2387 | in loopj(lj, 0, 0) | |
2388 | end | |
2389 | in | |
2390 | fun +* ta tb = | |
2391 | let val (rank_a,lk::rest_a,a) = (rank ta, shape ta, toArray ta) | |
2392 | val (rank_b,lk2::rest_b,b) = (rank tb, shape tb, toArray tb) | |
2393 | in if not(lk = lk2) | |
2394 | then raise Match | |
2395 | else let val li = Index.length rest_a | |
2396 | val lj = Index.length rest_b | |
2397 | val c = Array.array(li*lj,Number.zero) | |
2398 | in fromArray(rest_a @ rest_b, | |
2399 | do_fold_first a b c lk li lj) | |
2400 | end | |
2401 | end | |
2402 | end | |
2403 | local | |
2404 | (* | |
2405 | LAST INDEX CONTRACTION: | |
2406 | a = a(i1,i2,...,in) | |
2407 | b = b(j1,j2,...,jn) | |
2408 | c = c(i2,...,in,j2,...,jn) | |
2409 | = sum(mult(a(i1,i2,...,k),b(j1,j2,...,k))) forall k | |
2410 | MEANINGFUL VARIABLES: | |
2411 | lk = in = jn | |
2412 | li = i1*...*i(n-1) | |
2413 | lj = j1*...*j(n-1) | |
2414 | *) | |
2415 | fun do_fold_last a b c lk lj li = | |
2416 | let fun loopi (0, ia, ic, fac) = () | |
2417 | | loopi (i, ia, ic, fac) = | |
2418 | let val old = Array.sub(c,ic) | |
2419 | val inc = Number.*(Array.sub(a,ia),fac) | |
2420 | in | |
2421 | Array.update(c,ic,Number.+(old,inc)); | |
2422 | loopi(i-1, ia+1, ic+1, fac) | |
2423 | end | |
2424 | fun loopj (j, ib, ic) = | |
2425 | let fun loopk (0, ia, ib) = () | |
2426 | | loopk (k, ia, ib) = | |
2427 | (loopi(li, ia, ic, Array.sub(b,ib)); | |
2428 | loopk(k-1, ia+li, ib+lj)) | |
2429 | in case j of | |
2430 | 0 => c | |
2431 | | _ => (loopk(lk, 0, ib); | |
2432 | loopj(j-1, ib+1, ic+li)) | |
2433 | end (* loopj *) | |
2434 | in | |
2435 | loopj(lj, 0, 0) | |
2436 | end | |
2437 | in | |
2438 | fun *+ ta tb = | |
2439 | let val (rank_a,shape_a,a) = (rank ta, shape ta, toArray ta) | |
2440 | val (rank_b,shape_b,b) = (rank tb, shape tb, toArray tb) | |
2441 | val (lk::rest_a) = List.rev shape_a | |
2442 | val (lk2::rest_b) = List.rev shape_b | |
2443 | in if not(lk = lk2) | |
2444 | then raise Match | |
2445 | else let val li = Index.length rest_a | |
2446 | val lj = Index.length rest_b | |
2447 | val c = Array.array(li*lj,Number.zero) | |
2448 | in fromArray(List.rev rest_a @ List.rev rest_b, | |
2449 | do_fold_last a b c lk li lj) | |
2450 | end | |
2451 | end | |
2452 | end | |
2453 | (* ALGEBRAIC OPERATIONS *) | |
2454 | infix ** | |
2455 | infix == | |
2456 | infix != | |
2457 | fun a + b = map2 Number.+ a b | |
2458 | fun a - b = map2 Number.- a b | |
2459 | fun a * b = map2 Number.* a b | |
2460 | fun a ** i = map (fn x => (Number.**(x,i))) a | |
2461 | fun ~ a = map Number.~ a | |
2462 | fun abs a = map Number.abs a | |
2463 | fun signum a = map Number.signum a | |
2464 | fun a == b = map2' Number.== a b | |
2465 | fun a != b = map2' Number.!= a b | |
2466 | fun toString a = raise Domain | |
2467 | fun fromInt a = new([1], Number.fromInt a) | |
2468 | (* TENSOR SPECIFIC OPERATIONS *) | |
2469 | fun *> n = map (fn x => Number.*(n,x)) | |
2470 | fun print t = | |
2471 | (PrettyPrint.intList (shape t); | |
2472 | TextIO.print "\n"; | |
2473 | PrettyPrint.sequence (length t) appi Number.toString t) | |
2474 | fun a / b = map2 Number./ a b | |
2475 | fun recip a = map Number.recip a | |
2476 | fun ln a = map Number.ln a | |
2477 | fun pow (a, b) = map (fn x => (Number.pow(x,b))) a | |
2478 | fun exp a = map Number.exp a | |
2479 | fun sqrt a = map Number.sqrt a | |
2480 | fun cos a = map Number.cos a | |
2481 | fun sin a = map Number.sin a | |
2482 | fun tan a = map Number.tan a | |
2483 | fun sinh a = map Number.sinh a | |
2484 | fun cosh a = map Number.cosh a | |
2485 | fun tanh a = map Number.tanh a | |
2486 | fun asin a = map Number.asin a | |
2487 | fun acos a = map Number.acos a | |
2488 | fun atan a = map Number.atan a | |
2489 | fun asinh a = map Number.asinh a | |
2490 | fun acosh a = map Number.acosh a | |
2491 | fun atanh a = map Number.atanh a | |
2492 | fun atan2 (a,b) = map2 Number.atan2 a b | |
2493 | fun normInf a = | |
2494 | let fun accum (y,x) = Number.max(x,Number.abs y) | |
2495 | in foldl accum Number.zero a | |
2496 | end | |
2497 | fun norm1 a = | |
2498 | let fun accum (y,x) = Number.+(x,Number.abs y) | |
2499 | in foldl accum Number.zero a | |
2500 | end | |
2501 | fun norm2 a = | |
2502 | let fun accum (y,x) = Number.+(x, Number.*(y,y)) | |
2503 | in Number.sqrt(foldl accum Number.zero a) | |
2504 | end | |
2505 | end (* RTensor *) | |
2506 | structure CTensor = | |
2507 | struct | |
2508 | structure Number = CNumber | |
2509 | structure Array = CNumberArray | |
2510 | (* | |
2511 | Copyright (c) Juan Jose Garcia Ripoll. | |
2512 | All rights reserved. | |
2513 | Refer to the COPYRIGHT file for license conditions | |
2514 | *) | |
2515 | structure MonoTensor = | |
2516 | struct | |
2517 | (* PARAMETERS | |
2518 | structure Array = Array | |
2519 | *) | |
2520 | structure Index = Index | |
2521 | type elem = Array.elem | |
2522 | type index = Index.t | |
2523 | type tensor = {shape : index, indexer : Index.indexer, data : Array.array} | |
2524 | type t = tensor | |
2525 | exception Shape | |
2526 | exception Match | |
2527 | exception Index | |
2528 | local | |
2529 | (*----- LOCALS -----*) | |
2530 | fun make' (shape, data) = | |
2531 | {shape = shape, indexer = Index.indexer shape, data = data} | |
2532 | fun toInt {shape, indexer, data} index = indexer index | |
2533 | fun splitList (l as (a::rest), place) = | |
2534 | let fun loop (left,here,right) 0 = (List.rev left,here,right) | |
2535 | | loop (_,_,[]) place = raise Index | |
2536 | | loop (left,here,a::right) place = | |
2537 | loop (here::left,a,right) (place-1) | |
2538 | in | |
2539 | if place <= 0 then | |
2540 | loop ([],a,rest) (List.length rest - place) | |
2541 | else | |
2542 | loop ([],a,rest) (place - 1) | |
2543 | end | |
2544 | in | |
2545 | (*----- STRUCTURAL OPERATIONS & QUERIES ------*) | |
2546 | fun new (shape, init) = | |
2547 | if not (Index.validShape shape) then | |
2548 | raise Shape | |
2549 | else | |
2550 | let val length = Index.length shape in | |
2551 | {shape = shape, | |
2552 | indexer = Index.indexer shape, | |
2553 | data = Array.array(length,init)} | |
2554 | end | |
2555 | fun toArray {shape, indexer, data} = data | |
2556 | fun length {shape, indexer, data} = Array.length data | |
2557 | fun shape {shape, indexer, data} = shape | |
2558 | fun rank t = List.length (shape t) | |
2559 | fun reshape new_shape tensor = | |
2560 | if Index.validShape new_shape then | |
2561 | case (Index.length new_shape) = length tensor of | |
2562 | true => make'(new_shape, toArray tensor) | |
2563 | | false => raise Match | |
2564 | else | |
2565 | raise Shape | |
2566 | fun fromArray (s, a) = | |
2567 | case Index.validShape s andalso | |
2568 | ((Index.length s) = (Array.length a)) of | |
2569 | true => make'(s, a) | |
2570 | | false => raise Shape | |
2571 | fun fromList (s, a) = fromArray (s, Array.fromList a) | |
2572 | fun tabulate (shape,f) = | |
2573 | if Index.validShape shape then | |
2574 | let val last = Index.last shape | |
2575 | val length = Index.length shape | |
2576 | val c = Array.array(length, f last) | |
2577 | fun dotable (c, indices, i) = | |
2578 | (Array.update(c, i, f indices); | |
2579 | if i <= 1 | |
2580 | then c | |
2581 | else dotable(c, Index.prev' shape indices, i-1)) | |
2582 | in make'(shape,dotable(c, Index.prev' shape last, length-2)) | |
2583 | end | |
2584 | else | |
2585 | raise Shape | |
2586 | (*----- ELEMENTWISE OPERATIONS -----*) | |
2587 | fun sub (t, index) = Array.sub(#data t, toInt t index) | |
2588 | fun update (t, index, value) = | |
2589 | Array.update(toArray t, toInt t index, value) | |
2590 | fun map f {shape, indexer, data} = | |
2591 | {shape = shape, indexer = indexer, data = Array.map f data} | |
2592 | fun map2 f t1 t2= | |
2593 | let val {shape=shape1, indexer=indexer1, data=data1} = t1 | |
2594 | val {shape=shape2, indexer=indexer2, data=data2} = t2 | |
2595 | in | |
2596 | if Index.eq(shape1,shape2) then | |
2597 | {shape = shape1, | |
2598 | indexer = indexer1, | |
2599 | data = Array.map2 f data1 data2} | |
2600 | else | |
2601 | raise Match | |
2602 | end | |
2603 | fun appi f tensor = Array.appi f (toArray tensor, 0, NONE) | |
2604 | fun app f tensor = Array.app f (toArray tensor) | |
2605 | fun all f tensor = | |
2606 | let val a = toArray tensor | |
2607 | in Loop.all(0, length tensor - 1, fn i => | |
2608 | f (Array.sub(a, i))) | |
2609 | end | |
2610 | fun any f tensor = | |
2611 | let val a = toArray tensor | |
2612 | in Loop.any(0, length tensor - 1, fn i => | |
2613 | f (Array.sub(a, i))) | |
2614 | end | |
2615 | fun foldl f init tensor = Array.foldl f init (toArray tensor) | |
2616 | fun foldln f init {shape, indexer, data=a} index = | |
2617 | let val (head,lk,tail) = splitList(shape, index) | |
2618 | val li = Index.length head | |
2619 | val lj = Index.length tail | |
2620 | val c = Array.array(li * lj,init) | |
2621 | fun loopi (0, _, _) = () | |
2622 | | loopi (i, ia, ic) = | |
2623 | (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia))); | |
2624 | loopi (i-1, ia+1, ic+1)) | |
2625 | fun loopk (0, ia, _) = ia | |
2626 | | loopk (k, ia, ic) = (loopi (li, ia, ic); | |
2627 | loopk (k-1, ia+li, ic)) | |
2628 | fun loopj (0, _, _) = () | |
2629 | | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li) | |
2630 | in | |
2631 | loopj (lj, 0, 0); | |
2632 | make'(head @ tail, c) | |
2633 | end | |
2634 | (* --- POLYMORPHIC ELEMENTWISE OPERATIONS --- *) | |
2635 | fun array_map' f a = | |
2636 | let fun apply index = f(Array.sub(a,index)) in | |
2637 | Tensor.Array.tabulate(Array.length a, apply) | |
2638 | end | |
2639 | fun map' f t = Tensor.fromArray(shape t, array_map' f (toArray t)) | |
2640 | fun map2' f t1 t2 = | |
2641 | let val d1 = toArray t1 | |
2642 | val d2 = toArray t2 | |
2643 | fun apply i = f (Array.sub(d1,i), Array.sub(d2,i)) | |
2644 | val len = Array.length d1 | |
2645 | in | |
2646 | if Index.eq(shape t1, shape t2) then | |
2647 | Tensor.fromArray(shape t1, Tensor.Array.tabulate(len,apply)) | |
2648 | else | |
2649 | raise Match | |
2650 | end | |
2651 | fun foldl' f init {shape, indexer, data=a} index = | |
2652 | let val (head,lk,tail) = splitList(shape, index) | |
2653 | val li = Index.length head | |
2654 | val lj = Index.length tail | |
2655 | val c = Tensor.Array.array(li * lj,init) | |
2656 | fun loopi (0, _, _) = () | |
2657 | | loopi (i, ia, ic) = | |
2658 | (Tensor.Array.update(c,ic,f(Tensor.Array.sub(c,ic),Array.sub(a,ia))); | |
2659 | loopi (i-1, ia+1, ic+1)) | |
2660 | fun loopk (0, ia, _) = ia | |
2661 | | loopk (k, ia, ic) = (loopi (li, ia, ic); | |
2662 | loopk (k-1, ia+li, ic)) | |
2663 | fun loopj (0, _, _) = () | |
2664 | | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li) | |
2665 | in | |
2666 | loopj (lj, 0, 0); | |
2667 | make'(head @ tail, c) | |
2668 | end | |
2669 | end | |
2670 | end (* MonoTensor *) | |
2671 | open MonoTensor | |
2672 | local | |
2673 | (* | |
2674 | LEFT INDEX CONTRACTION: | |
2675 | a = a(i1,i2,...,in) | |
2676 | b = b(j1,j2,...,jn) | |
2677 | c = c(i2,...,in,j2,...,jn) | |
2678 | = sum(a(k,i2,...,jn)*b(k,j2,...jn)) forall k | |
2679 | MEANINGFUL VARIABLES: | |
2680 | lk = i1 = j1 | |
2681 | li = i2*...*in | |
2682 | lj = j2*...*jn | |
2683 | *) | |
2684 | fun do_fold_first a b c lk lj li = | |
2685 | let fun loopk (0, _, _, r, i) = Number.make(r,i) | |
2686 | | loopk (k, ia, ib, r, i) = | |
2687 | let val (ar, ai) = Array.sub(a,ia) | |
2688 | val (br, bi) = Array.sub(b,ib) | |
2689 | val dr = ar * br - ai * bi | |
2690 | val di = ar * bi + ai * br | |
2691 | in loopk (k-1, ia+1, ib+1, r+dr, i+di) | |
2692 | end | |
2693 | fun loopj (0, ib, ic) = c | |
2694 | | loopj (j, ib, ic) = | |
2695 | let fun loopi (0, ia, ic) = ic | |
2696 | | loopi (i, ia, ic) = | |
2697 | (Array.update(c, ic, loopk(lk, ia, ib, RNumber.zero, RNumber.zero)); | |
2698 | loopi(i-1, ia+lk, ic+1)) | |
2699 | in loopj(j-1, ib+lk, loopi(li, 0, ic)) | |
2700 | end | |
2701 | in loopj(lj, 0, 0) | |
2702 | end | |
2703 | in | |
2704 | fun +* ta tb = | |
2705 | let val (rank_a,lk::rest_a,a) = (rank ta, shape ta, toArray ta) | |
2706 | val (rank_b,lk2::rest_b,b) = (rank tb, shape tb, toArray tb) | |
2707 | in if not(lk = lk2) | |
2708 | then raise Match | |
2709 | else let val li = Index.length rest_a | |
2710 | val lj = Index.length rest_b | |
2711 | val c = Array.array(li*lj,Number.zero) | |
2712 | in fromArray(rest_a @ rest_b, do_fold_first a b c lk li lj) | |
2713 | end | |
2714 | end | |
2715 | end | |
2716 | local | |
2717 | (* | |
2718 | LAST INDEX CONTRACTION: | |
2719 | a = a(i1,i2,...,in) | |
2720 | b = b(j1,j2,...,jn) | |
2721 | c = c(i2,...,in,j2,...,jn) | |
2722 | = sum(mult(a(i1,i2,...,k),b(j1,j2,...,k))) forall k | |
2723 | MEANINGFUL VARIABLES: | |
2724 | lk = in = jn | |
2725 | li = i1*...*i(n-1) | |
2726 | lj = j1*...*j(n-1) | |
2727 | *) | |
2728 | fun do_fold_last a b c lk lj li = | |
2729 | let fun loopi(0, _, _, _, _) = () | |
2730 | | loopi(i, ia, ic, br, bi) = | |
2731 | let val (cr,ci) = Array.sub(c,ic) | |
2732 | val (ar,ai) = Array.sub(a,ia) | |
2733 | val dr = (ar * br - ai * bi) | |
2734 | val di = (ar * bi + ai * br) | |
2735 | in | |
2736 | Array.update(c,ic,Number.make(cr+dr,ci+di)); | |
2737 | loopi(i-1, ia+1, ic+1, br, bi) | |
2738 | end | |
2739 | fun loopj(j, ib, ic) = | |
2740 | let fun loopk(0, _, _) = () | |
2741 | | loopk(k, ia, ib) = | |
2742 | let val (br, bi) = Array.sub(b,ib) | |
2743 | in | |
2744 | loopi(li, ia, ic, br, bi); | |
2745 | loopk(k-1, ia+li, ib+lj) | |
2746 | end | |
2747 | in case j of | |
2748 | 0 => c | |
2749 | | _ => (loopk(lk, 0, ib); | |
2750 | loopj(j-1, ib+1, ic+li)) | |
2751 | end (* loopj *) | |
2752 | in | |
2753 | loopj(lj, 0, 0) | |
2754 | end | |
2755 | in | |
2756 | fun *+ ta tb = | |
2757 | let val (rank_a,shape_a,a) = (rank ta, shape ta, toArray ta) | |
2758 | val (rank_b,shape_b,b) = (rank tb, shape tb, toArray tb) | |
2759 | val (lk::rest_a) = List.rev shape_a | |
2760 | val (lk2::rest_b) = List.rev shape_b | |
2761 | in | |
2762 | if not(lk = lk2) then | |
2763 | raise Match | |
2764 | else | |
2765 | let val li = Index.length rest_a | |
2766 | val lj = Index.length rest_b | |
2767 | val c = Array.array(li*lj,Number.zero) | |
2768 | in | |
2769 | fromArray(List.rev rest_a @ List.rev rest_b, | |
2770 | do_fold_last a b c lk li lj) | |
2771 | end | |
2772 | end | |
2773 | end | |
2774 | (* ALGEBRAIC OPERATIONS *) | |
2775 | infix ** | |
2776 | infix == | |
2777 | infix != | |
2778 | fun a + b = map2 Number.+ a b | |
2779 | fun a - b = map2 Number.- a b | |
2780 | fun a * b = map2 Number.* a b | |
2781 | fun a ** i = map (fn x => (Number.**(x,i))) a | |
2782 | fun ~ a = map Number.~ a | |
2783 | fun abs a = map Number.abs a | |
2784 | fun signum a = map Number.signum a | |
2785 | fun a == b = map2' Number.== a b | |
2786 | fun a != b = map2' Number.!= a b | |
2787 | fun toString a = raise Domain | |
2788 | fun fromInt a = new([1], Number.fromInt a) | |
2789 | (* TENSOR SPECIFIC OPERATIONS *) | |
2790 | fun *> n = map (fn x => Number.*(n,x)) | |
2791 | fun print t = | |
2792 | (PrettyPrint.intList (shape t); | |
2793 | TextIO.print "\n"; | |
2794 | PrettyPrint.sequence (length t) appi Number.toString t) | |
2795 | fun a / b = map2 Number./ a b | |
2796 | fun recip a = map Number.recip a | |
2797 | fun ln a = map Number.ln a | |
2798 | fun pow (a, b) = map (fn x => (Number.pow(x,b))) a | |
2799 | fun exp a = map Number.exp a | |
2800 | fun sqrt a = map Number.sqrt a | |
2801 | fun cos a = map Number.cos a | |
2802 | fun sin a = map Number.sin a | |
2803 | fun tan a = map Number.tan a | |
2804 | fun sinh a = map Number.sinh a | |
2805 | fun cosh a = map Number.cosh a | |
2806 | fun tanh a = map Number.tanh a | |
2807 | fun asin a = map Number.asin a | |
2808 | fun acos a = map Number.acos a | |
2809 | fun atan a = map Number.atan a | |
2810 | fun asinh a = map Number.asinh a | |
2811 | fun acosh a = map Number.acosh a | |
2812 | fun atanh a = map Number.atanh a | |
2813 | fun atan2 (a,b) = map2 Number.atan2 a b | |
2814 | fun normInf a = | |
2815 | let fun accum (y,x) = RNumber.max(x, Number.realPart(Number.abs y)) | |
2816 | in foldl accum RNumber.zero a | |
2817 | end | |
2818 | fun norm1 a = | |
2819 | let fun accum (y,x) = RNumber.+(x, Number.realPart(Number.abs y)) | |
2820 | in foldl accum RNumber.zero a | |
2821 | end | |
2822 | fun norm2 a = | |
2823 | let fun accum (y,x) = RNumber.+(x, Number.abs2 y) | |
2824 | in RNumber.sqrt(foldl accum RNumber.zero a) | |
2825 | end | |
2826 | end (* CTensor *) | |
2827 | structure MathFile = | |
2828 | struct | |
2829 | ||
2830 | type file = TextIO.instream | |
2831 | ||
2832 | exception Data | |
2833 | ||
2834 | fun assert NONE = raise Data | |
2835 | | assert (SOME a) = a | |
2836 | ||
2837 | (* ------------------ INPUT --------------------- *) | |
2838 | ||
2839 | fun intRead file = assert(TextIO.scanStream INumber.scan file) | |
2840 | fun realRead file = assert(TextIO.scanStream RNumber.scan file) | |
2841 | fun complexRead file = assert(TextIO.scanStream CNumber.scan file) | |
2842 | ||
2843 | fun listRead eltScan file = | |
2844 | let val length = intRead file | |
2845 | fun eltRead file = assert(TextIO.scanStream eltScan file) | |
2846 | fun loop (0,accum) = accum | |
2847 | | loop (i,accum) = loop(i-1, eltRead file :: accum) | |
2848 | in | |
2849 | if length < 0 | |
2850 | then raise Data | |
2851 | else List.rev(loop(length,[])) | |
2852 | end | |
2853 | ||
2854 | fun intListRead file = listRead INumber.scan file | |
2855 | fun realListRead file = listRead RNumber.scan file | |
2856 | fun complexListRead file = listRead CNumber.scan file | |
2857 | ||
2858 | fun intTensorRead file = | |
2859 | let val shape = intListRead file | |
2860 | val length = Index.length shape | |
2861 | val first = intRead file | |
2862 | val a = ITensor.Array.array(length, first) | |
2863 | fun loop 0 = ITensor.fromArray(shape, a) | |
2864 | | loop j = (ITensor.Array.update(a, length-j, intRead file); | |
2865 | loop (j-1)) | |
2866 | in loop (length - 1) | |
2867 | end | |
2868 | ||
2869 | fun realTensorRead file = | |
2870 | let val shape = intListRead file | |
2871 | val length = Index.length shape | |
2872 | val first = realRead file | |
2873 | val a = RTensor.Array.array(length, first) | |
2874 | fun loop 0 = RTensor.fromArray(shape, a) | |
2875 | | loop j = (RTensor.Array.update(a, length-j, realRead file); | |
2876 | loop (j-1)) | |
2877 | in loop (length - 1) | |
2878 | end | |
2879 | ||
2880 | fun complexTensorRead file = | |
2881 | let val shape = intListRead file | |
2882 | val length = Index.length shape | |
2883 | val first = complexRead file | |
2884 | val a = CTensor.Array.array(length, first) | |
2885 | fun loop j = if j = length | |
2886 | then CTensor.fromArray(shape, a) | |
2887 | else (CTensor.Array.update(a, j, complexRead file); | |
2888 | loop (j+1)) | |
2889 | in loop 1 | |
2890 | end | |
2891 | ||
2892 | (* ------------------ OUTPUT -------------------- *) | |
2893 | fun linedOutput(file, x) = (TextIO.output(file, x); TextIO.output(file, "\n")) | |
2894 | ||
2895 | fun intWrite file x = linedOutput(file, INumber.toString x) | |
2896 | fun realWrite file x = linedOutput(file, RNumber.toString x) | |
2897 | fun complexWrite file x = | |
2898 | let val (r,i) = CNumber.split x | |
2899 | in linedOutput(file, concat [RNumber.toString r, " ", RNumber.toString i]) | |
2900 | end | |
2901 | ||
2902 | fun listWrite converter file x = | |
2903 | (intWrite file (length x); | |
2904 | List.app (fn x => (linedOutput(file, converter x))) x) | |
2905 | ||
2906 | fun intListWrite file x = listWrite INumber.toString file x | |
2907 | fun realListWrite file x = listWrite RNumber.toString file x | |
2908 | fun complexListWrite file x = listWrite CNumber.toString file x | |
2909 | ||
2910 | fun intTensorWrite file x = (intListWrite file (ITensor.shape x); ITensor.app (fn x => (intWrite file x)) x) | |
2911 | fun realTensorWrite file x = (intListWrite file (RTensor.shape x); RTensor.app (fn x => (realWrite file x)) x) | |
2912 | fun complexTensorWrite file x = (intListWrite file (CTensor.shape x); CTensor.app (fn x => (complexWrite file x)) x) | |
2913 | end | |
2914 | ||
2915 | fun loop 0 _ = () | |
2916 | | loop n f = (f(); loop (n-1) f) | |
2917 | ||
2918 | fun test_operator new list_op list_sizes = | |
2919 | let fun test_many list_op size = | |
2920 | let fun test_op (times,f) = | |
2921 | let val a = new size | |
2922 | in (EvalTimer.timerOn(); | |
2923 | loop times (fn _ => f(a,a)); | |
2924 | let val t = LargeInt.toInt(EvalTimer.timerRead()) div times | |
2925 | val i = StringCvt.padLeft #" " 6 (Int.toString t) | |
2926 | in print i | |
2927 | end) | |
2928 | end | |
2929 | in | |
2930 | print (Int.toString size); | |
2931 | print " "; | |
2932 | List.app test_op list_op; | |
2933 | print "\n" | |
2934 | end | |
2935 | in List.app (test_many list_op) list_sizes | |
2936 | end | |
2937 | ||
2938 | structure Main = | |
2939 | struct | |
2940 | fun one() = | |
2941 | let | |
2942 | val _ = | |
2943 | let val operators = [(20, RTensor.+), (20, RTensor.* ), (20, RTensor./), | |
2944 | (4, fn (a,b) => RTensor.+* a b), | |
2945 | (4, fn (a,b) => RTensor.*+ a b)] | |
2946 | fun constructor size = RTensor.new([size,size],1.0) | |
2947 | in | |
2948 | print "Real tensors: (+, *, /, +*, *+)\n"; | |
2949 | test_operator constructor operators [100,200,300,400,500]; | |
2950 | print "\n\n" | |
2951 | end | |
2952 | ||
2953 | val _ = | |
2954 | let val operators = [(20, CTensor.+), (20, CTensor.* ), (20, CTensor./), | |
2955 | (4, fn (a,b) => CTensor.+* a b), | |
2956 | (4, fn (a,b) => CTensor.*+ a b)] | |
2957 | fun constructor size = CTensor.new([size,size],CNumber.one) | |
2958 | in | |
2959 | print "Real tensors: (+, *, /, +*, *+)\n"; | |
2960 | test_operator constructor operators [100,200,300,400,500]; | |
2961 | print "\n\n" | |
2962 | end | |
2963 | in () | |
2964 | end | |
2965 | ||
2966 | fun doit n = | |
2967 | if n = 0 | |
2968 | then () | |
2969 | else (one () | |
2970 | ; doit (n - 1)) | |
2971 | end |