Import Debian changes 20180207-1
[hcoop/debian/mlton.git] / basis-library / arrays-and-vectors / array2.sml
1 (* Copyright (C) 2017 Matthew Fluet.
2 * Copyright (C) 1999-2008 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 * Copyright (C) 1997-2000 NEC Research Institute.
5 *
6 * MLton is released under a BSD-style license.
7 * See the file MLton-LICENSE for details.
8 *)
9
10 structure Array2 : ARRAY2 =
11 struct
12
13 val op +? = SeqIndex.+?
14 val op + = SeqIndex.+
15 val op -? = SeqIndex.-?
16 val op - = SeqIndex.-
17 val op *? = SeqIndex.*?
18 val op * = SeqIndex.*
19 val op < = SeqIndex.<
20 val op <= = SeqIndex.<=
21 val op > = SeqIndex.>
22 val op >= = SeqIndex.>=
23 val ltu = SeqIndex.ltu
24 val leu = SeqIndex.leu
25 val gtu = SeqIndex.gtu
26 val geu = SeqIndex.geu
27
28 type 'a array = {array: 'a Array.array,
29 rows: SeqIndex.int,
30 cols: SeqIndex.int}
31
32 fun dimensions' ({rows, cols, ...}: 'a array) = (rows, cols)
33 fun dimensions ({rows, cols, ...}: 'a array) =
34 (SeqIndex.toIntUnsafe rows, SeqIndex.toIntUnsafe cols)
35 fun nRows' ({rows, ...}: 'a array) = rows
36 fun nRows ({rows, ...}: 'a array) = SeqIndex.toIntUnsafe rows
37 fun nCols' ({cols, ...}: 'a array) = cols
38 fun nCols ({cols, ...}: 'a array) = SeqIndex.toIntUnsafe cols
39
40 type 'a region = {base: 'a array,
41 row: int,
42 col: int,
43 nrows: int option,
44 ncols: int option}
45
46 local
47 fun checkSliceMax' (start: int,
48 num: SeqIndex.int option,
49 max: SeqIndex.int): SeqIndex.int * SeqIndex.int =
50 case num of
51 NONE => if Primitive.Controls.safe
52 then let
53 val start =
54 (SeqIndex.fromInt start)
55 handle Overflow => raise Subscript
56 in
57 if gtu (start, max)
58 then raise Subscript
59 else (start, max)
60 end
61 else (SeqIndex.fromIntUnsafe start, max)
62 | SOME num => if Primitive.Controls.safe
63 then let
64 val start =
65 (SeqIndex.fromInt start)
66 handle Overflow => raise Subscript
67 in
68 if (start < 0 orelse num < 0
69 orelse start +? num > max)
70 then raise Subscript
71 else (start, start +? num)
72 end
73 else (SeqIndex.fromIntUnsafe start,
74 SeqIndex.fromIntUnsafe start +? num)
75 fun checkSliceMax (start: int,
76 num: int option,
77 max: SeqIndex.int): SeqIndex.int * SeqIndex.int =
78 if Primitive.Controls.safe
79 then (checkSliceMax' (start, Option.map SeqIndex.fromInt num, max))
80 handle Overflow => raise Subscript
81 else checkSliceMax' (start, Option.map SeqIndex.fromIntUnsafe num, max)
82 in
83 fun checkRegion' {base, row, col, nrows, ncols} =
84 let
85 val (rows, cols) = dimensions' base
86 val (startRow, stopRow) = checkSliceMax' (row, nrows, rows)
87 val (startCol, stopCol) = checkSliceMax' (col, ncols, cols)
88 in
89 {startRow = startRow, stopRow = stopRow,
90 startCol = startCol, stopCol = stopCol}
91 end
92 fun checkRegion {base, row, col, nrows, ncols} =
93 let
94 val (rows, cols) = dimensions' base
95 val (startRow, stopRow) = checkSliceMax (row, nrows, rows)
96 val (startCol, stopCol) = checkSliceMax (col, ncols, cols)
97 in
98 {startRow = startRow, stopRow = stopRow,
99 startCol = startCol, stopCol = stopCol}
100 end
101 end
102
103 fun wholeRegion (a : 'a array): 'a region =
104 {base = a, row = 0, col = 0, nrows = NONE, ncols = NONE}
105
106 datatype traversal = RowMajor | ColMajor
107
108 local
109 fun make (rows, cols, doit) =
110 if Primitive.Controls.safe
111 andalso (rows < 0 orelse cols < 0)
112 then raise Size
113 else {array = doit (rows * cols handle Overflow => raise Size),
114 rows = rows,
115 cols = cols}
116 in
117 fun alloc' (rows, cols) =
118 make (rows, cols, Primitive.Array.alloc)
119 fun array' (rows, cols, init) =
120 make (rows, cols, fn size => Primitive.Array.new (size, init))
121 end
122 local
123 fun make (rows, cols, doit) =
124 if Primitive.Controls.safe
125 then let
126 val rows =
127 (SeqIndex.fromInt rows)
128 handle Overflow => raise Size
129 val cols =
130 (SeqIndex.fromInt cols)
131 handle Overflow => raise Size
132 in
133 doit (rows, cols)
134 end
135 else doit (SeqIndex.fromIntUnsafe rows,
136 SeqIndex.fromIntUnsafe cols)
137 in
138 fun alloc (rows, cols) =
139 make (rows, cols, fn (rows, cols) => alloc' (rows, cols))
140 fun array (rows, cols, init) =
141 make (rows, cols, fn (rows, cols) => array' (rows, cols, init))
142 end
143
144 fun array0 (): 'a array =
145 {array = Primitive.Array.alloc 0,
146 rows = 0,
147 cols = 0}
148
149 fun unsafeSpot' ({cols, ...}: 'a array, r, c) =
150 r *? cols +? c
151 fun spot' (a as {rows, cols, ...}: 'a array, r, c) =
152 if Primitive.Controls.safe
153 andalso (geu (r, rows) orelse geu (c, cols))
154 then raise Subscript
155 else unsafeSpot' (a, r, c)
156
157 fun unsafeSub' (a as {array, ...}: 'a array, r, c) =
158 Primitive.Array.unsafeSub (array, unsafeSpot' (a, r, c))
159 fun sub' (a as {array, ...}: 'a array, r, c) =
160 Primitive.Array.unsafeSub (array, spot' (a, r, c))
161 fun unsafeUpdate' (a as {array, ...}: 'a array, r, c, x) =
162 Primitive.Array.unsafeUpdate (array, unsafeSpot' (a, r, c), x)
163 fun update' (a as {array, ...}: 'a array, r, c, x) =
164 Primitive.Array.unsafeUpdate (array, spot' (a, r, c), x)
165
166 local
167 fun make (r, c, doit) =
168 if Primitive.Controls.safe
169 then let
170 val r =
171 (SeqIndex.fromInt r)
172 handle Overflow => raise Subscript
173 val c =
174 (SeqIndex.fromInt c)
175 handle Overflow => raise Subscript
176 in
177 doit (r, c)
178 end
179 else doit (SeqIndex.fromIntUnsafe r,
180 SeqIndex.fromIntUnsafe c)
181 in
182 fun sub (a, r, c) =
183 make (r, c, fn (r, c) => sub' (a, r, c))
184 fun update (a, r, c, x) =
185 make (r, c, fn (r, c) => update' (a, r, c, x))
186 end
187
188 fun 'a fromList (rows: 'a list list): 'a array =
189 case rows of
190 [] => array0 ()
191 | row1 :: _ =>
192 let
193 val cols = length row1
194 val a as {array, cols = cols', ...} =
195 alloc (length rows, cols)
196 val _ =
197 List.foldl
198 (fn (row: 'a list, i) =>
199 let
200 val max = i +? cols'
201 val i' =
202 List.foldl (fn (x: 'a, i) =>
203 (if i >= max
204 then raise Size
205 else (Primitive.Array.unsafeUpdate (array, i, x)
206 ; i +? 1)))
207 i row
208 in if i' = max
209 then i'
210 else raise Size
211 end)
212 0 rows
213 in
214 a
215 end
216
217 fun row' ({array, rows, cols}, r) =
218 if Primitive.Controls.safe andalso geu (r, rows)
219 then raise Subscript
220 else
221 ArraySlice.vector (Primitive.Array.Slice.slice (array, r *? cols, SOME cols))
222 fun row (a, r) =
223 if Primitive.Controls.safe
224 then let
225 val r =
226 (SeqIndex.fromInt r)
227 handle Overflow => raise Subscript
228 in
229 row' (a, r)
230 end
231 else row' (a, SeqIndex.fromIntUnsafe r)
232 fun column' (a as {rows, cols, ...}: 'a array, c) =
233 if Primitive.Controls.safe andalso geu (c, cols)
234 then raise Subscript
235 else
236 Primitive.Vector.tabulate (rows, fn r => unsafeSub' (a, r, c))
237 fun column (a, c) =
238 if Primitive.Controls.safe
239 then let
240 val c =
241 (SeqIndex.fromInt c)
242 handle Overflow => raise Subscript
243 in
244 column' (a, c)
245 end
246 else column' (a, SeqIndex.fromIntUnsafe c)
247
248 fun foldi' trv f b (region as {base, ...}) =
249 let
250 val {startRow, stopRow, startCol, stopCol} = checkRegion region
251 in
252 case trv of
253 RowMajor =>
254 let
255 fun loopRow (r, b) =
256 if r >= stopRow then b
257 else let
258 fun loopCol (c, b) =
259 if c >= stopCol then b
260 else loopCol (c +? 1, f (r, c, sub' (base, r, c), b))
261 in
262 loopRow (r +? 1, loopCol (startCol, b))
263 end
264 in
265 loopRow (startRow, b)
266 end
267 | ColMajor =>
268 let
269 fun loopCol (c, b) =
270 if c >= stopCol then b
271 else let
272 fun loopRow (r, b) =
273 if r >= stopRow then b
274 else loopRow (r +? 1, f (r, c, sub' (base, r, c), b))
275 in
276 loopCol (c +? 1, loopRow (startRow, b))
277 end
278 in
279 loopCol (startCol, b)
280 end
281 end
282
283 fun foldi trv f b a =
284 foldi' trv (fn (r, c, x, b) =>
285 f (SeqIndex.toIntUnsafe r,
286 SeqIndex.toIntUnsafe c,
287 x, b)) b a
288 fun fold trv f b a =
289 foldi trv (fn (_, _, x, b) => f (x, b)) b (wholeRegion a)
290
291 fun appi trv f =
292 foldi trv (fn (r, c, x, ()) => f (r, c, x)) ()
293
294 fun app trv f = fold trv (f o #1) ()
295
296 fun modifyi trv f (r as {base, ...}) =
297 appi trv (fn (r, c, x) => update (base, r, c, f (r, c, x))) r
298
299 fun modify trv f a = modifyi trv (f o #3) (wholeRegion a)
300
301 fun tabulate trv (rows, cols, f) =
302 let
303 val a = alloc (rows, cols)
304 val () = modifyi trv (fn (r, c, _) => f (r, c)) (wholeRegion a)
305 in
306 a
307 end
308
309 fun copy {src = src as {base, ...}: 'a region,
310 dst, dst_row, dst_col} =
311 let
312 val {startRow, stopRow, startCol, stopCol} = checkRegion src
313 val nrows = stopRow -? startRow
314 val ncols = stopCol -? startCol
315 val {startRow = dst_row, startCol = dst_col, ...} =
316 checkRegion' {base = dst, row = dst_row, col = dst_col,
317 nrows = SOME nrows,
318 ncols = SOME ncols}
319 fun forUp (start, stop, f: SeqIndex.int -> unit) =
320 let
321 fun loop i =
322 if i >= stop
323 then ()
324 else (f i; loop (i + 1))
325 in loop start
326 end
327 fun forDown (start, stop, f: SeqIndex.int -> unit) =
328 let
329 fun loop i =
330 if i < start
331 then ()
332 else (f i; loop (i - 1))
333 in loop (stop -? 1)
334 end
335 val forRows = if startRow <= dst_row then forDown else forUp
336 val forCols = if startCol <= dst_col then forUp else forDown
337 in forRows (0, nrows, fn r =>
338 forCols (0, ncols, fn c =>
339 unsafeUpdate' (dst, dst_row +? r, dst_col +? c,
340 unsafeSub' (base, startRow +? r, startCol +? c))))
341 end
342 end