1 (* Copyright (C
) 2017 Matthew Fluet
.
2 * Copyright (C
) 1999-2008 Henry Cejtin
, Matthew Fluet
, Suresh
3 * Jagannathan
, and Stephen Weeks
.
4 * Copyright (C
) 1997-2000 NEC Research Institute
.
6 * MLton is released under a BSD
-style license
.
7 * See the file MLton
-LICENSE for details
.
10 structure Array2
: ARRAY2
=
13 val op +?
= SeqIndex
.+?
15 val op -?
= SeqIndex
.-?
17 val op *?
= SeqIndex
.*?
20 val op <= = SeqIndex
.<=
22 val op >= = SeqIndex
.>=
23 val ltu
= SeqIndex
.ltu
24 val leu
= SeqIndex
.leu
25 val gtu
= SeqIndex
.gtu
26 val geu
= SeqIndex
.geu
28 type 'a array
= {array
: 'a Array
.array
,
32 fun dimensions
' ({rows
, cols
, ...}: 'a array
) = (rows
, cols
)
33 fun dimensions ({rows
, cols
, ...}: 'a array
) =
34 (SeqIndex
.toIntUnsafe rows
, SeqIndex
.toIntUnsafe cols
)
35 fun nRows
' ({rows
, ...}: 'a array
) = rows
36 fun nRows ({rows
, ...}: 'a array
) = SeqIndex
.toIntUnsafe rows
37 fun nCols
' ({cols
, ...}: 'a array
) = cols
38 fun nCols ({cols
, ...}: 'a array
) = SeqIndex
.toIntUnsafe cols
40 type 'a region
= {base
: 'a array
,
47 fun checkSliceMax
' (start
: int,
48 num
: SeqIndex
.int option
,
49 max
: SeqIndex
.int): SeqIndex
.int * SeqIndex
.int =
51 NONE
=> if Primitive
.Controls
.safe
54 (SeqIndex
.fromInt start
)
55 handle Overflow
=> raise Subscript
61 else (SeqIndex
.fromIntUnsafe start
, max
)
62 | SOME num
=> if Primitive
.Controls
.safe
65 (SeqIndex
.fromInt start
)
66 handle Overflow
=> raise Subscript
68 if (start
< 0 orelse num
< 0
69 orelse start
+? num
> max
)
71 else (start
, start
+? num
)
73 else (SeqIndex
.fromIntUnsafe start
,
74 SeqIndex
.fromIntUnsafe start
+? num
)
75 fun checkSliceMax (start
: int,
77 max
: SeqIndex
.int): SeqIndex
.int * SeqIndex
.int =
78 if Primitive
.Controls
.safe
79 then (checkSliceMax
' (start
, Option
.map SeqIndex
.fromInt num
, max
))
80 handle Overflow
=> raise Subscript
81 else checkSliceMax
' (start
, Option
.map SeqIndex
.fromIntUnsafe num
, max
)
83 fun checkRegion
' {base
, row
, col
, nrows
, ncols
} =
85 val (rows
, cols
) = dimensions
' base
86 val (startRow
, stopRow
) = checkSliceMax
' (row
, nrows
, rows
)
87 val (startCol
, stopCol
) = checkSliceMax
' (col
, ncols
, cols
)
89 {startRow
= startRow
, stopRow
= stopRow
,
90 startCol
= startCol
, stopCol
= stopCol
}
92 fun checkRegion
{base
, row
, col
, nrows
, ncols
} =
94 val (rows
, cols
) = dimensions
' base
95 val (startRow
, stopRow
) = checkSliceMax (row
, nrows
, rows
)
96 val (startCol
, stopCol
) = checkSliceMax (col
, ncols
, cols
)
98 {startRow
= startRow
, stopRow
= stopRow
,
99 startCol
= startCol
, stopCol
= stopCol
}
103 fun wholeRegion (a
: 'a array
): 'a region
=
104 {base
= a
, row
= 0, col
= 0, nrows
= NONE
, ncols
= NONE
}
106 datatype traversal
= RowMajor | ColMajor
109 fun make (rows
, cols
, doit
) =
110 if Primitive
.Controls
.safe
111 andalso (rows
< 0 orelse cols
< 0)
113 else {array
= doit (rows
* cols
handle Overflow
=> raise Size
),
117 fun alloc
' (rows
, cols
) =
118 make (rows
, cols
, Primitive
.Array
.alloc
)
119 fun array
' (rows
, cols
, init
) =
120 make (rows
, cols
, fn size
=> Primitive
.Array
.new (size
, init
))
123 fun make (rows
, cols
, doit
) =
124 if Primitive
.Controls
.safe
127 (SeqIndex
.fromInt rows
)
128 handle Overflow
=> raise Size
130 (SeqIndex
.fromInt cols
)
131 handle Overflow
=> raise Size
135 else doit (SeqIndex
.fromIntUnsafe rows
,
136 SeqIndex
.fromIntUnsafe cols
)
138 fun alloc (rows
, cols
) =
139 make (rows
, cols
, fn (rows
, cols
) => alloc
' (rows
, cols
))
140 fun array (rows
, cols
, init
) =
141 make (rows
, cols
, fn (rows
, cols
) => array
' (rows
, cols
, init
))
144 fun array0 (): 'a array
=
145 {array
= Primitive
.Array
.alloc
0,
149 fun unsafeSpot
' ({cols
, ...}: 'a array
, r
, c
) =
151 fun spot
' (a
as {rows
, cols
, ...}: 'a array
, r
, c
) =
152 if Primitive
.Controls
.safe
153 andalso (geu (r
, rows
) orelse geu (c
, cols
))
155 else unsafeSpot
' (a
, r
, c
)
157 fun unsafeSub
' (a
as {array
, ...}: 'a array
, r
, c
) =
158 Primitive
.Array
.unsafeSub (array
, unsafeSpot
' (a
, r
, c
))
159 fun sub
' (a
as {array
, ...}: 'a array
, r
, c
) =
160 Primitive
.Array
.unsafeSub (array
, spot
' (a
, r
, c
))
161 fun unsafeUpdate
' (a
as {array
, ...}: 'a array
, r
, c
, x
) =
162 Primitive
.Array
.unsafeUpdate (array
, unsafeSpot
' (a
, r
, c
), x
)
163 fun update
' (a
as {array
, ...}: 'a array
, r
, c
, x
) =
164 Primitive
.Array
.unsafeUpdate (array
, spot
' (a
, r
, c
), x
)
167 fun make (r
, c
, doit
) =
168 if Primitive
.Controls
.safe
172 handle Overflow
=> raise Subscript
175 handle Overflow
=> raise Subscript
179 else doit (SeqIndex
.fromIntUnsafe r
,
180 SeqIndex
.fromIntUnsafe c
)
183 make (r
, c
, fn (r
, c
) => sub
' (a
, r
, c
))
184 fun update (a
, r
, c
, x
) =
185 make (r
, c
, fn (r
, c
) => update
' (a
, r
, c
, x
))
188 fun 'a
fromList (rows
: 'a list list
): 'a array
=
193 val cols
= length row1
194 val a
as {array
, cols
= cols
', ...} =
195 alloc (length rows
, cols
)
198 (fn (row
: 'a list
, i
) =>
202 List.foldl (fn (x
: 'a
, i
) =>
205 else (Primitive
.Array
.unsafeUpdate (array
, i
, x
)
217 fun row
' ({array
, rows
, cols
}, r
) =
218 if Primitive
.Controls
.safe
andalso geu (r
, rows
)
221 ArraySlice
.vector (Primitive
.Array
.Slice
.slice (array
, r
*? cols
, SOME cols
))
223 if Primitive
.Controls
.safe
227 handle Overflow
=> raise Subscript
231 else row
' (a
, SeqIndex
.fromIntUnsafe r
)
232 fun column
' (a
as {rows
, cols
, ...}: 'a array
, c
) =
233 if Primitive
.Controls
.safe
andalso geu (c
, cols
)
236 Primitive
.Vector.tabulate (rows
, fn r
=> unsafeSub
' (a
, r
, c
))
238 if Primitive
.Controls
.safe
242 handle Overflow
=> raise Subscript
246 else column
' (a
, SeqIndex
.fromIntUnsafe c
)
248 fun foldi
' trv f
b (region
as {base
, ...}) =
250 val {startRow
, stopRow
, startCol
, stopCol
} = checkRegion region
256 if r
>= stopRow
then b
259 if c
>= stopCol
then b
260 else loopCol (c
+?
1, f (r
, c
, sub
' (base
, r
, c
), b
))
262 loopRow (r
+?
1, loopCol (startCol
, b
))
265 loopRow (startRow
, b
)
270 if c
>= stopCol
then b
273 if r
>= stopRow
then b
274 else loopRow (r
+?
1, f (r
, c
, sub
' (base
, r
, c
), b
))
276 loopCol (c
+?
1, loopRow (startRow
, b
))
279 loopCol (startCol
, b
)
283 fun foldi trv f b a
=
284 foldi
' trv (fn (r
, c
, x
, b
) =>
285 f (SeqIndex
.toIntUnsafe r
,
286 SeqIndex
.toIntUnsafe c
,
289 foldi
trv (fn (_
, _
, x
, b
) => f (x
, b
)) b (wholeRegion a
)
292 foldi
trv (fn (r
, c
, x
, ()) => f (r
, c
, x
)) ()
294 fun app trv f
= fold
trv (f
o #
1) ()
296 fun modifyi trv
f (r
as {base
, ...}) =
297 appi
trv (fn (r
, c
, x
) => update (base
, r
, c
, f (r
, c
, x
))) r
299 fun modify trv f a
= modifyi
trv (f
o #
3) (wholeRegion a
)
301 fun tabulate
trv (rows
, cols
, f
) =
303 val a
= alloc (rows
, cols
)
304 val () = modifyi
trv (fn (r
, c
, _
) => f (r
, c
)) (wholeRegion a
)
309 fun copy
{src
= src
as {base
, ...}: 'a region
,
310 dst
, dst_row
, dst_col
} =
312 val {startRow
, stopRow
, startCol
, stopCol
} = checkRegion src
313 val nrows
= stopRow
-? startRow
314 val ncols
= stopCol
-? startCol
315 val {startRow
= dst_row
, startCol
= dst_col
, ...} =
316 checkRegion
' {base
= dst
, row
= dst_row
, col
= dst_col
,
319 fun forUp (start
, stop
, f
: SeqIndex
.int -> unit
) =
324 else (f i
; loop (i
+ 1))
327 fun forDown (start
, stop
, f
: SeqIndex
.int -> unit
) =
332 else (f i
; loop (i
- 1))
335 val forRows
= if startRow
<= dst_row
then forDown
else forUp
336 val forCols
= if startCol
<= dst_col
then forUp
else forDown
337 in forRows (0, nrows
, fn r
=>
338 forCols (0, ncols
, fn c
=>
339 unsafeUpdate
' (dst
, dst_row
+? r
, dst_col
+? c
,
340 unsafeSub
' (base
, startRow
+? r
, startCol
+? c
))))