Import Upstream version 20180207
[hcoop/debian/mlton.git] / benchmark / tests / tensor.sml
1 (* Obtained at http://www.arrakis.es/~worm/ *)
2
3 signature MONO_VECTOR =
4 sig
5 type vector
6 type elem
7 val maxLen : int
8 val fromList : elem list -> vector
9 val tabulate : (int * (int -> elem)) -> vector
10 val length : vector -> int
11 val sub : (vector * int) -> elem
12 val extract : (vector * int * int option) -> vector
13 val concat : vector list -> vector
14 val mapi : ((int * elem) -> elem) -> (vector * int * int option) -> vector
15 val map : (elem -> elem) -> vector -> vector
16 val appi : ((int * elem) -> unit) -> (vector * int * int option) -> unit
17 val app : (elem -> unit) -> vector -> unit
18 val foldli : ((int * elem * 'a) -> 'a) -> 'a -> (vector * int * int option) -> 'a
19 val foldri : ((int * elem * 'a) -> 'a) -> 'a -> (vector * int * int option) -> 'a
20 val foldl : ((elem * 'a) -> 'a) -> 'a -> vector -> 'a
21 val foldr : ((elem * 'a) -> 'a) -> 'a -> vector -> 'a
22 end
23
24 (*
25 Copyright (c) Juan Jose Garcia Ripoll.
26 All rights reserved.
27
28 Refer to the COPYRIGHT file for license conditions
29 *)
30
31 (* COPYRIGHT
32
33 Redistribution and use in source and binary forms, with or
34 without modification, are permitted provided that the following
35 conditions are met:
36
37 1. Redistributions of source code must retain the above copyright
38 notice, this list of conditions and the following disclaimer.
39
40 2. Redistributions in binary form must reproduce the above
41 copyright notice, this list of conditions and the following
42 disclaimer in the documentation and/or other materials provided
43 with the distribution.
44
45 3. All advertising materials mentioning features or use of this
46 software must display the following acknowledgement:
47 This product includes software developed by Juan Jose
48 Garcia Ripoll.
49
50 4. The name of Juan Jose Garcia Ripoll may not be used to endorse
51 or promote products derived from this software without
52 specific prior written permission.
53
54 THIS SOFTWARE IS PROVIDED BY JUAN JOSE GARCIA RIPOLL ``AS IS''
55 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
56 TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
57 PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL HE BE
58 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
59 OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
60 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
61 OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
62 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
63 TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
64 OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
65 OF SUCH DAMAGE.
66 *)
67
68 structure EvalTimer =
69 struct
70 local
71 val TIME = ref (Time.now())
72 in
73 fun timerOn () =
74 (TIME := Time.now(); ())
75 fun timerRead () =
76 Time.toMilliseconds(Time.-(Time.now(),!TIME))
77 fun timerOff () =
78 let val delta = timerRead()
79 in
80 print "Elapsed: ";
81 print (LargeInt.toString delta);
82 print " ms\n"
83 end
84 fun time f = (timerOn(); f(); timerOff())
85 end
86 end
87 structure Loop =
88 struct
89 fun all (a, b, f) =
90 if a > b then
91 true
92 else if f a then
93 all (a+1, b, f)
94 else
95 false
96
97 fun any (a, b, f) =
98 if a > b then
99 false
100 else if f a then
101 true
102 else
103 any (a+1, b, f)
104
105 fun app (a, b, f) =
106 if a < b then
107 (f a; app (a+1, b, f))
108 else
109 ()
110
111 fun app' (a, b, d, f) =
112 if a < b then
113 (f a; app' (a+d, b, d, f))
114 else
115 ()
116
117 fun appi' (a, b, d, f) =
118 if a < b then
119 (f a; appi' (a+d, b, d, f))
120 else
121 ()
122 end
123 (*
124 INDEX -Signature-
125
126 Indices are a enumerable finite set of data with an order and a map
127 to a continous nonnegative interval of integers. In the sample
128 implementation, Index, each index is a list of integers,
129 [i1,...,in]
130 and each set of indices is defined by a shape, which has the same
131 shape of an index but with each integer incremented by one
132 shape = [k1,...,kn]
133 0 <= i1 < k1
134
135 type storage = RowMajor | ColumnMajor
136 order : storage
137 Identifies:
138 1) the underlying algorithms for this structure
139 2) the most significant index
140 3) the index that varies more slowly
141 4) the total order
142 RowMajor means that first index is most significant and varies
143 more slowly, while ColumnMajor means that last index is the most
144 significant and varies more slowly. For instance
145 RowMajor => [0,0]<[0,1]<[1,0]<[1,1] (C, C++, Pascal)
146 ColumnMajor => [0,0]>[1,0]>[0,1]>[1,1] (Fortran)
147 last shape
148 first shape
149 Returns the last/first index that belongs to the sed defined by
150 'shape'.
151 inBounds shape index
152 Checkes whether 'index' belongs to the set defined by 'shape'.
153 toInt shape index
154 As we said, indices can be sorted and mapped to a finite set of
155 integers. 'toInt' obtaines the integer number that corresponds to
156 a certain index.
157 indexer shape
158 It is equivalent to the partial evaluation 'toInt shape' but
159 optimized for 'shape'.
160
161 next shape index
162 prev shape index
163 next' shape index
164 prev' shape index
165 Obtain the following or previous index to the one we supply.
166 next and prev return an object of type 'index option' so that
167 if there is no such following/previous, the output is NONE.
168 On the other hand, next'/prev' raise an exception when the
169 output is not well defined and their output is always of type
170 index. next/prev/next'/prev' raise an exception if 'index'
171 does not belong to the set of 'shape'.
172
173 all shape f
174 any shape f
175 app shape f
176 Iterates 'f' over every index of the set defined by 'shape'.
177 'all' stops when 'f' first returns false, 'any' stops when
178 'f' first returns true and 'app' does not stop and discards the
179 output of 'f'.
180
181 compare(a,b)
182 Returns LESS/GREATER/EQUAL according to the total order which
183 is defined in the set of all indices.
184 <,>,eq,<=,>=,<>
185 Reduced comparisons which are defined in terms of 'compare'.
186
187 validShape t
188 validIndex t
189 Checks whether 't' conforms a valid shape or index.
190
191 iteri shape f
192 *)
193
194 signature INDEX =
195 sig
196 type t
197 type indexer = t -> int
198 datatype storage = RowMajor | ColumnMajor
199
200 exception Index
201 exception Shape
202
203 val order : storage
204 val toInt : t -> t -> int
205 val length : t -> int
206 val first : t -> t
207 val last : t -> t
208 val next : t -> t -> t option
209 val prev : t -> t -> t option
210 val next' : t -> t -> t
211 val prev' : t -> t -> t
212 val indexer : t -> (t -> int)
213
214 val inBounds : t -> t -> bool
215 val compare : t * t -> order
216 val < : t * t -> bool
217 val > : t * t -> bool
218 val eq : t * t -> bool
219 val <= : t * t -> bool
220 val >= : t * t -> bool
221 val <> : t * t -> bool
222 val - : t * t -> t
223
224 val validShape : t -> bool
225 val validIndex : t -> bool
226
227 val all : t -> (t -> bool) -> bool
228 val any : t -> (t -> bool) -> bool
229 val app : t -> (t -> unit) -> unit
230 end
231 structure Index : INDEX =
232 struct
233 type t = int list
234 type indexer = t -> int
235 datatype storage = RowMajor | ColumnMajor
236
237 exception Index
238 exception Shape
239
240 val order = ColumnMajor
241
242 fun validShape shape = List.all (fn x => x > 0) shape
243
244 fun validIndex index = List.all (fn x => x >= 0) index
245
246 fun toInt shape index =
247 let fun loop ([], [], accum, _) = accum
248 | loop ([], _, _, _) = raise Index
249 | loop (_, [], _, _) = raise Index
250 | loop (i::ri, l::rl, accum, fac) =
251 if (i >= 0) andalso (i < l) then
252 loop (ri, rl, i*fac + accum, fac*l)
253 else
254 raise Index
255 in loop (index, shape, 0, 1)
256 end
257
258 (* ----- CACHED LINEAR INDEXER -----
259
260 An indexer is a function that takes a list of
261 indices, validates it and produces a nonnegative
262 integer number. In short, the indexer is the
263 mapper from indices to element positions in
264 arrays.
265
266 'indexer' builds such a mapper by optimizing
267 the most common cases, which are 1d and 2d
268 tensors.
269 *)
270 local
271 fun doindexer [] _ = raise Shape
272 | doindexer [a] [dx] =
273 let fun f [x] = if (x > 0) andalso (x < a)
274 then x
275 else raise Index
276 | f _ = raise Index
277 in f end
278 | doindexer [a,b] [dx, dy] =
279 let fun f [x,y] = if ((x > 0) andalso (x < a) andalso
280 (y > 0) andalso (y < b))
281 then x + dy * y
282 else raise Index
283 | f _ = raise Index
284 in f end
285 | doindexer [a,b,c] [dx,dy,dz] =
286 let fun f [x,y,z] = if ((x > 0) andalso (x < a) andalso
287 (y > 0) andalso (y < b) andalso
288 (z > 0) andalso (z < c))
289 then x + dy * y + dz * z
290 else raise Index
291 | f _ = raise Index
292 in f end
293 | doindexer shape memo =
294 let fun f [] [] accum [] = accum
295 | f _ _ _ [] = raise Index
296 | f (fac::rf) (ndx::ri) accum (dim::rd) =
297 if (ndx >= 0) andalso (ndx < dim) then
298 f rf ri (accum + ndx * fac) rd
299 else
300 raise Index
301 in f shape memo 0
302 end
303 in
304 fun indexer shape =
305 let fun memoize accum [] = []
306 | memoize accum (dim::rd) =
307 accum :: (memoize (dim * accum) rd)
308 in
309 if validShape shape
310 then doindexer shape (memoize 1 shape)
311 else raise Shape
312 end
313 end
314
315 fun length shape =
316 let fun prod (a,b) =
317 if b < 0 then raise Shape else a * b
318 in foldl prod 1 shape
319 end
320
321 fun first shape = map (fn x => 0) shape
322
323 fun last [] = []
324 | last (size :: rest) =
325 if size < 1
326 then raise Shape
327 else size - 1 :: last rest
328
329 fun next' [] [] = raise Subscript
330 | next' _ [] = raise Index
331 | next' [] _ = raise Index
332 | next' (dimension::restd) (index::resti) =
333 if (index + 1) < dimension
334 then (index + 1) :: resti
335 else 0 :: (next' restd resti)
336
337 fun prev' [] [] = raise Subscript
338 | prev' _ [] = raise Index
339 | prev' [] _ = raise Index
340 | prev' (dimension::restd) (index::resti) =
341 if (index > 0)
342 then index - 1 :: resti
343 else dimension - 1 :: prev' restd resti
344
345 fun next shape index = (SOME (next' shape index)) handle
346 Subscript => NONE
347
348 fun prev shape index = (SOME (prev' shape index)) handle
349 Subscript => NONE
350
351 fun inBounds shape index =
352 ListPair.all (fn (x,y) => (x >= 0) andalso (x < y))
353 (index, shape)
354
355 fun compare ([],[]) = EQUAL
356 | compare (_, []) = raise Index
357 | compare ([],_) = raise Index
358 | compare (a::ra, b::rb) =
359 case Int.compare (a,b) of
360 EQUAL => compare (ra,rb)
361 | LESS => LESS
362 | GREATER => GREATER
363
364 local
365 fun iterator a inner =
366 let fun loop accum f =
367 let fun innerloop i =
368 if i < a
369 then if inner (i::accum) f
370 then innerloop (i+1)
371 else false
372 else true
373 in innerloop 0
374 end
375 in loop
376 end
377 fun build_iterator [a] =
378 let fun loop accum f =
379 let fun innerloop i =
380 if i < a
381 then if f (i::accum)
382 then innerloop (i+1)
383 else false
384 else true
385 in innerloop 0
386 end
387 in loop
388 end
389 | build_iterator (a::rest) = iterator a (build_iterator rest)
390 in
391 fun all shape = build_iterator shape []
392 end
393
394 local
395 fun iterator a inner =
396 let fun loop accum f =
397 let fun innerloop i =
398 if i < a
399 then if inner (i::accum) f
400 then true
401 else innerloop (i+1)
402 else false
403 in innerloop 0
404 end
405 in loop
406 end
407 fun build_iterator [a] =
408 let fun loop accum f =
409 let fun innerloop i =
410 if i < a
411 then if f (i::accum)
412 then true
413 else innerloop (i+1)
414 else false
415 in innerloop 0
416 end
417 in loop
418 end
419 | build_iterator (a::rest) = iterator a (build_iterator rest)
420 in
421 fun any shape = build_iterator shape []
422 end
423
424 local
425 fun iterator a inner =
426 let fun loop accum f =
427 let fun innerloop i =
428 if i < a
429 then (inner (i::accum) f;
430 innerloop (i+1))
431 else ()
432 in innerloop 0
433 end
434 in loop
435 end
436 fun build_iterator [a] =
437 let fun loop accum f =
438 let fun innerloop i =
439 if i < a
440 then (f (i::accum); innerloop (i+1))
441 else ()
442 in innerloop 0
443 end
444 in loop
445 end
446 | build_iterator (a::rest) = iterator a (build_iterator rest)
447 in
448 fun app shape = build_iterator shape []
449 end
450
451 fun a < b = compare(a,b) = LESS
452 fun a > b = compare(a,b) = GREATER
453 fun eq (a, b) = compare(a,b) = EQUAL
454 fun a <> b = not (a = b)
455 fun a <= b = not (a > b)
456 fun a >= b = not (a < b)
457 fun a - b = ListPair.map Int.- (a,b)
458
459 end
460 (*
461 Copyright (c) Juan Jose Garcia Ripoll.
462 All rights reserved.
463
464 Refer to the COPYRIGHT file for license conditions
465 *)
466
467 (*
468 TENSOR - Signature -
469
470 Polymorphic tensors of any type. With 'tensor' we denote a (mutable)
471 array of any rank, with as many indices as one wishes, and that may
472 be traversed (map, fold, etc) according to any of those indices.
473
474 type 'a tensor
475 Polymorphic tensor whose elements are all of type 'a.
476 val storage = RowMajor | ColumnMajor
477 RowMajor = data is stored in consecutive cells, first index
478 varying fastest (FORTRAN convention)
479 ColumnMajor = data is stored in consecutive cells, last
480 index varying fastest (C,C++,Pascal,CommonLisp convention)
481 new ([i1,...,in],init)
482 Build a new tensor with n indices, each of sizes i1...in,
483 filled with 'init'.
484 fromArray (shape,data)
485 fromList (shape,data)
486 Use 'data' to fill a tensor of that shape. An exception is
487 raised if 'data' is too large or too small to properly
488 fill the vector. Later use of a 'data' array is disregarded
489 -- one must think that the tensor now owns the array.
490 length tensor
491 rank tensor
492 shape tensor
493 Return the number of elements, the number of indices and
494 the shape (size of each index) of the tensor.
495 toArray tensor
496 Return the data of the tensor in the form of an array.
497 Mutation of this array may lead to unexpected behavior.
498
499 sub (tensor,[i1,...,in])
500 update (tensor,[i1,...,in],new_value)
501 Access the element that is indexed by the numbers [i1,..,in]
502
503 app f a
504 appi f a
505 The same as 'map' and 'mapi' but the function 'f' outputs
506 nothing and no new array is produced, i.e. one only seeks
507 the side effect that 'f' may produce.
508 map2 operation a b
509 Apply function 'f' to pairs of elements of 'a' and 'b'
510 and build a new tensor with the output. Both operands
511 must have the same shape or an exception is raised.
512 The procedure is sequential, as specified by 'storage'.
513 foldl operation a n
514 Fold-left the elements of tensor 'a' along the n-th
515 index.
516 all test a
517 any test a
518 Folded boolean tests on the elements of the tensor.
519 *)
520
521 signature TENSOR =
522 sig
523 structure Array : ARRAY
524 structure Index : INDEX
525 type index = Index.t
526 type 'a tensor
527
528 val new : index * 'a -> 'a tensor
529 val tabulate : index * (index -> 'a) -> 'a tensor
530 val length : 'a tensor -> int
531 val rank : 'a tensor -> int
532 val shape : 'a tensor -> (index)
533 val reshape : index -> 'a tensor -> 'a tensor
534 val fromList : index * 'a list -> 'a tensor
535 val fromArray : index * 'a array -> 'a tensor
536 val toArray : 'a tensor -> 'a array
537
538 val sub : 'a tensor * index -> 'a
539 val update : 'a tensor * index * 'a -> unit
540 val map : ('a -> 'b) -> 'a tensor -> 'b tensor
541 val map2 : ('a * 'b -> 'c) -> 'a tensor -> 'b tensor -> 'c tensor
542 val app : ('a -> unit) -> 'a tensor -> unit
543 val appi : (int * 'a -> unit) -> 'a tensor -> unit
544 val foldl : ('c * 'a -> 'c) -> 'c -> 'a tensor -> int -> 'c tensor
545 val all : ('a -> bool) -> 'a tensor -> bool
546 val any : ('a -> bool) -> 'a tensor -> bool
547 end
548
549 (*
550 Copyright (c) Juan Jose Garcia Ripoll.
551 All rights reserved.
552
553 Refer to the COPYRIGHT file for license conditions
554 *)
555
556 structure Tensor : TENSOR =
557 struct
558 structure Array = Array
559 structure Index = Index
560
561 type index = Index.t
562 type 'a tensor = {shape : index, indexer : Index.indexer, data : 'a array}
563
564 exception Shape
565 exception Match
566 exception Index
567
568 local
569 (*----- LOCALS -----*)
570
571 fun make' (shape, data) =
572 {shape = shape, indexer = Index.indexer shape, data = data}
573
574 fun toInt {shape, indexer, data} index = indexer index
575
576 fun array_map f a =
577 let fun apply index = f(Array.sub(a,index)) in
578 Array.tabulate(Array.length a, apply)
579 end
580
581 fun splitList (l as (a::rest), place) =
582 let fun loop (left,here,right) 0 = (List.rev left,here,right)
583 | loop (_,_,[]) place = raise Index
584 | loop (left,here,a::right) place =
585 loop (here::left,a,right) (place-1)
586 in
587 if place <= 0 then
588 loop ([],a,rest) (List.length rest - place)
589 else
590 loop ([],a,rest) (place - 1)
591 end
592
593 in
594 (*----- STRUCTURAL OPERATIONS & QUERIES ------*)
595
596 fun new (shape, init) =
597 if not (Index.validShape shape) then
598 raise Shape
599 else
600 let val length = Index.length shape in
601 {shape = shape,
602 indexer = Index.indexer shape,
603 data = Array.array(length,init)}
604 end
605
606 fun toArray {shape, indexer, data} = data
607
608 fun length {shape, indexer, data} = Array.length data
609
610 fun shape {shape, indexer, data} = shape
611
612 fun rank t = List.length (shape t)
613
614 fun reshape new_shape tensor =
615 if Index.validShape new_shape then
616 case (Index.length new_shape) = length tensor of
617 true => make'(new_shape, toArray tensor)
618 | false => raise Match
619 else
620 raise Shape
621
622 fun fromArray (s, a) =
623 case Index.validShape s andalso
624 ((Index.length s) = (Array.length a)) of
625 true => make'(s, a)
626 | false => raise Shape
627
628 fun fromList (s, a) = fromArray (s, Array.fromList a)
629
630 fun tabulate (shape,f) =
631 if Index.validShape shape then
632 let val last = Index.last shape
633 val length = Index.length shape
634 val c = Array.array(length, f last)
635 fun dotable (c, indices, i) =
636 (Array.update(c, i, f indices);
637 case i of
638 0 => c
639 | i => dotable(c, Index.prev' shape indices, i-1))
640 in
641 make'(shape,dotable(c, Index.prev' shape last, length-1))
642 end
643 else
644 raise Shape
645
646 (*----- ELEMENTWISE OPERATIONS -----*)
647
648 fun sub (t, index) = Array.sub(#data t, toInt t index)
649
650 fun update (t, index, value) =
651 Array.update(toArray t, toInt t index, value)
652
653 fun map f {shape, indexer, data} =
654 {shape = shape, indexer = indexer, data = array_map f data}
655
656 fun map2 f t1 t2=
657 let val {shape, indexer, data} = t1
658 val {shape=shape2, indexer=indexer2, data=data2} = t2
659 fun apply i = f (Array.sub(data,i), Array.sub(data2,i))
660 val len = Array.length data
661 in
662 if Index.eq(shape, shape2) then
663 {shape = shape,
664 indexer = indexer,
665 data = Array.tabulate(len, apply)}
666 else
667 raise Match
668 end
669
670 fun appi f tensor = Array.appi f (toArray tensor)
671
672 fun app f tensor = Array.app f (toArray tensor)
673
674 fun all f tensor =
675 let val a = toArray tensor
676 in Loop.all(0, length tensor - 1, fn i =>
677 f (Array.sub(a, i)))
678 end
679
680 fun any f tensor =
681 let val a = toArray tensor
682 in Loop.any(0, length tensor - 1, fn i =>
683 f (Array.sub(a, i)))
684 end
685
686 fun foldl f init {shape, indexer, data=a} index =
687 let val (head,lk,tail) = splitList(shape, index)
688 val li = Index.length head
689 val lj = Index.length tail
690 val c = Array.array(li * lj,init)
691 fun loopi (0, _, _) = ()
692 | loopi (i, ia, ic) =
693 (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia)));
694 loopi (i-1, ia+1, ic+1))
695 fun loopk (0, ia, _) = ia
696 | loopk (k, ia, ic) = (loopi (li, ia, ic);
697 loopk (k-1, ia+li, ic))
698 fun loopj (0, _, _) = ()
699 | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
700 in
701 loopj (lj, 0, 0);
702 make'(head @ tail, c)
703 end
704
705 end
706 end (* Tensor *)
707
708 (*
709 Copyright (c) Juan Jose Garcia Ripoll.
710 All rights reserved.
711
712 Refer to the COPYRIGHT file for license conditions
713 *)
714
715 (*
716 MONO_TENSOR - signature -
717
718 Monomorphic tensor of arbitrary data (not only numbers). Operations
719 should be provided to run the data in several ways, according to one
720 index.
721
722 type tensor
723 The type of the tensor itself
724 type elem
725 The type of every element
726 val storage = RowMajor | ColumnMajor
727 RowMajor = data is stored in consecutive cells, first index
728 varying fastest (FORTRAN convention)
729 ColumnMajor = data is stored in consecutive cells, last
730 index varying fastest (C,C++,Pascal,CommonLisp convention)
731 new ([i1,...,in],init)
732 Build a new tensor with n indices, each of sizes i1...in,
733 filled with 'init'.
734 fromArray (shape,data)
735 fromList (shape,data)
736 Use 'data' to fill a tensor of that shape. An exception is
737 raised if 'data' is too large or too small to properly
738 fill the vector. Later use of a 'data' array is disregarded
739 -- one must think that the tensor now owns the array.
740 length tensor
741 rank tensor
742 shape tensor
743 Return the number of elements, the number of indices and
744 the shape (size of each index) of the tensor.
745 toArray tensor
746 Return the data of the tensor in the form of an array.
747 Mutation of this array may lead to unexpected behavior.
748 The data in the array is stored according to `storage'.
749
750 sub (tensor,[i1,...,in])
751 update (tensor,[i1,...,in],new_value)
752 Access the element that is indexed by the numbers [i1,..,in]
753
754 map f a
755 mapi f a
756 Produce a new array by mapping the function sequentially
757 as specified by 'storage', to each element of tensor 'a'.
758 In 'mapi' the function receives a (indices,value) tuple,
759 while in 'map' it only receives the value.
760 app f a
761 appi f a
762 The same as 'map' and 'mapi' but the function 'f' outputs
763 nothing and no new array is produced, i.e. one only seeks
764 the side effect that 'f' may produce.
765 map2 operation a b
766 Apply function 'f' to pairs of elements of 'a' and 'b'
767 and build a new tensor with the output. Both operands
768 must have the same shape or an exception is raised.
769 The procedure is sequential, as specified by 'storage'.
770 foldl operation a n
771 Fold-left the elements of tensor 'a' along the n-th
772 index.
773 all test a
774 any test a
775 Folded boolean tests on the elements of the tensor.
776
777 map', map2', foldl'
778 Polymorphic versions of map, map2, foldl.
779 *)
780
781 signature MONO_TENSOR =
782 sig
783 structure Array : MONO_ARRAY
784 structure Index : INDEX
785 type index = Index.t
786 type elem
787 type tensor
788 type t = tensor
789
790 val new : index * elem -> tensor
791 val tabulate : index * (index -> elem) -> tensor
792 val length : tensor -> int
793 val rank : tensor -> int
794 val shape : tensor -> (index)
795 val reshape : index -> tensor -> tensor
796 val fromList : index * elem list -> tensor
797 val fromArray : index * Array.array -> tensor
798 val toArray : tensor -> Array.array
799
800 val sub : tensor * index -> elem
801 val update : tensor * index * elem -> unit
802 val map : (elem -> elem) -> tensor -> tensor
803 val map2 : (elem * elem -> elem) -> tensor -> tensor -> tensor
804 val app : (elem -> unit) -> tensor -> unit
805 val appi : (int * elem -> unit) -> tensor -> unit
806 val foldl : (elem * 'a -> 'a) -> 'a -> tensor -> tensor
807 val foldln : (elem * elem -> elem) -> elem -> tensor -> int -> tensor
808 val all : (elem -> bool) -> tensor -> bool
809 val any : (elem -> bool) -> tensor -> bool
810
811 val map' : (elem -> 'a) -> tensor -> 'a Tensor.tensor
812 val map2' : (elem * elem -> 'a) -> tensor -> tensor -> 'a Tensor.tensor
813 val foldl' : ('a * elem -> 'a) -> 'a -> tensor -> int -> 'a Tensor.tensor
814 end
815
816 (*
817 NUMBER - Signature -
818
819 Guarantees a structure with a minimal number of mathematical operations
820 so as to build an algebraic structure named Tensor.
821 *)
822
823 signature NUMBER =
824 sig
825 type t
826 val zero : t
827 val one : t
828 val ~ : t -> t
829 val + : t * t -> t
830 val - : t * t -> t
831 val * : t * t -> t
832 val / : t * t -> t
833 val toString : t -> string
834 end
835
836 signature NUMBER =
837 sig
838 type t
839 val zero : t
840 val one : t
841
842 val + : t * t -> t
843 val - : t * t -> t
844 val * : t * t -> t
845 val *+ : t * t * t -> t
846 val *- : t * t * t -> t
847 val ** : t * int -> t
848
849 val ~ : t -> t
850 val abs : t -> t
851 val signum : t -> t
852
853 val == : t * t -> bool
854 val != : t * t -> bool
855
856 val toString : t -> string
857 val fromInt : int -> t
858 val scan : (char,'a) StringCvt.reader -> (t,'a) StringCvt.reader
859 end
860
861 signature INTEGRAL_NUMBER =
862 sig
863 include NUMBER
864
865 val quot : t * t -> t
866 val rem : t * t -> t
867 val mod : t * t -> t
868 val div : t * t -> t
869
870 val compare : t * t -> order
871 val < : t * t -> bool
872 val > : t * t -> bool
873 val <= : t * t -> bool
874 val >= : t * t -> bool
875
876 val max : t * t -> t
877 val min : t * t -> t
878 end
879
880 signature FRACTIONAL_NUMBER =
881 sig
882 include NUMBER
883
884 val pi : t
885 val e : t
886
887 val / : t * t -> t
888 val recip : t -> t
889
890 val ln : t -> t
891 val pow : t * t -> t
892 val exp : t -> t
893 val sqrt : t -> t
894
895 val cos : t -> t
896 val sin : t -> t
897 val tan : t -> t
898 val sinh : t -> t
899 val cosh : t -> t
900 val tanh : t -> t
901
902 val acos : t -> t
903 val asin : t -> t
904 val atan : t -> t
905 val asinh : t -> t
906 val acosh : t -> t
907 val atanh : t -> t
908 val atan2 : t * t -> t
909 end
910
911 signature REAL_NUMBER =
912 sig
913 include FRACTIONAL_NUMBER
914
915 val compare : t * t -> order
916 val < : t * t -> bool
917 val > : t * t -> bool
918 val <= : t * t -> bool
919 val >= : t * t -> bool
920
921 val max : t * t -> t
922 val min : t * t -> t
923 end
924
925 signature COMPLEX_NUMBER =
926 sig
927 include FRACTIONAL_NUMBER
928
929 structure Real : REAL_NUMBER
930 type real = Real.t
931
932 val make : real * real -> t
933 val split : t -> real * real
934 val realPart : t -> real
935 val imagPart : t -> real
936 val abs2 : t -> real
937 end
938
939 structure INumber : INTEGRAL_NUMBER =
940 struct
941 open Int
942 type t = Int.int
943 val zero = 0
944 val one = 1
945
946 infix **
947 fun i ** n =
948 let fun loop 0 = 1
949 | loop 1 = i
950 | loop n =
951 let val x = loop (Int.div(n, 2))
952 val m = Int.mod(n, 2)
953 in
954 if m = 0 then
955 x * x
956 else
957 x * x * i
958 end
959 in if n < 0
960 then raise Domain
961 else loop n
962 end
963
964 fun signum i = case compare(i, 0) of
965 GREATER => 1
966 | EQUAL => 0
967 | LESS => ~1
968
969 infix ==
970 infix !=
971 fun a == b = a = b
972 fun a != b = (a <> b)
973 fun *+(b,c,a) = b * c + a
974 fun *-(b,c,a) = b * c - b
975
976 fun scan getc = Int.scan StringCvt.DEC getc
977 end
978
979 structure RNumber : REAL_NUMBER =
980 struct
981 open Real
982 open Real.Math
983 type t = Real.real
984 val zero = 0.0
985 val one = 1.0
986
987 fun signum x = case compare(x,0.0) of
988 LESS => ~1.0
989 | GREATER => 1.0
990 | EQUAL => 0.0
991
992 fun recip x = 1.0 / x
993
994 infix **
995 fun i ** n =
996 let fun loop 0 = one
997 | loop 1 = i
998 | loop n =
999 let val x = loop (Int.div(n, 2))
1000 val m = Int.mod(n, 2)
1001 in
1002 if m = 0 then
1003 x * x
1004 else
1005 x * x * i
1006 end
1007 in if Int.<(n, 0)
1008 then raise Domain
1009 else loop n
1010 end
1011
1012 fun max (a, b) = if a < b then b else a
1013 fun min (a, b) = if a < b then a else b
1014
1015 fun asinh x = ln (x + sqrt(1.0 + x * x))
1016 fun acosh x = ln (x + (x + 1.0) * sqrt((x - 1.0)/(x + 1.0)))
1017 fun atanh x = ln ((1.0 + x) / sqrt(1.0 - x * x))
1018
1019 end
1020 (*
1021 Complex(R) - Functor -
1022
1023 Provides support for complex numbers based on tuples. Should be
1024 highly efficient as most operations can be inlined.
1025 *)
1026
1027 structure CNumber : COMPLEX_NUMBER =
1028 struct
1029 structure Real = RNumber
1030
1031 type t = Real.t * Real.t
1032 type real = Real.t
1033
1034 val zero = (0.0,0.0)
1035 val one = (1.0,0.0)
1036 val pi = (Real.pi, 0.0)
1037 val e = (Real.e, 0.0)
1038
1039 fun make (r,i) = (r,i) : t
1040 fun split z = z
1041 fun realPart (r,_) = r
1042 fun imagPart (_,i) = i
1043
1044 fun abs2 (r,i) = Real.+(Real.*(r,r),Real.*(i,i)) (* FIXME!!! *)
1045 fun arg (r,i) = Real.atan2(i,r)
1046 fun modulus z = Real.sqrt(abs2 z)
1047 fun abs z = (modulus z, 0.0)
1048 fun signum (z as (r,i)) =
1049 let val m = modulus z
1050 in (Real./(r,m), Real./(i,m))
1051 end
1052
1053 fun ~ (r1,i1) = (Real.~ r1, Real.~ i1)
1054 fun (r1,i1) + (r2,i2) = (Real.+(r1,r2), Real.+(i1,i2))
1055 fun (r1,i1) - (r2,i2) = (Real.-(r1,r2), Real.-(i1,i1))
1056 fun (r1,i1) * (r2,i2) = (Real.-(Real.*(r1,r2),Real.*(i1,i2)),
1057 Real.+(Real.*(r1,i2),Real.*(r2,i1)))
1058 fun (r1,i1) / (r2,i2) =
1059 let val modulus = abs2(r2,i2)
1060 val (nr,ni) = (r1,i1) * (r2,i2)
1061 in
1062 (Real./(nr,modulus), Real./(ni,modulus))
1063 end
1064 fun *+((r1,i1),(r2,i2),(r0,i0)) =
1065 (Real.*+(Real.~ i1, i2, Real.*+(r1,r2,r0)),
1066 Real.*+(r2, i2, Real.*+(r1,i2,i0)))
1067 fun *-((r1,i1),(r2,i2),(r0,i0)) =
1068 (Real.*+(Real.~ i1, i2, Real.*-(r1,r2,r0)),
1069 Real.*+(r2, i2, Real.*-(r1,i2,i0)))
1070
1071 infix **
1072 fun i ** n =
1073 let fun loop 0 = one
1074 | loop 1 = i
1075 | loop n =
1076 let val x = loop (Int.div(n, 2))
1077 val m = Int.mod(n, 2)
1078 in
1079 if m = 0 then
1080 x * x
1081 else
1082 x * x * i
1083 end
1084 in if Int.<(n, 0)
1085 then raise Domain
1086 else loop n
1087 end
1088
1089 fun recip (r1, i1) =
1090 let val modulus = abs2(r1, i1)
1091 in (Real./(r1, modulus), Real./(Real.~ i1, modulus))
1092 end
1093 fun ==(z, w) = Real.==(realPart z, realPart w) andalso Real.==(imagPart z, imagPart w)
1094 fun !=(z, w) = Real.!=(realPart z, realPart w) andalso Real.!=(imagPart z, imagPart w)
1095 fun fromInt i = (Real.fromInt i, 0.0)
1096 fun toString (r,i) =
1097 String.concat ["(",Real.toString r,",",Real.toString i,")"]
1098
1099 fun exp (x, y) =
1100 let val expx = Real.exp x
1101 in (Real.*(x, (Real.cos y)), Real.*(x, (Real.sin y)))
1102 end
1103
1104 local
1105 val half = Real.recip (Real.fromInt 2)
1106 in
1107 fun sqrt (z as (x,y)) =
1108 if Real.==(x, 0.0) andalso Real.==(y, 0.0) then
1109 zero
1110 else
1111 let val m = Real.+(modulus z, Real.abs x)
1112 val u' = Real.sqrt (Real.*(m, half))
1113 val v' = Real./(Real.abs y , Real.+(u',u'))
1114 val (u,v) = if Real.<(x, 0.0) then (v',u') else (u',v')
1115 in (u, if Real.<(y, 0.0) then Real.~ v else v)
1116 end
1117 end
1118 fun ln z = (Real.ln (modulus z), arg z)
1119
1120 fun pow (z, n) =
1121 let val l = ln z
1122 in exp (l * n)
1123 end
1124
1125 fun sin (x, y) = (Real.*(Real.sin x, Real.cosh y),
1126 Real.*(Real.cos x, Real.sinh y))
1127 fun cos (x, y) = (Real.*(Real.cos x, Real.cosh y),
1128 Real.~ (Real.*(Real.sin x, Real.sinh y)))
1129 fun tan (x, y) =
1130 let val (sx, cx) = (Real.sin x, Real.cos x)
1131 val (shy, chy) = (Real.sinh y, Real.cosh y)
1132 val a = (Real.*(sx, chy), Real.*(cx, shy))
1133 val b = (Real.*(cx, chy), Real.*(Real.~ sx, shy))
1134 in a / b
1135 end
1136
1137 fun sinh (x, y) = (Real.*(Real.cos y, Real.sinh x),
1138 Real.*(Real.sin y, Real.cosh x))
1139 fun cosh (x, y) = (Real.*(Real.cos y, Real.cosh x),
1140 Real.*(Real.sin y, Real.sinh x))
1141 fun tanh (x, y) =
1142 let val (sy, cy) = (Real.sin y, Real.cos y)
1143 val (shx, chx) = (Real.sinh x, Real.cosh x)
1144 val a = (Real.*(cy, shx), Real.*(sy, chx))
1145 val b = (Real.*(cy, chx), Real.*(sy, shx))
1146 in a / b
1147 end
1148
1149 fun asin (z as (x,y)) =
1150 let val w = sqrt (one - z * z)
1151 val (x',y') = ln ((Real.~ y, x) + w)
1152 in (y', Real.~ x')
1153 end
1154
1155 fun acos (z as (x,y)) =
1156 let val (x', y') = sqrt (one + z * z)
1157 val (x'', y'') = ln (z + (Real.~ y', x'))
1158 in (y'', Real.~ x'')
1159 end
1160
1161 fun atan (z as (x,y)) =
1162 let val w = sqrt (one + z*z)
1163 val (x',y') = ln ((Real.-(1.0, y), x) / w)
1164 in (y', Real.~ x')
1165 end
1166
1167 fun atan2 (y, x) = atan(y / x)
1168
1169 fun asinh x = ln (x + sqrt(one + x * x))
1170 fun acosh x = ln (x + (x + one) * sqrt((x - one)/(x + one)))
1171 fun atanh x = ln ((one + x) / sqrt(one - x * x))
1172
1173 fun scan getc =
1174 let val scanner = Real.scan getc
1175 in fn stream =>
1176 case scanner stream of
1177 NONE => NONE
1178 | SOME (a, rest) =>
1179 case scanner rest of
1180 NONE => NONE
1181 | SOME (b, rest) => SOME (make(a,b), rest)
1182 end
1183
1184 end (* ComplexNumber *)
1185
1186 (*
1187 Copyright (c) Juan Jose Garcia Ripoll.
1188 All rights reserved.
1189
1190 Refer to the COPYRIGHT file for license conditions
1191 *)
1192
1193 structure INumberArray =
1194 struct
1195 open Array
1196 type array = INumber.t array
1197 type vector = INumber.t vector
1198 type elem = INumber.t
1199 structure Vector =
1200 struct
1201 open Vector
1202 type vector = INumber.t Vector.vector
1203 type elem = INumber.t
1204 end
1205 fun map f a = tabulate(length a, fn x => (f (sub(a,x))))
1206 fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x))))
1207 fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x))))
1208 end
1209
1210 structure RNumberArray =
1211 struct
1212 open Real64Array
1213 val sub = Unsafe.Real64Array.sub
1214 val update = Unsafe.Real64Array.update
1215 fun map f a = tabulate(length a, fn x => (f (sub(a,x))))
1216 fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x))))
1217 fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x))))
1218 end
1219
1220 (*--------------------- COMPLEX ARRAY -------------------------*)
1221
1222 structure BasicCNumberArray =
1223 struct
1224 structure Complex : COMPLEX_NUMBER = CNumber
1225 structure Array : MONO_ARRAY = RNumberArray
1226
1227 type elem = Complex.t
1228 type array = Array.array * Array.array
1229
1230 val maxLen = Array.maxLen
1231
1232 fun length (a,b) = Array.length a
1233
1234 fun sub ((a,b),index) = Complex.make(Array.sub(a,index),Array.sub(b,index))
1235
1236 fun update ((a,b),index,z) =
1237 let val (re,im) = Complex.split z in
1238 Array.update(a, index, re);
1239 Array.update(b, index, im)
1240 end
1241
1242 local
1243 fun makeRange (a, start, NONE) = makeRange(a, start, SOME (length a - 1))
1244 | makeRange (a, start, SOME last) =
1245 let val len = length a
1246 val diff = last - start
1247 in
1248 if (start >= len) orelse (last >= len) then
1249 raise Subscript
1250 else if diff < 0 then
1251 (a, start, 0)
1252 else
1253 (a, start, diff + 1)
1254 end
1255
1256 in
1257
1258 fun array (size,z:elem) =
1259 let val realsize = size * 2
1260 val r = Complex.realPart z
1261 val i = Complex.imagPart z in
1262 (Array.array(size,r), Array.array(size,i))
1263 end
1264
1265 fun zeroarray size =
1266 (Array.array(size,Complex.Real.zero),
1267 Array.array(size,Complex.Real.zero))
1268
1269 fun tabulate (size,f) =
1270 let val a = array(size, Complex.zero)
1271 fun loop i =
1272 case i = size of
1273 true => a
1274 | false => (update(a, i, f i); loop (i+1))
1275 in
1276 loop 0
1277 end
1278
1279 fun fromList list =
1280 let val length = List.length list
1281 val a = zeroarray length
1282 fun loop (_, []) = a
1283 | loop (i, z::rest) = (update(a, i, z);
1284 loop (i+1, rest))
1285 in
1286 loop(0,list)
1287 end
1288
1289 fun extract range =
1290 let val (a, start, len) = makeRange range
1291 fun copy i = sub(a, i + start)
1292 in tabulate(len, copy)
1293 end
1294
1295 fun concat array_list =
1296 let val total_length = foldl (op +) 0 (map length array_list)
1297 val a = array(total_length, Complex.zero)
1298 fun copy (_, []) = a
1299 | copy (pos, v::rest) =
1300 let fun loop i =
1301 case i = 0 of
1302 true => ()
1303 | false => (update(a, i+pos, sub(v, i)); loop (i-1))
1304 in (loop (length v - 1); copy(length v + pos, rest))
1305 end
1306 in
1307 copy(0, array_list)
1308 end
1309
1310 fun copy {src : array, si : int, len : int option, dst : array, di : int } =
1311 let val (a, ia, la) = makeRange (src, si, len)
1312 val (b, ib, lb) = makeRange (dst, di, len)
1313 fun copy i =
1314 case i < 0 of
1315 true => ()
1316 | false => (update(b, i+ib, sub(a, i+ia)); copy (i-1))
1317 in copy (la - 1)
1318 end
1319
1320 val copyVec = copy
1321
1322 fun modifyi f range =
1323 let val (a, start, len) = makeRange range
1324 val last = start + len
1325 fun loop i =
1326 case i >= last of
1327 true => ()
1328 | false => (update(a, i, f(i, sub(a,i))); loop (i+1))
1329 in loop start
1330 end
1331
1332 fun modify f a =
1333 let val last = length a
1334 fun loop i =
1335 case i >= last of
1336 true => ()
1337 | false => (update(a, i, f(sub(a,i))); loop (i+1))
1338 in loop 0
1339 end
1340
1341 fun app f a =
1342 let val size = length a
1343 fun loop i =
1344 case i = size of
1345 true => ()
1346 | false => (f(sub(a,i)); loop (i+1))
1347 in
1348 loop 0
1349 end
1350
1351 fun appi f range =
1352 let val (a, start, len) = makeRange range
1353 val last = start + len
1354 fun loop i =
1355 case i >= last of
1356 true => ()
1357 | false => (f(i, sub(a,i)); loop (i+1))
1358 in
1359 loop start
1360 end
1361
1362 fun map f a =
1363 let val len = length a
1364 val c = zeroarray len
1365 fun loop ~1 = c
1366 | loop i = (update(a, i, f(sub(a,i))); loop (i-1))
1367 in loop (len-1)
1368 end
1369
1370 fun map2 f a b =
1371 let val len = length a
1372 val c = zeroarray len
1373 fun loop ~1 = c
1374 | loop i = (update(c, i, f(sub(a,i),sub(b,i)));
1375 loop (i-1))
1376 in loop (len-1)
1377 end
1378
1379 fun mapi f range =
1380 let val (a, start, len) = makeRange range
1381 fun rule i = f (i+start, sub(a, i+start))
1382 in tabulate(len, rule)
1383 end
1384
1385 fun foldli f init range =
1386 let val (a, start, len) = makeRange range
1387 val last = start + len - 1
1388 fun loop (i, accum) =
1389 case i > last of
1390 true => accum
1391 | false => loop (i+1, f(i, sub(a,i), accum))
1392 in loop (start, init)
1393 end
1394
1395 fun foldri f init range =
1396 let val (a, start, len) = makeRange range
1397 val last = start + len - 1
1398 fun loop (i, accum) =
1399 case i < start of
1400 true => accum
1401 | false => loop (i-1, f(i, sub(a,i), accum))
1402 in loop (last, init)
1403 end
1404
1405 fun foldl f init a = foldli (fn (_, a, x) => f(a,x)) init (a,0,NONE)
1406 fun foldr f init a = foldri (fn (_, x, a) => f(x,a)) init (a,0,NONE)
1407 end
1408 end (* BasicCNumberArray *)
1409
1410
1411 structure CNumberArray =
1412 struct
1413 structure Vector =
1414 struct
1415 open BasicCNumberArray
1416 type vector = array
1417 end : MONO_VECTOR
1418 type vector = Vector.vector
1419 open BasicCNumberArray
1420 end (* CNumberArray *)
1421 structure INumber : INTEGRAL_NUMBER =
1422 struct
1423 open Int
1424 type t = Int.int
1425 val zero = 0
1426 val one = 1
1427 infix **
1428 fun i ** n =
1429 let fun loop 0 = 1
1430 | loop 1 = i
1431 | loop n =
1432 let val x = loop (Int.div(n, 2))
1433 val m = Int.mod(n, 2)
1434 in
1435 if m = 0 then
1436 x * x
1437 else
1438 x * x * i
1439 end
1440 in if n < 0
1441 then raise Domain
1442 else loop n
1443 end
1444 fun signum i = case compare(i, 0) of
1445 GREATER => 1
1446 | EQUAL => 0
1447 | LESS => ~1
1448 infix ==
1449 infix !=
1450 fun a == b = a = b
1451 fun a != b = (a <> b)
1452 fun *+(b,c,a) = b * c + a
1453 fun *-(b,c,a) = b * c - b
1454 fun scan getc = Int.scan StringCvt.DEC getc
1455 end
1456 structure RNumber : REAL_NUMBER =
1457 struct
1458 open Real
1459 open Real.Math
1460 type t = Real.real
1461 val zero = 0.0
1462 val one = 1.0
1463 fun signum x = case compare(x,0.0) of
1464 LESS => ~1.0
1465 | GREATER => 1.0
1466 | EQUAL => 0.0
1467 fun recip x = 1.0 / x
1468 infix **
1469 fun i ** n =
1470 let fun loop 0 = one
1471 | loop 1 = i
1472 | loop n =
1473 let val x = loop (Int.div(n, 2))
1474 val m = Int.mod(n, 2)
1475 in
1476 if m = 0 then
1477 x * x
1478 else
1479 x * x * i
1480 end
1481 in if Int.<(n, 0)
1482 then raise Domain
1483 else loop n
1484 end
1485 fun max (a, b) = if a < b then b else a
1486 fun min (a, b) = if a < b then a else b
1487 fun asinh x = ln (x + sqrt(1.0 + x * x))
1488 fun acosh x = ln (x + (x + 1.0) * sqrt((x - 1.0)/(x + 1.0)))
1489 fun atanh x = ln ((1.0 + x) / sqrt(1.0 - x * x))
1490 end
1491 (*
1492 Complex(R) - Functor -
1493 Provides support for complex numbers based on tuples. Should be
1494 highly efficient as most operations can be inlined.
1495 *)
1496 structure CNumber : COMPLEX_NUMBER =
1497 struct
1498 structure Real = RNumber
1499 type t = Real.t * Real.t
1500 type real = Real.t
1501 val zero = (0.0,0.0)
1502 val one = (1.0,0.0)
1503 val pi = (Real.pi, 0.0)
1504 val e = (Real.e, 0.0)
1505 fun make (r,i) = (r,i) : t
1506 fun split z = z
1507 fun realPart (r,_) = r
1508 fun imagPart (_,i) = i
1509 fun abs2 (r,i) = Real.+(Real.*(r,r),Real.*(i,i)) (* FIXME!!! *)
1510 fun arg (r,i) = Real.atan2(i,r)
1511 fun modulus z = Real.sqrt(abs2 z)
1512 fun abs z = (modulus z, 0.0)
1513 fun signum (z as (r,i)) =
1514 let val m = modulus z
1515 in (Real./(r,m), Real./(i,m))
1516 end
1517 fun ~ (r1,i1) = (Real.~ r1, Real.~ i1)
1518 fun (r1,i1) + (r2,i2) = (Real.+(r1,r2), Real.+(i1,i2))
1519 fun (r1,i1) - (r2,i2) = (Real.-(r1,r2), Real.-(i1,i1))
1520 fun (r1,i1) * (r2,i2) = (Real.-(Real.*(r1,r2),Real.*(i1,i2)),
1521 Real.+(Real.*(r1,i2),Real.*(r2,i1)))
1522 fun (r1,i1) / (r2,i2) =
1523 let val modulus = abs2(r2,i2)
1524 val (nr,ni) = (r1,i1) * (r2,i2)
1525 in
1526 (Real./(nr,modulus), Real./(ni,modulus))
1527 end
1528 fun *+((r1,i1),(r2,i2),(r0,i0)) =
1529 (Real.*+(Real.~ i1, i2, Real.*+(r1,r2,r0)),
1530 Real.*+(r2, i2, Real.*+(r1,i2,i0)))
1531 fun *-((r1,i1),(r2,i2),(r0,i0)) =
1532 (Real.*+(Real.~ i1, i2, Real.*-(r1,r2,r0)),
1533 Real.*+(r2, i2, Real.*-(r1,i2,i0)))
1534 infix **
1535 fun i ** n =
1536 let fun loop 0 = one
1537 | loop 1 = i
1538 | loop n =
1539 let val x = loop (Int.div(n, 2))
1540 val m = Int.mod(n, 2)
1541 in
1542 if m = 0 then
1543 x * x
1544 else
1545 x * x * i
1546 end
1547 in if Int.<(n, 0)
1548 then raise Domain
1549 else loop n
1550 end
1551 fun recip (r1, i1) =
1552 let val modulus = abs2(r1, i1)
1553 in (Real./(r1, modulus), Real./(Real.~ i1, modulus))
1554 end
1555 fun ==(z, w) = Real.==(realPart z, realPart w) andalso Real.==(imagPart z, imagPart w)
1556 fun !=(z, w) = Real.!=(realPart z, realPart w) andalso Real.!=(imagPart z, imagPart w)
1557 fun fromInt i = (Real.fromInt i, 0.0)
1558 fun toString (r,i) =
1559 String.concat ["(",Real.toString r,",",Real.toString i,")"]
1560 fun exp (x, y) =
1561 let val expx = Real.exp x
1562 in (Real.*(x, (Real.cos y)), Real.*(x, (Real.sin y)))
1563 end
1564 local
1565 val half = Real.recip (Real.fromInt 2)
1566 in
1567 fun sqrt (z as (x,y)) =
1568 if Real.==(x, 0.0) andalso Real.==(y, 0.0) then
1569 zero
1570 else
1571 let val m = Real.+(modulus z, Real.abs x)
1572 val u' = Real.sqrt (Real.*(m, half))
1573 val v' = Real./(Real.abs y , Real.+(u',u'))
1574 val (u,v) = if Real.<(x, 0.0) then (v',u') else (u',v')
1575 in (u, if Real.<(y, 0.0) then Real.~ v else v)
1576 end
1577 end
1578 fun ln z = (Real.ln (modulus z), arg z)
1579 fun pow (z, n) =
1580 let val l = ln z
1581 in exp (l * n)
1582 end
1583 fun sin (x, y) = (Real.*(Real.sin x, Real.cosh y),
1584 Real.*(Real.cos x, Real.sinh y))
1585 fun cos (x, y) = (Real.*(Real.cos x, Real.cosh y),
1586 Real.~ (Real.*(Real.sin x, Real.sinh y)))
1587 fun tan (x, y) =
1588 let val (sx, cx) = (Real.sin x, Real.cos x)
1589 val (shy, chy) = (Real.sinh y, Real.cosh y)
1590 val a = (Real.*(sx, chy), Real.*(cx, shy))
1591 val b = (Real.*(cx, chy), Real.*(Real.~ sx, shy))
1592 in a / b
1593 end
1594 fun sinh (x, y) = (Real.*(Real.cos y, Real.sinh x),
1595 Real.*(Real.sin y, Real.cosh x))
1596 fun cosh (x, y) = (Real.*(Real.cos y, Real.cosh x),
1597 Real.*(Real.sin y, Real.sinh x))
1598 fun tanh (x, y) =
1599 let val (sy, cy) = (Real.sin y, Real.cos y)
1600 val (shx, chx) = (Real.sinh x, Real.cosh x)
1601 val a = (Real.*(cy, shx), Real.*(sy, chx))
1602 val b = (Real.*(cy, chx), Real.*(sy, shx))
1603 in a / b
1604 end
1605 fun asin (z as (x,y)) =
1606 let val w = sqrt (one - z * z)
1607 val (x',y') = ln ((Real.~ y, x) + w)
1608 in (y', Real.~ x')
1609 end
1610 fun acos (z as (x,y)) =
1611 let val (x', y') = sqrt (one + z * z)
1612 val (x'', y'') = ln (z + (Real.~ y', x'))
1613 in (y'', Real.~ x'')
1614 end
1615 fun atan (z as (x,y)) =
1616 let val w = sqrt (one + z*z)
1617 val (x',y') = ln ((Real.-(1.0, y), x) / w)
1618 in (y', Real.~ x')
1619 end
1620 fun atan2 (y, x) = atan(y / x)
1621 fun asinh x = ln (x + sqrt(one + x * x))
1622 fun acosh x = ln (x + (x + one) * sqrt((x - one)/(x + one)))
1623 fun atanh x = ln ((one + x) / sqrt(one - x * x))
1624 fun scan getc =
1625 let val scanner = Real.scan getc
1626 in fn stream =>
1627 case scanner stream of
1628 NONE => NONE
1629 | SOME (a, rest) =>
1630 case scanner rest of
1631 NONE => NONE
1632 | SOME (b, rest) => SOME (make(a,b), rest)
1633 end
1634 end (* ComplexNumber *)
1635 (*
1636 Copyright (c) Juan Jose Garcia Ripoll.
1637 All rights reserved.
1638 Refer to the COPYRIGHT file for license conditions
1639 *)
1640 structure PrettyPrint :>
1641 sig
1642 datatype modifier =
1643 Int of int |
1644 Real of real |
1645 Complex of CNumber.t |
1646 String of string
1647 val list : ('a -> string) -> 'a list -> unit
1648 val intList : int list -> unit
1649 val realList : real list -> unit
1650 val stringList : string list -> unit
1651 val array : ('a -> string) -> 'a array -> unit
1652 val intArray : int array -> unit
1653 val realArray : real array -> unit
1654 val stringArray : string array -> unit
1655 val sequence :
1656 int -> ((int * 'a -> unit) -> 'b -> unit) -> ('a -> string) -> 'b -> unit
1657 val print : modifier list -> unit
1658 end =
1659 struct
1660 datatype modifier =
1661 Int of int |
1662 Real of real |
1663 Complex of CNumber.t |
1664 String of string
1665 fun list _ [] = print "[]"
1666 | list cvt (a::resta) =
1667 let fun loop a [] = (print(cvt a); print "]")
1668 | loop a (b::restb) = (print(cvt a); print ", "; loop b restb)
1669 in
1670 print "[";
1671 loop a resta
1672 end
1673 fun boolList a = list Bool.toString a
1674 fun intList a = list Int.toString a
1675 fun realList a = list Real.toString a
1676 fun stringList a = list (fn x => x) a
1677 fun array cvt a =
1678 let val length = Array.length a - 1
1679 fun print_one (i,x) =
1680 (print(cvt x); if not(i = length) then print ", " else ())
1681 in
1682 Array.appi print_one a
1683 end
1684 fun boolArray a = array Bool.toString a
1685 fun intArray a = array Int.toString a
1686 fun realArray a = array Real.toString a
1687 fun stringArray a = array (fn x => x) a
1688 fun sequence length appi cvt seq =
1689 let val length = length - 1
1690 fun print_one (i:int,x) =
1691 (print(cvt x); if not(i = length) then print ", " else ())
1692 in
1693 print "[";
1694 appi print_one seq;
1695 print "]\n"
1696 end
1697 fun print b =
1698 let fun printer (Int a) = INumber.toString a
1699 | printer (Real a) = RNumber.toString a
1700 | printer (Complex a) = CNumber.toString a
1701 | printer (String a) = a
1702 in List.app (fn x => (TextIO.print (printer x))) b
1703 end
1704 end (* PrettyPrint *)
1705 fun print' x = List.app print x
1706 (*
1707 Copyright (c) Juan Jose Garcia Ripoll.
1708 All rights reserved.
1709 Refer to the COPYRIGHT file for license conditions
1710 *)
1711 structure INumberArray =
1712 struct
1713 open Array
1714 type array = INumber.t array
1715 type vector = INumber.t vector
1716 type elem = INumber.t
1717 structure Vector =
1718 struct
1719 open Vector
1720 type vector = INumber.t Vector.vector
1721 type elem = INumber.t
1722 end
1723 fun map f a = tabulate(length a, fn x => (f (sub(a,x))))
1724 fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x))))
1725 fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x))))
1726 end
1727 structure RNumberArray =
1728 struct
1729 open Real64Array
1730 val sub = Unsafe.Real64Array.sub
1731 val update = Unsafe.Real64Array.update
1732 fun map f a = tabulate(length a, fn x => (f (sub(a,x))))
1733 fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x))))
1734 fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x))))
1735 end
1736 (*--------------------- COMPLEX ARRAY -------------------------*)
1737 structure BasicCNumberArray =
1738 struct
1739 structure Complex : COMPLEX_NUMBER = CNumber
1740 structure Array : MONO_ARRAY = RNumberArray
1741 type elem = Complex.t
1742 type array = Array.array * Array.array
1743 val maxLen = Array.maxLen
1744 fun length (a,b) = Array.length a
1745 fun sub ((a,b),index) = Complex.make(Array.sub(a,index),Array.sub(b,index))
1746 fun update ((a,b),index,z) =
1747 let val (re,im) = Complex.split z in
1748 Array.update(a, index, re);
1749 Array.update(b, index, im)
1750 end
1751 local
1752 fun makeRange (a, start, NONE) = makeRange(a, start, SOME (length a - 1))
1753 | makeRange (a, start, SOME last) =
1754 let val len = length a
1755 val diff = last - start
1756 in
1757 if (start >= len) orelse (last >= len) then
1758 raise Subscript
1759 else if diff < 0 then
1760 (a, start, 0)
1761 else
1762 (a, start, diff + 1)
1763 end
1764 in
1765 fun array (size,z:elem) =
1766 let val realsize = size * 2
1767 val r = Complex.realPart z
1768 val i = Complex.imagPart z in
1769 (Array.array(size,r), Array.array(size,i))
1770 end
1771 fun zeroarray size =
1772 (Array.array(size,Complex.Real.zero),
1773 Array.array(size,Complex.Real.zero))
1774 fun tabulate (size,f) =
1775 let val a = array(size, Complex.zero)
1776 fun loop i =
1777 case i = size of
1778 true => a
1779 | false => (update(a, i, f i); loop (i+1))
1780 in
1781 loop 0
1782 end
1783 fun fromList list =
1784 let val length = List.length list
1785 val a = zeroarray length
1786 fun loop (_, []) = a
1787 | loop (i, z::rest) = (update(a, i, z);
1788 loop (i+1, rest))
1789 in
1790 loop(0,list)
1791 end
1792 fun extract range =
1793 let val (a, start, len) = makeRange range
1794 fun copy i = sub(a, i + start)
1795 in tabulate(len, copy)
1796 end
1797 fun concat array_list =
1798 let val total_length = foldl (op +) 0 (map length array_list)
1799 val a = array(total_length, Complex.zero)
1800 fun copy (_, []) = a
1801 | copy (pos, v::rest) =
1802 let fun loop i =
1803 case i = 0 of
1804 true => ()
1805 | false => (update(a, i+pos, sub(v, i)); loop (i-1))
1806 in (loop (length v - 1); copy(length v + pos, rest))
1807 end
1808 in
1809 copy(0, array_list)
1810 end
1811 fun copy {src : array, si : int, len : int option, dst : array, di : int } =
1812 let val (a, ia, la) = makeRange (src, si, len)
1813 val (b, ib, lb) = makeRange (dst, di, len)
1814 fun copy i =
1815 case i < 0 of
1816 true => ()
1817 | false => (update(b, i+ib, sub(a, i+ia)); copy (i-1))
1818 in copy (la - 1)
1819 end
1820 val copyVec = copy
1821 fun modifyi f range =
1822 let val (a, start, len) = makeRange range
1823 val last = start + len
1824 fun loop i =
1825 case i >= last of
1826 true => ()
1827 | false => (update(a, i, f(i, sub(a,i))); loop (i+1))
1828 in loop start
1829 end
1830 fun modify f a =
1831 let val last = length a
1832 fun loop i =
1833 case i >= last of
1834 true => ()
1835 | false => (update(a, i, f(sub(a,i))); loop (i+1))
1836 in loop 0
1837 end
1838 fun app f a =
1839 let val size = length a
1840 fun loop i =
1841 case i = size of
1842 true => ()
1843 | false => (f(sub(a,i)); loop (i+1))
1844 in
1845 loop 0
1846 end
1847 fun appi f range =
1848 let val (a, start, len) = makeRange range
1849 val last = start + len
1850 fun loop i =
1851 case i >= last of
1852 true => ()
1853 | false => (f(i, sub(a,i)); loop (i+1))
1854 in
1855 loop start
1856 end
1857 fun map f a =
1858 let val len = length a
1859 val c = zeroarray len
1860 fun loop ~1 = c
1861 | loop i = (update(a, i, f(sub(a,i))); loop (i-1))
1862 in loop (len-1)
1863 end
1864 fun map2 f a b =
1865 let val len = length a
1866 val c = zeroarray len
1867 fun loop ~1 = c
1868 | loop i = (update(c, i, f(sub(a,i),sub(b,i)));
1869 loop (i-1))
1870 in loop (len-1)
1871 end
1872 fun mapi f range =
1873 let val (a, start, len) = makeRange range
1874 fun rule i = f (i+start, sub(a, i+start))
1875 in tabulate(len, rule)
1876 end
1877 fun foldli f init range =
1878 let val (a, start, len) = makeRange range
1879 val last = start + len - 1
1880 fun loop (i, accum) =
1881 case i > last of
1882 true => accum
1883 | false => loop (i+1, f(i, sub(a,i), accum))
1884 in loop (start, init)
1885 end
1886 fun foldri f init range =
1887 let val (a, start, len) = makeRange range
1888 val last = start + len - 1
1889 fun loop (i, accum) =
1890 case i < start of
1891 true => accum
1892 | false => loop (i-1, f(i, sub(a,i), accum))
1893 in loop (last, init)
1894 end
1895 fun foldl f init a = foldli (fn (_, a, x) => f(a,x)) init (a,0,NONE)
1896 fun foldr f init a = foldri (fn (_, x, a) => f(x,a)) init (a,0,NONE)
1897 end
1898 end (* BasicCNumberArray *)
1899 structure CNumberArray =
1900 struct
1901 structure Vector =
1902 struct
1903 open BasicCNumberArray
1904 type vector = array
1905 end : MONO_VECTOR
1906 type vector = Vector.vector
1907 open BasicCNumberArray
1908 end (* CNumberArray *)
1909 structure ITensor =
1910 struct
1911 structure Number = INumber
1912 structure Array = INumberArray
1913 (*
1914 Copyright (c) Juan Jose Garcia Ripoll.
1915 All rights reserved.
1916 Refer to the COPYRIGHT file for license conditions
1917 *)
1918 structure MonoTensor =
1919 struct
1920 (* PARAMETERS
1921 structure Array = Array
1922 *)
1923 structure Index = Index
1924 type elem = Array.elem
1925 type index = Index.t
1926 type tensor = {shape : index, indexer : Index.indexer, data : Array.array}
1927 type t = tensor
1928 exception Shape
1929 exception Match
1930 exception Index
1931 local
1932 (*----- LOCALS -----*)
1933 fun make' (shape, data) =
1934 {shape = shape, indexer = Index.indexer shape, data = data}
1935 fun toInt {shape, indexer, data} index = indexer index
1936 fun splitList (l as (a::rest), place) =
1937 let fun loop (left,here,right) 0 = (List.rev left,here,right)
1938 | loop (_,_,[]) place = raise Index
1939 | loop (left,here,a::right) place =
1940 loop (here::left,a,right) (place-1)
1941 in
1942 if place <= 0 then
1943 loop ([],a,rest) (List.length rest - place)
1944 else
1945 loop ([],a,rest) (place - 1)
1946 end
1947 in
1948 (*----- STRUCTURAL OPERATIONS & QUERIES ------*)
1949 fun new (shape, init) =
1950 if not (Index.validShape shape) then
1951 raise Shape
1952 else
1953 let val length = Index.length shape in
1954 {shape = shape,
1955 indexer = Index.indexer shape,
1956 data = Array.array(length,init)}
1957 end
1958 fun toArray {shape, indexer, data} = data
1959 fun length {shape, indexer, data} = Array.length data
1960 fun shape {shape, indexer, data} = shape
1961 fun rank t = List.length (shape t)
1962 fun reshape new_shape tensor =
1963 if Index.validShape new_shape then
1964 case (Index.length new_shape) = length tensor of
1965 true => make'(new_shape, toArray tensor)
1966 | false => raise Match
1967 else
1968 raise Shape
1969 fun fromArray (s, a) =
1970 case Index.validShape s andalso
1971 ((Index.length s) = (Array.length a)) of
1972 true => make'(s, a)
1973 | false => raise Shape
1974 fun fromList (s, a) = fromArray (s, Array.fromList a)
1975 fun tabulate (shape,f) =
1976 if Index.validShape shape then
1977 let val last = Index.last shape
1978 val length = Index.length shape
1979 val c = Array.array(length, f last)
1980 fun dotable (c, indices, i) =
1981 (Array.update(c, i, f indices);
1982 if i <= 1
1983 then c
1984 else dotable(c, Index.prev' shape indices, i-1))
1985 in make'(shape,dotable(c, Index.prev' shape last, length-2))
1986 end
1987 else
1988 raise Shape
1989 (*----- ELEMENTWISE OPERATIONS -----*)
1990 fun sub (t, index) = Array.sub(#data t, toInt t index)
1991 fun update (t, index, value) =
1992 Array.update(toArray t, toInt t index, value)
1993 fun map f {shape, indexer, data} =
1994 {shape = shape, indexer = indexer, data = Array.map f data}
1995 fun map2 f t1 t2=
1996 let val {shape=shape1, indexer=indexer1, data=data1} = t1
1997 val {shape=shape2, indexer=indexer2, data=data2} = t2
1998 in
1999 if Index.eq(shape1,shape2) then
2000 {shape = shape1,
2001 indexer = indexer1,
2002 data = Array.map2 f data1 data2}
2003 else
2004 raise Match
2005 end
2006 fun appi f tensor = Array.appi f (toArray tensor)
2007 fun app f tensor = Array.app f (toArray tensor)
2008 fun all f tensor =
2009 let val a = toArray tensor
2010 in Loop.all(0, length tensor - 1, fn i =>
2011 f (Array.sub(a, i)))
2012 end
2013 fun any f tensor =
2014 let val a = toArray tensor
2015 in Loop.any(0, length tensor - 1, fn i =>
2016 f (Array.sub(a, i)))
2017 end
2018 fun foldl f init tensor = Array.foldl f init (toArray tensor)
2019 fun foldln f init {shape, indexer, data=a} index =
2020 let val (head,lk,tail) = splitList(shape, index)
2021 val li = Index.length head
2022 val lj = Index.length tail
2023 val c = Array.array(li * lj,init)
2024 fun loopi (0, _, _) = ()
2025 | loopi (i, ia, ic) =
2026 (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia)));
2027 loopi (i-1, ia+1, ic+1))
2028 fun loopk (0, ia, _) = ia
2029 | loopk (k, ia, ic) = (loopi (li, ia, ic);
2030 loopk (k-1, ia+li, ic))
2031 fun loopj (0, _, _) = ()
2032 | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2033 in
2034 loopj (lj, 0, 0);
2035 make'(head @ tail, c)
2036 end
2037 (* --- POLYMORPHIC ELEMENTWISE OPERATIONS --- *)
2038 fun array_map' f a =
2039 let fun apply index = f(Array.sub(a,index)) in
2040 Tensor.Array.tabulate(Array.length a, apply)
2041 end
2042 fun map' f t = Tensor.fromArray(shape t, array_map' f (toArray t))
2043 fun map2' f t1 t2 =
2044 let val d1 = toArray t1
2045 val d2 = toArray t2
2046 fun apply i = f (Array.sub(d1,i), Array.sub(d2,i))
2047 val len = Array.length d1
2048 in
2049 if Index.eq(shape t1, shape t2) then
2050 Tensor.fromArray(shape t1, Tensor.Array.tabulate(len,apply))
2051 else
2052 raise Match
2053 end
2054 fun foldl' f init {shape, indexer, data=a} index =
2055 let val (head,lk,tail) = splitList(shape, index)
2056 val li = Index.length head
2057 val lj = Index.length tail
2058 val c = Tensor.Array.array(li * lj,init)
2059 fun loopi (0, _, _) = ()
2060 | loopi (i, ia, ic) =
2061 (Tensor.Array.update(c,ic,f(Tensor.Array.sub(c,ic),Array.sub(a,ia)));
2062 loopi (i-1, ia+1, ic+1))
2063 fun loopk (0, ia, _) = ia
2064 | loopk (k, ia, ic) = (loopi (li, ia, ic);
2065 loopk (k-1, ia+li, ic))
2066 fun loopj (0, _, _) = ()
2067 | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2068 in
2069 loopj (lj, 0, 0);
2070 make'(head @ tail, c)
2071 end
2072 end
2073 end (* MonoTensor *)
2074 open MonoTensor
2075 local
2076 (*
2077 LEFT INDEX CONTRACTION:
2078 a = a(i1,i2,...,in)
2079 b = b(j1,j2,...,jn)
2080 c = c(i2,...,in,j2,...,jn)
2081 = sum(a(k,i2,...,jn)*b(k,j2,...jn)) forall k
2082 MEANINGFUL VARIABLES:
2083 lk = i1 = j1
2084 li = i2*...*in
2085 lj = j2*...*jn
2086 *)
2087 fun do_fold_first a b c lk lj li =
2088 let fun loopk (0, _, _, accum) = accum
2089 | loopk (k, ia, ib, accum) =
2090 let val delta = Number.*(Array.sub(a,ia),Array.sub(b,ib))
2091 in loopk (k-1, ia+1, ib+1, Number.+(delta,accum))
2092 end
2093 fun loopj (0, ib, ic) = c
2094 | loopj (j, ib, ic) =
2095 let fun loopi (0, ia, ic) = ic
2096 | loopi (i, ia, ic) =
2097 (Array.update(c, ic, loopk(lk, ia, ib, Number.zero));
2098 loopi(i-1, ia+lk, ic+1))
2099 in
2100 loopj(j-1, ib+lk, loopi(li, 0, ic))
2101 end
2102 in loopj(lj, 0, 0)
2103 end
2104 in
2105 fun +* ta tb =
2106 let val (rank_a,lk::rest_a,a) = (rank ta, shape ta, toArray ta)
2107 val (rank_b,lk2::rest_b,b) = (rank tb, shape tb, toArray tb)
2108 in if not(lk = lk2)
2109 then raise Match
2110 else let val li = Index.length rest_a
2111 val lj = Index.length rest_b
2112 val c = Array.array(li*lj,Number.zero)
2113 in fromArray(rest_a @ rest_b,
2114 do_fold_first a b c lk li lj)
2115 end
2116 end
2117 end
2118 local
2119 (*
2120 LAST INDEX CONTRACTION:
2121 a = a(i1,i2,...,in)
2122 b = b(j1,j2,...,jn)
2123 c = c(i2,...,in,j2,...,jn)
2124 = sum(mult(a(i1,i2,...,k),b(j1,j2,...,k))) forall k
2125 MEANINGFUL VARIABLES:
2126 lk = in = jn
2127 li = i1*...*i(n-1)
2128 lj = j1*...*j(n-1)
2129 *)
2130 fun do_fold_last a b c lk lj li =
2131 let fun loopi (0, ia, ic, fac) = ()
2132 | loopi (i, ia, ic, fac) =
2133 let val old = Array.sub(c,ic)
2134 val inc = Number.*(Array.sub(a,ia),fac)
2135 in
2136 Array.update(c,ic,Number.+(old,inc));
2137 loopi(i-1, ia+1, ic+1, fac)
2138 end
2139 fun loopj (j, ib, ic) =
2140 let fun loopk (0, ia, ib) = ()
2141 | loopk (k, ia, ib) =
2142 (loopi(li, ia, ic, Array.sub(b,ib));
2143 loopk(k-1, ia+li, ib+lj))
2144 in case j of
2145 0 => c
2146 | _ => (loopk(lk, 0, ib);
2147 loopj(j-1, ib+1, ic+li))
2148 end (* loopj *)
2149 in
2150 loopj(lj, 0, 0)
2151 end
2152 in
2153 fun *+ ta tb =
2154 let val (rank_a,shape_a,a) = (rank ta, shape ta, toArray ta)
2155 val (rank_b,shape_b,b) = (rank tb, shape tb, toArray tb)
2156 val (lk::rest_a) = List.rev shape_a
2157 val (lk2::rest_b) = List.rev shape_b
2158 in if not(lk = lk2)
2159 then raise Match
2160 else let val li = Index.length rest_a
2161 val lj = Index.length rest_b
2162 val c = Array.array(li*lj,Number.zero)
2163 in fromArray(List.rev rest_a @ List.rev rest_b,
2164 do_fold_last a b c lk li lj)
2165 end
2166 end
2167 end
2168 (* ALGEBRAIC OPERATIONS *)
2169 infix **
2170 infix ==
2171 infix !=
2172 fun a + b = map2 Number.+ a b
2173 fun a - b = map2 Number.- a b
2174 fun a * b = map2 Number.* a b
2175 fun a ** i = map (fn x => (Number.**(x,i))) a
2176 fun ~ a = map Number.~ a
2177 fun abs a = map Number.abs a
2178 fun signum a = map Number.signum a
2179 fun a == b = map2' Number.== a b
2180 fun a != b = map2' Number.!= a b
2181 fun toString a = raise Domain
2182 fun fromInt a = new([1], Number.fromInt a)
2183 (* TENSOR SPECIFIC OPERATIONS *)
2184 fun *> n = map (fn x => Number.*(n,x))
2185 fun print t =
2186 (PrettyPrint.intList (shape t);
2187 TextIO.print "\n";
2188 PrettyPrint.sequence (length t) appi Number.toString t)
2189 fun normInf a =
2190 let fun accum (y,x) = Number.max(x,Number.abs y)
2191 in foldl accum Number.zero a
2192 end
2193 end (* NumberTensor *)
2194 structure RTensor =
2195 struct
2196 structure Number = RNumber
2197 structure Array = RNumberArray
2198 (*
2199 Copyright (c) Juan Jose Garcia Ripoll.
2200 All rights reserved.
2201 Refer to the COPYRIGHT file for license conditions
2202 *)
2203 structure MonoTensor =
2204 struct
2205 (* PARAMETERS
2206 structure Array = Array
2207 *)
2208 structure Index = Index
2209 type elem = Array.elem
2210 type index = Index.t
2211 type tensor = {shape : index, indexer : Index.indexer, data : Array.array}
2212 type t = tensor
2213 exception Shape
2214 exception Match
2215 exception Index
2216 local
2217 (*----- LOCALS -----*)
2218 fun make' (shape, data) =
2219 {shape = shape, indexer = Index.indexer shape, data = data}
2220 fun toInt {shape, indexer, data} index = indexer index
2221 fun splitList (l as (a::rest), place) =
2222 let fun loop (left,here,right) 0 = (List.rev left,here,right)
2223 | loop (_,_,[]) place = raise Index
2224 | loop (left,here,a::right) place =
2225 loop (here::left,a,right) (place-1)
2226 in
2227 if place <= 0 then
2228 loop ([],a,rest) (List.length rest - place)
2229 else
2230 loop ([],a,rest) (place - 1)
2231 end
2232 in
2233 (*----- STRUCTURAL OPERATIONS & QUERIES ------*)
2234 fun new (shape, init) =
2235 if not (Index.validShape shape) then
2236 raise Shape
2237 else
2238 let val length = Index.length shape in
2239 {shape = shape,
2240 indexer = Index.indexer shape,
2241 data = Array.array(length,init)}
2242 end
2243 fun toArray {shape, indexer, data} = data
2244 fun length {shape, indexer, data} = Array.length data
2245 fun shape {shape, indexer, data} = shape
2246 fun rank t = List.length (shape t)
2247 fun reshape new_shape tensor =
2248 if Index.validShape new_shape then
2249 case (Index.length new_shape) = length tensor of
2250 true => make'(new_shape, toArray tensor)
2251 | false => raise Match
2252 else
2253 raise Shape
2254 fun fromArray (s, a) =
2255 case Index.validShape s andalso
2256 ((Index.length s) = (Array.length a)) of
2257 true => make'(s, a)
2258 | false => raise Shape
2259 fun fromList (s, a) = fromArray (s, Array.fromList a)
2260 fun tabulate (shape,f) =
2261 if Index.validShape shape then
2262 let val last = Index.last shape
2263 val length = Index.length shape
2264 val c = Array.array(length, f last)
2265 fun dotable (c, indices, i) =
2266 (Array.update(c, i, f indices);
2267 if i <= 1
2268 then c
2269 else dotable(c, Index.prev' shape indices, i-1))
2270 in make'(shape,dotable(c, Index.prev' shape last, length-2))
2271 end
2272 else
2273 raise Shape
2274 (*----- ELEMENTWISE OPERATIONS -----*)
2275 fun sub (t, index) = Array.sub(#data t, toInt t index)
2276 fun update (t, index, value) =
2277 Array.update(toArray t, toInt t index, value)
2278 fun map f {shape, indexer, data} =
2279 {shape = shape, indexer = indexer, data = Array.map f data}
2280 fun map2 f t1 t2=
2281 let val {shape=shape1, indexer=indexer1, data=data1} = t1
2282 val {shape=shape2, indexer=indexer2, data=data2} = t2
2283 in
2284 if Index.eq(shape1,shape2) then
2285 {shape = shape1,
2286 indexer = indexer1,
2287 data = Array.map2 f data1 data2}
2288 else
2289 raise Match
2290 end
2291 fun appi f tensor = Array.appi f (toArray tensor)
2292 fun app f tensor = Array.app f (toArray tensor)
2293 fun all f tensor =
2294 let val a = toArray tensor
2295 in Loop.all(0, length tensor - 1, fn i =>
2296 f (Array.sub(a, i)))
2297 end
2298 fun any f tensor =
2299 let val a = toArray tensor
2300 in Loop.any(0, length tensor - 1, fn i =>
2301 f (Array.sub(a, i)))
2302 end
2303 fun foldl f init tensor = Array.foldl f init (toArray tensor)
2304 fun foldln f init {shape, indexer, data=a} index =
2305 let val (head,lk,tail) = splitList(shape, index)
2306 val li = Index.length head
2307 val lj = Index.length tail
2308 val c = Array.array(li * lj,init)
2309 fun loopi (0, _, _) = ()
2310 | loopi (i, ia, ic) =
2311 (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia)));
2312 loopi (i-1, ia+1, ic+1))
2313 fun loopk (0, ia, _) = ia
2314 | loopk (k, ia, ic) = (loopi (li, ia, ic);
2315 loopk (k-1, ia+li, ic))
2316 fun loopj (0, _, _) = ()
2317 | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2318 in
2319 loopj (lj, 0, 0);
2320 make'(head @ tail, c)
2321 end
2322 (* --- POLYMORPHIC ELEMENTWISE OPERATIONS --- *)
2323 fun array_map' f a =
2324 let fun apply index = f(Array.sub(a,index)) in
2325 Tensor.Array.tabulate(Array.length a, apply)
2326 end
2327 fun map' f t = Tensor.fromArray(shape t, array_map' f (toArray t))
2328 fun map2' f t1 t2 =
2329 let val d1 = toArray t1
2330 val d2 = toArray t2
2331 fun apply i = f (Array.sub(d1,i), Array.sub(d2,i))
2332 val len = Array.length d1
2333 in
2334 if Index.eq(shape t1, shape t2) then
2335 Tensor.fromArray(shape t1, Tensor.Array.tabulate(len,apply))
2336 else
2337 raise Match
2338 end
2339 fun foldl' f init {shape, indexer, data=a} index =
2340 let val (head,lk,tail) = splitList(shape, index)
2341 val li = Index.length head
2342 val lj = Index.length tail
2343 val c = Tensor.Array.array(li * lj,init)
2344 fun loopi (0, _, _) = ()
2345 | loopi (i, ia, ic) =
2346 (Tensor.Array.update(c,ic,f(Tensor.Array.sub(c,ic),Array.sub(a,ia)));
2347 loopi (i-1, ia+1, ic+1))
2348 fun loopk (0, ia, _) = ia
2349 | loopk (k, ia, ic) = (loopi (li, ia, ic);
2350 loopk (k-1, ia+li, ic))
2351 fun loopj (0, _, _) = ()
2352 | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2353 in
2354 loopj (lj, 0, 0);
2355 make'(head @ tail, c)
2356 end
2357 end
2358 end (* MonoTensor *)
2359 open MonoTensor
2360 local
2361 (*
2362 LEFT INDEX CONTRACTION:
2363 a = a(i1,i2,...,in)
2364 b = b(j1,j2,...,jn)
2365 c = c(i2,...,in,j2,...,jn)
2366 = sum(a(k,i2,...,jn)*b(k,j2,...jn)) forall k
2367 MEANINGFUL VARIABLES:
2368 lk = i1 = j1
2369 li = i2*...*in
2370 lj = j2*...*jn
2371 *)
2372 fun do_fold_first a b c lk lj li =
2373 let fun loopk (0, _, _, accum) = accum
2374 | loopk (k, ia, ib, accum) =
2375 let val delta = Number.*(Array.sub(a,ia),Array.sub(b,ib))
2376 in loopk (k-1, ia+1, ib+1, Number.+(delta,accum))
2377 end
2378 fun loopj (0, ib, ic) = c
2379 | loopj (j, ib, ic) =
2380 let fun loopi (0, ia, ic) = ic
2381 | loopi (i, ia, ic) =
2382 (Array.update(c, ic, loopk(lk, ia, ib, Number.zero));
2383 loopi(i-1, ia+lk, ic+1))
2384 in
2385 loopj(j-1, ib+lk, loopi(li, 0, ic))
2386 end
2387 in loopj(lj, 0, 0)
2388 end
2389 in
2390 fun +* ta tb =
2391 let val (rank_a,lk::rest_a,a) = (rank ta, shape ta, toArray ta)
2392 val (rank_b,lk2::rest_b,b) = (rank tb, shape tb, toArray tb)
2393 in if not(lk = lk2)
2394 then raise Match
2395 else let val li = Index.length rest_a
2396 val lj = Index.length rest_b
2397 val c = Array.array(li*lj,Number.zero)
2398 in fromArray(rest_a @ rest_b,
2399 do_fold_first a b c lk li lj)
2400 end
2401 end
2402 end
2403 local
2404 (*
2405 LAST INDEX CONTRACTION:
2406 a = a(i1,i2,...,in)
2407 b = b(j1,j2,...,jn)
2408 c = c(i2,...,in,j2,...,jn)
2409 = sum(mult(a(i1,i2,...,k),b(j1,j2,...,k))) forall k
2410 MEANINGFUL VARIABLES:
2411 lk = in = jn
2412 li = i1*...*i(n-1)
2413 lj = j1*...*j(n-1)
2414 *)
2415 fun do_fold_last a b c lk lj li =
2416 let fun loopi (0, ia, ic, fac) = ()
2417 | loopi (i, ia, ic, fac) =
2418 let val old = Array.sub(c,ic)
2419 val inc = Number.*(Array.sub(a,ia),fac)
2420 in
2421 Array.update(c,ic,Number.+(old,inc));
2422 loopi(i-1, ia+1, ic+1, fac)
2423 end
2424 fun loopj (j, ib, ic) =
2425 let fun loopk (0, ia, ib) = ()
2426 | loopk (k, ia, ib) =
2427 (loopi(li, ia, ic, Array.sub(b,ib));
2428 loopk(k-1, ia+li, ib+lj))
2429 in case j of
2430 0 => c
2431 | _ => (loopk(lk, 0, ib);
2432 loopj(j-1, ib+1, ic+li))
2433 end (* loopj *)
2434 in
2435 loopj(lj, 0, 0)
2436 end
2437 in
2438 fun *+ ta tb =
2439 let val (rank_a,shape_a,a) = (rank ta, shape ta, toArray ta)
2440 val (rank_b,shape_b,b) = (rank tb, shape tb, toArray tb)
2441 val (lk::rest_a) = List.rev shape_a
2442 val (lk2::rest_b) = List.rev shape_b
2443 in if not(lk = lk2)
2444 then raise Match
2445 else let val li = Index.length rest_a
2446 val lj = Index.length rest_b
2447 val c = Array.array(li*lj,Number.zero)
2448 in fromArray(List.rev rest_a @ List.rev rest_b,
2449 do_fold_last a b c lk li lj)
2450 end
2451 end
2452 end
2453 (* ALGEBRAIC OPERATIONS *)
2454 infix **
2455 infix ==
2456 infix !=
2457 fun a + b = map2 Number.+ a b
2458 fun a - b = map2 Number.- a b
2459 fun a * b = map2 Number.* a b
2460 fun a ** i = map (fn x => (Number.**(x,i))) a
2461 fun ~ a = map Number.~ a
2462 fun abs a = map Number.abs a
2463 fun signum a = map Number.signum a
2464 fun a == b = map2' Number.== a b
2465 fun a != b = map2' Number.!= a b
2466 fun toString a = raise Domain
2467 fun fromInt a = new([1], Number.fromInt a)
2468 (* TENSOR SPECIFIC OPERATIONS *)
2469 fun *> n = map (fn x => Number.*(n,x))
2470 fun print t =
2471 (PrettyPrint.intList (shape t);
2472 TextIO.print "\n";
2473 PrettyPrint.sequence (length t) appi Number.toString t)
2474 fun a / b = map2 Number./ a b
2475 fun recip a = map Number.recip a
2476 fun ln a = map Number.ln a
2477 fun pow (a, b) = map (fn x => (Number.pow(x,b))) a
2478 fun exp a = map Number.exp a
2479 fun sqrt a = map Number.sqrt a
2480 fun cos a = map Number.cos a
2481 fun sin a = map Number.sin a
2482 fun tan a = map Number.tan a
2483 fun sinh a = map Number.sinh a
2484 fun cosh a = map Number.cosh a
2485 fun tanh a = map Number.tanh a
2486 fun asin a = map Number.asin a
2487 fun acos a = map Number.acos a
2488 fun atan a = map Number.atan a
2489 fun asinh a = map Number.asinh a
2490 fun acosh a = map Number.acosh a
2491 fun atanh a = map Number.atanh a
2492 fun atan2 (a,b) = map2 Number.atan2 a b
2493 fun normInf a =
2494 let fun accum (y,x) = Number.max(x,Number.abs y)
2495 in foldl accum Number.zero a
2496 end
2497 fun norm1 a =
2498 let fun accum (y,x) = Number.+(x,Number.abs y)
2499 in foldl accum Number.zero a
2500 end
2501 fun norm2 a =
2502 let fun accum (y,x) = Number.+(x, Number.*(y,y))
2503 in Number.sqrt(foldl accum Number.zero a)
2504 end
2505 end (* RTensor *)
2506 structure CTensor =
2507 struct
2508 structure Number = CNumber
2509 structure Array = CNumberArray
2510 (*
2511 Copyright (c) Juan Jose Garcia Ripoll.
2512 All rights reserved.
2513 Refer to the COPYRIGHT file for license conditions
2514 *)
2515 structure MonoTensor =
2516 struct
2517 (* PARAMETERS
2518 structure Array = Array
2519 *)
2520 structure Index = Index
2521 type elem = Array.elem
2522 type index = Index.t
2523 type tensor = {shape : index, indexer : Index.indexer, data : Array.array}
2524 type t = tensor
2525 exception Shape
2526 exception Match
2527 exception Index
2528 local
2529 (*----- LOCALS -----*)
2530 fun make' (shape, data) =
2531 {shape = shape, indexer = Index.indexer shape, data = data}
2532 fun toInt {shape, indexer, data} index = indexer index
2533 fun splitList (l as (a::rest), place) =
2534 let fun loop (left,here,right) 0 = (List.rev left,here,right)
2535 | loop (_,_,[]) place = raise Index
2536 | loop (left,here,a::right) place =
2537 loop (here::left,a,right) (place-1)
2538 in
2539 if place <= 0 then
2540 loop ([],a,rest) (List.length rest - place)
2541 else
2542 loop ([],a,rest) (place - 1)
2543 end
2544 in
2545 (*----- STRUCTURAL OPERATIONS & QUERIES ------*)
2546 fun new (shape, init) =
2547 if not (Index.validShape shape) then
2548 raise Shape
2549 else
2550 let val length = Index.length shape in
2551 {shape = shape,
2552 indexer = Index.indexer shape,
2553 data = Array.array(length,init)}
2554 end
2555 fun toArray {shape, indexer, data} = data
2556 fun length {shape, indexer, data} = Array.length data
2557 fun shape {shape, indexer, data} = shape
2558 fun rank t = List.length (shape t)
2559 fun reshape new_shape tensor =
2560 if Index.validShape new_shape then
2561 case (Index.length new_shape) = length tensor of
2562 true => make'(new_shape, toArray tensor)
2563 | false => raise Match
2564 else
2565 raise Shape
2566 fun fromArray (s, a) =
2567 case Index.validShape s andalso
2568 ((Index.length s) = (Array.length a)) of
2569 true => make'(s, a)
2570 | false => raise Shape
2571 fun fromList (s, a) = fromArray (s, Array.fromList a)
2572 fun tabulate (shape,f) =
2573 if Index.validShape shape then
2574 let val last = Index.last shape
2575 val length = Index.length shape
2576 val c = Array.array(length, f last)
2577 fun dotable (c, indices, i) =
2578 (Array.update(c, i, f indices);
2579 if i <= 1
2580 then c
2581 else dotable(c, Index.prev' shape indices, i-1))
2582 in make'(shape,dotable(c, Index.prev' shape last, length-2))
2583 end
2584 else
2585 raise Shape
2586 (*----- ELEMENTWISE OPERATIONS -----*)
2587 fun sub (t, index) = Array.sub(#data t, toInt t index)
2588 fun update (t, index, value) =
2589 Array.update(toArray t, toInt t index, value)
2590 fun map f {shape, indexer, data} =
2591 {shape = shape, indexer = indexer, data = Array.map f data}
2592 fun map2 f t1 t2=
2593 let val {shape=shape1, indexer=indexer1, data=data1} = t1
2594 val {shape=shape2, indexer=indexer2, data=data2} = t2
2595 in
2596 if Index.eq(shape1,shape2) then
2597 {shape = shape1,
2598 indexer = indexer1,
2599 data = Array.map2 f data1 data2}
2600 else
2601 raise Match
2602 end
2603 fun appi f tensor = Array.appi f (toArray tensor, 0, NONE)
2604 fun app f tensor = Array.app f (toArray tensor)
2605 fun all f tensor =
2606 let val a = toArray tensor
2607 in Loop.all(0, length tensor - 1, fn i =>
2608 f (Array.sub(a, i)))
2609 end
2610 fun any f tensor =
2611 let val a = toArray tensor
2612 in Loop.any(0, length tensor - 1, fn i =>
2613 f (Array.sub(a, i)))
2614 end
2615 fun foldl f init tensor = Array.foldl f init (toArray tensor)
2616 fun foldln f init {shape, indexer, data=a} index =
2617 let val (head,lk,tail) = splitList(shape, index)
2618 val li = Index.length head
2619 val lj = Index.length tail
2620 val c = Array.array(li * lj,init)
2621 fun loopi (0, _, _) = ()
2622 | loopi (i, ia, ic) =
2623 (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia)));
2624 loopi (i-1, ia+1, ic+1))
2625 fun loopk (0, ia, _) = ia
2626 | loopk (k, ia, ic) = (loopi (li, ia, ic);
2627 loopk (k-1, ia+li, ic))
2628 fun loopj (0, _, _) = ()
2629 | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2630 in
2631 loopj (lj, 0, 0);
2632 make'(head @ tail, c)
2633 end
2634 (* --- POLYMORPHIC ELEMENTWISE OPERATIONS --- *)
2635 fun array_map' f a =
2636 let fun apply index = f(Array.sub(a,index)) in
2637 Tensor.Array.tabulate(Array.length a, apply)
2638 end
2639 fun map' f t = Tensor.fromArray(shape t, array_map' f (toArray t))
2640 fun map2' f t1 t2 =
2641 let val d1 = toArray t1
2642 val d2 = toArray t2
2643 fun apply i = f (Array.sub(d1,i), Array.sub(d2,i))
2644 val len = Array.length d1
2645 in
2646 if Index.eq(shape t1, shape t2) then
2647 Tensor.fromArray(shape t1, Tensor.Array.tabulate(len,apply))
2648 else
2649 raise Match
2650 end
2651 fun foldl' f init {shape, indexer, data=a} index =
2652 let val (head,lk,tail) = splitList(shape, index)
2653 val li = Index.length head
2654 val lj = Index.length tail
2655 val c = Tensor.Array.array(li * lj,init)
2656 fun loopi (0, _, _) = ()
2657 | loopi (i, ia, ic) =
2658 (Tensor.Array.update(c,ic,f(Tensor.Array.sub(c,ic),Array.sub(a,ia)));
2659 loopi (i-1, ia+1, ic+1))
2660 fun loopk (0, ia, _) = ia
2661 | loopk (k, ia, ic) = (loopi (li, ia, ic);
2662 loopk (k-1, ia+li, ic))
2663 fun loopj (0, _, _) = ()
2664 | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2665 in
2666 loopj (lj, 0, 0);
2667 make'(head @ tail, c)
2668 end
2669 end
2670 end (* MonoTensor *)
2671 open MonoTensor
2672 local
2673 (*
2674 LEFT INDEX CONTRACTION:
2675 a = a(i1,i2,...,in)
2676 b = b(j1,j2,...,jn)
2677 c = c(i2,...,in,j2,...,jn)
2678 = sum(a(k,i2,...,jn)*b(k,j2,...jn)) forall k
2679 MEANINGFUL VARIABLES:
2680 lk = i1 = j1
2681 li = i2*...*in
2682 lj = j2*...*jn
2683 *)
2684 fun do_fold_first a b c lk lj li =
2685 let fun loopk (0, _, _, r, i) = Number.make(r,i)
2686 | loopk (k, ia, ib, r, i) =
2687 let val (ar, ai) = Array.sub(a,ia)
2688 val (br, bi) = Array.sub(b,ib)
2689 val dr = ar * br - ai * bi
2690 val di = ar * bi + ai * br
2691 in loopk (k-1, ia+1, ib+1, r+dr, i+di)
2692 end
2693 fun loopj (0, ib, ic) = c
2694 | loopj (j, ib, ic) =
2695 let fun loopi (0, ia, ic) = ic
2696 | loopi (i, ia, ic) =
2697 (Array.update(c, ic, loopk(lk, ia, ib, RNumber.zero, RNumber.zero));
2698 loopi(i-1, ia+lk, ic+1))
2699 in loopj(j-1, ib+lk, loopi(li, 0, ic))
2700 end
2701 in loopj(lj, 0, 0)
2702 end
2703 in
2704 fun +* ta tb =
2705 let val (rank_a,lk::rest_a,a) = (rank ta, shape ta, toArray ta)
2706 val (rank_b,lk2::rest_b,b) = (rank tb, shape tb, toArray tb)
2707 in if not(lk = lk2)
2708 then raise Match
2709 else let val li = Index.length rest_a
2710 val lj = Index.length rest_b
2711 val c = Array.array(li*lj,Number.zero)
2712 in fromArray(rest_a @ rest_b, do_fold_first a b c lk li lj)
2713 end
2714 end
2715 end
2716 local
2717 (*
2718 LAST INDEX CONTRACTION:
2719 a = a(i1,i2,...,in)
2720 b = b(j1,j2,...,jn)
2721 c = c(i2,...,in,j2,...,jn)
2722 = sum(mult(a(i1,i2,...,k),b(j1,j2,...,k))) forall k
2723 MEANINGFUL VARIABLES:
2724 lk = in = jn
2725 li = i1*...*i(n-1)
2726 lj = j1*...*j(n-1)
2727 *)
2728 fun do_fold_last a b c lk lj li =
2729 let fun loopi(0, _, _, _, _) = ()
2730 | loopi(i, ia, ic, br, bi) =
2731 let val (cr,ci) = Array.sub(c,ic)
2732 val (ar,ai) = Array.sub(a,ia)
2733 val dr = (ar * br - ai * bi)
2734 val di = (ar * bi + ai * br)
2735 in
2736 Array.update(c,ic,Number.make(cr+dr,ci+di));
2737 loopi(i-1, ia+1, ic+1, br, bi)
2738 end
2739 fun loopj(j, ib, ic) =
2740 let fun loopk(0, _, _) = ()
2741 | loopk(k, ia, ib) =
2742 let val (br, bi) = Array.sub(b,ib)
2743 in
2744 loopi(li, ia, ic, br, bi);
2745 loopk(k-1, ia+li, ib+lj)
2746 end
2747 in case j of
2748 0 => c
2749 | _ => (loopk(lk, 0, ib);
2750 loopj(j-1, ib+1, ic+li))
2751 end (* loopj *)
2752 in
2753 loopj(lj, 0, 0)
2754 end
2755 in
2756 fun *+ ta tb =
2757 let val (rank_a,shape_a,a) = (rank ta, shape ta, toArray ta)
2758 val (rank_b,shape_b,b) = (rank tb, shape tb, toArray tb)
2759 val (lk::rest_a) = List.rev shape_a
2760 val (lk2::rest_b) = List.rev shape_b
2761 in
2762 if not(lk = lk2) then
2763 raise Match
2764 else
2765 let val li = Index.length rest_a
2766 val lj = Index.length rest_b
2767 val c = Array.array(li*lj,Number.zero)
2768 in
2769 fromArray(List.rev rest_a @ List.rev rest_b,
2770 do_fold_last a b c lk li lj)
2771 end
2772 end
2773 end
2774 (* ALGEBRAIC OPERATIONS *)
2775 infix **
2776 infix ==
2777 infix !=
2778 fun a + b = map2 Number.+ a b
2779 fun a - b = map2 Number.- a b
2780 fun a * b = map2 Number.* a b
2781 fun a ** i = map (fn x => (Number.**(x,i))) a
2782 fun ~ a = map Number.~ a
2783 fun abs a = map Number.abs a
2784 fun signum a = map Number.signum a
2785 fun a == b = map2' Number.== a b
2786 fun a != b = map2' Number.!= a b
2787 fun toString a = raise Domain
2788 fun fromInt a = new([1], Number.fromInt a)
2789 (* TENSOR SPECIFIC OPERATIONS *)
2790 fun *> n = map (fn x => Number.*(n,x))
2791 fun print t =
2792 (PrettyPrint.intList (shape t);
2793 TextIO.print "\n";
2794 PrettyPrint.sequence (length t) appi Number.toString t)
2795 fun a / b = map2 Number./ a b
2796 fun recip a = map Number.recip a
2797 fun ln a = map Number.ln a
2798 fun pow (a, b) = map (fn x => (Number.pow(x,b))) a
2799 fun exp a = map Number.exp a
2800 fun sqrt a = map Number.sqrt a
2801 fun cos a = map Number.cos a
2802 fun sin a = map Number.sin a
2803 fun tan a = map Number.tan a
2804 fun sinh a = map Number.sinh a
2805 fun cosh a = map Number.cosh a
2806 fun tanh a = map Number.tanh a
2807 fun asin a = map Number.asin a
2808 fun acos a = map Number.acos a
2809 fun atan a = map Number.atan a
2810 fun asinh a = map Number.asinh a
2811 fun acosh a = map Number.acosh a
2812 fun atanh a = map Number.atanh a
2813 fun atan2 (a,b) = map2 Number.atan2 a b
2814 fun normInf a =
2815 let fun accum (y,x) = RNumber.max(x, Number.realPart(Number.abs y))
2816 in foldl accum RNumber.zero a
2817 end
2818 fun norm1 a =
2819 let fun accum (y,x) = RNumber.+(x, Number.realPart(Number.abs y))
2820 in foldl accum RNumber.zero a
2821 end
2822 fun norm2 a =
2823 let fun accum (y,x) = RNumber.+(x, Number.abs2 y)
2824 in RNumber.sqrt(foldl accum RNumber.zero a)
2825 end
2826 end (* CTensor *)
2827 structure MathFile =
2828 struct
2829
2830 type file = TextIO.instream
2831
2832 exception Data
2833
2834 fun assert NONE = raise Data
2835 | assert (SOME a) = a
2836
2837 (* ------------------ INPUT --------------------- *)
2838
2839 fun intRead file = assert(TextIO.scanStream INumber.scan file)
2840 fun realRead file = assert(TextIO.scanStream RNumber.scan file)
2841 fun complexRead file = assert(TextIO.scanStream CNumber.scan file)
2842
2843 fun listRead eltScan file =
2844 let val length = intRead file
2845 fun eltRead file = assert(TextIO.scanStream eltScan file)
2846 fun loop (0,accum) = accum
2847 | loop (i,accum) = loop(i-1, eltRead file :: accum)
2848 in
2849 if length < 0
2850 then raise Data
2851 else List.rev(loop(length,[]))
2852 end
2853
2854 fun intListRead file = listRead INumber.scan file
2855 fun realListRead file = listRead RNumber.scan file
2856 fun complexListRead file = listRead CNumber.scan file
2857
2858 fun intTensorRead file =
2859 let val shape = intListRead file
2860 val length = Index.length shape
2861 val first = intRead file
2862 val a = ITensor.Array.array(length, first)
2863 fun loop 0 = ITensor.fromArray(shape, a)
2864 | loop j = (ITensor.Array.update(a, length-j, intRead file);
2865 loop (j-1))
2866 in loop (length - 1)
2867 end
2868
2869 fun realTensorRead file =
2870 let val shape = intListRead file
2871 val length = Index.length shape
2872 val first = realRead file
2873 val a = RTensor.Array.array(length, first)
2874 fun loop 0 = RTensor.fromArray(shape, a)
2875 | loop j = (RTensor.Array.update(a, length-j, realRead file);
2876 loop (j-1))
2877 in loop (length - 1)
2878 end
2879
2880 fun complexTensorRead file =
2881 let val shape = intListRead file
2882 val length = Index.length shape
2883 val first = complexRead file
2884 val a = CTensor.Array.array(length, first)
2885 fun loop j = if j = length
2886 then CTensor.fromArray(shape, a)
2887 else (CTensor.Array.update(a, j, complexRead file);
2888 loop (j+1))
2889 in loop 1
2890 end
2891
2892 (* ------------------ OUTPUT -------------------- *)
2893 fun linedOutput(file, x) = (TextIO.output(file, x); TextIO.output(file, "\n"))
2894
2895 fun intWrite file x = linedOutput(file, INumber.toString x)
2896 fun realWrite file x = linedOutput(file, RNumber.toString x)
2897 fun complexWrite file x =
2898 let val (r,i) = CNumber.split x
2899 in linedOutput(file, concat [RNumber.toString r, " ", RNumber.toString i])
2900 end
2901
2902 fun listWrite converter file x =
2903 (intWrite file (length x);
2904 List.app (fn x => (linedOutput(file, converter x))) x)
2905
2906 fun intListWrite file x = listWrite INumber.toString file x
2907 fun realListWrite file x = listWrite RNumber.toString file x
2908 fun complexListWrite file x = listWrite CNumber.toString file x
2909
2910 fun intTensorWrite file x = (intListWrite file (ITensor.shape x); ITensor.app (fn x => (intWrite file x)) x)
2911 fun realTensorWrite file x = (intListWrite file (RTensor.shape x); RTensor.app (fn x => (realWrite file x)) x)
2912 fun complexTensorWrite file x = (intListWrite file (CTensor.shape x); CTensor.app (fn x => (complexWrite file x)) x)
2913 end
2914
2915 fun loop 0 _ = ()
2916 | loop n f = (f(); loop (n-1) f)
2917
2918 fun test_operator new list_op list_sizes =
2919 let fun test_many list_op size =
2920 let fun test_op (times,f) =
2921 let val a = new size
2922 in (EvalTimer.timerOn();
2923 loop times (fn _ => f(a,a));
2924 let val t = LargeInt.toInt(EvalTimer.timerRead()) div times
2925 val i = StringCvt.padLeft #" " 6 (Int.toString t)
2926 in print i
2927 end)
2928 end
2929 in
2930 print (Int.toString size);
2931 print " ";
2932 List.app test_op list_op;
2933 print "\n"
2934 end
2935 in List.app (test_many list_op) list_sizes
2936 end
2937
2938 structure Main =
2939 struct
2940 fun one() =
2941 let
2942 val _ =
2943 let val operators = [(20, RTensor.+), (20, RTensor.* ), (20, RTensor./),
2944 (4, fn (a,b) => RTensor.+* a b),
2945 (4, fn (a,b) => RTensor.*+ a b)]
2946 fun constructor size = RTensor.new([size,size],1.0)
2947 in
2948 print "Real tensors: (+, *, /, +*, *+)\n";
2949 test_operator constructor operators [100,200,300,400,500];
2950 print "\n\n"
2951 end
2952
2953 val _ =
2954 let val operators = [(20, CTensor.+), (20, CTensor.* ), (20, CTensor./),
2955 (4, fn (a,b) => CTensor.+* a b),
2956 (4, fn (a,b) => CTensor.*+ a b)]
2957 fun constructor size = CTensor.new([size,size],CNumber.one)
2958 in
2959 print "Real tensors: (+, *, /, +*, *+)\n";
2960 test_operator constructor operators [100,200,300,400,500];
2961 print "\n\n"
2962 end
2963 in ()
2964 end
2965
2966 fun doit n =
2967 if n = 0
2968 then ()
2969 else (one ()
2970 ; doit (n - 1))
2971 end