Import Upstream version 20180207
[hcoop/debian/mlton.git] / benchmark / tests / matrix-multiply.sml
CommitLineData
7f918cf1
CE
1(* Written by Stephen Weeks (sweeks@sweeks.com). *)
2structure Array = Array2
3
4fun 'a fold (n : int, b : 'a, f : int * 'a -> 'a) =
5 let
6 fun loop (i : int, b : 'a) : 'a =
7 if i = n
8 then b
9 else loop (i + 1, f (i, b))
10 in loop (0, b)
11 end
12
13fun foreach (n : int, f : int -> unit) : unit =
14 fold (n, (), f o #1)
15
16fun mult (a1 : real Array.array, a2 : real Array.array) : real Array.array =
17 let
18 val r1 = Array.nRows a1
19 val c1 = Array.nCols a1
20 val r2 = Array.nRows a2
21 val c2 = Array.nCols a2
22 in if c1 <> r2
23 then raise Fail "mult"
24 else
25 let val a = Array2.array (r1, c2, 0.0)
26 fun dot (r, c) =
27 fold (c1, 0.0, fn (i, sum) =>
28 sum + Array.sub (a1, r, i) * Array.sub (a2, i, c))
29 in foreach (r1, fn r =>
30 foreach (c2, fn c =>
31 Array.update (a, r, c, dot (r,c))));
32 a
33 end
34 end
35
36structure Main =
37 struct
38 fun doit () =
39 let
40 val dim = 500
41 val a = Array.tabulate Array.RowMajor (dim, dim, fn (r, c) =>
42 Real.fromInt (r + c))
43 in
44 if Real.== (41541750.0, Array2.sub (mult (a, a), 0, 0))
45 then ()
46 else raise Fail "bug"
47 end
48
49 val doit =
50 fn size =>
51 let
52 fun loop n =
53 if n = 0
54 then ()
55 else (doit ();
56 loop (n-1))
57 in loop size
58 end
59 end