Import Upstream version 20180207
[hcoop/debian/mlton.git] / lib / mlton / basic / vector.fun
1 (* Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 *
4 * MLton is released under a BSD-style license.
5 * See the file MLton-LICENSE for details.
6 *)
7
8 functor Vector (S: sig
9 include VECTOR_STRUCTS
10 val unsafeSub: 'a t * int -> 'a
11 end): VECTOR =
12 struct
13
14 open S
15
16 val size = length
17
18 fun unfold (n, a, f) = unfoldi (n, a, f o #2)
19
20 fun tabulate (n, f) = #1 (unfoldi (n, (), fn (i, ()) => (f i, ())))
21
22 fun fromArray a =
23 tabulate (Pervasive.Array.length a, fn i => Pervasive.Array.sub (a, i))
24
25 fun toArray v =
26 Pervasive.Array.tabulate (length v, fn i => sub (v, i))
27
28 datatype ('a, 'b) continue =
29 Continue of 'a
30 | Done of 'b
31
32 fun first v =
33 let
34 val n = length v
35 in
36 if n = 0
37 then Error.bug "Vector.first"
38 else unsafeSub (v, 0)
39 end
40
41 fun fold' (v, start, b, f, g) =
42 let
43 val n = length v
44 fun loop (i, b) =
45 if i >= n
46 then g b
47 else
48 case f (i, unsafeSub (v, i), b) of
49 Continue b => loop (i + 1, b)
50 | Done c => c
51 in
52 if 0 <= start andalso start <= n
53 then loop (start, b)
54 else Error.bug "Vector.fold'"
55 end
56
57 fun foldFrom (v, start, b, f) =
58 fold' (v, start, b,
59 fn (_, a, b) => Continue (f (a, b)),
60 fn b => b)
61
62 fun fold (a, b, f) = foldFrom (a, 0, b, f)
63
64 fun isEmpty a = 0 = length a
65
66 fun dropPrefix (v, n) = tabulate (length v - n, fn i => sub (v, i + n))
67
68 fun dropSuffix (v, n) = tabulate (length v - n, fn i => sub (v, i))
69
70 fun new (n, x) = tabulate (n, fn _ => x)
71
72 fun mapi (a, f) = tabulate (length a, fn i => f (i, unsafeSub (a, i)))
73
74 fun map (v, f) = mapi (v, f o #2)
75
76 fun copy v = map (v, fn x => x)
77
78 fun existsR (v, start, stop, f) =
79 fold' (v, start, (),
80 fn (i, a, ()) => if i = stop
81 then Done false
82 else if f a
83 then Done true
84 else Continue (),
85 fn _ => false)
86
87 fun foldi (v, b, f) = fold' (v, 0, b, Continue o f, fn b => b)
88
89 fun loopi (v, f, g) =
90 fold' (v, 0, (),
91 fn (i, a, ()) => (case f (i, a) of
92 NONE => Continue ()
93 | SOME b => Done b),
94 g)
95
96 fun loop (v, f, g) = loopi (v, f o #2, g)
97
98 fun peekMapi (v, f) =
99 let
100 val n = length v
101 fun loop i =
102 if i = n
103 then NONE
104 else
105 (case f (sub (v, i)) of
106 NONE => loop (i + 1)
107 | SOME b => SOME (i, b))
108 in
109 loop 0
110 end
111
112 fun peekMap (v, f) =
113 loop (v,
114 fn a => (case f a of
115 NONE => NONE
116 | z => SOME z),
117 fn () => NONE)
118
119 fun fromListMap (l, f) =
120 let
121 val r = ref l
122 in
123 tabulate (List.length l, fn _ =>
124 case !r of
125 [] => Error.bug "Vector.fromListMap"
126 | x :: l => (r := l; f x))
127 end
128
129 fun fromList l = fromListMap (l, fn x => x)
130
131 fun foldr2 (a, a', b, f) =
132 let
133 val n = length a
134 val n' = length a'
135 fun loop (i, b) =
136 if i < 0
137 then b
138 else loop (i - 1, f (unsafeSub (a, i), unsafeSub (a', i), b))
139 in
140 if n = n'
141 then loop (n - 1, b)
142 else Error.bug "Vector.foldr2"
143 end
144
145 fun foldi2From (a, a', start, b, f) =
146 let
147 val n = length a
148 val n' = length a'
149 fun loop (i, b) =
150 if i >= n
151 then b
152 else loop (i + 1, f (i, unsafeSub (a, i), unsafeSub (a', i), b))
153 in
154 if n = n' andalso 0 <= start andalso start <= n
155 then loop (start, b)
156 else Error.bug "Vector.foldi2From"
157 end
158
159 fun foldi2 (a, a', b, f) = foldi2From (a, a', 0, b, f)
160
161 fun foreachi2 (v, v', f) =
162 foldi2 (v, v', (), fn (i, x, x', ()) => f (i, x, x'))
163
164 fun fold2 (a, a', b, f) =
165 foldi2 (a, a', b, fn (_, x, x', b) => f (x, x', b))
166
167 fun fold3From (a, a', a'', start, b, f) =
168 let
169 val n = length a
170 val n' = length a'
171 val n'' = length a''
172 fun loop (i, b) =
173 if i >= n
174 then b
175 else loop (i + 1, f (unsafeSub (a, i),
176 unsafeSub (a', i),
177 unsafeSub (a'', i),
178 b))
179 in
180 if n = n' andalso n = n'' andalso 0 <= start andalso start <= n
181 then loop (start, b)
182 else Error.bug "Vector.fold3From"
183 end
184
185 fun fold3 (a, a', a'', b, f) = fold3From (a, a', a'', 0, b, f)
186
187 fun foreachR (v, start, stop, f: 'a -> unit) =
188 if 0 <= start andalso start <= stop andalso stop <= length v
189 then
190 let
191 fun step (i, a, ()) =
192 if i >= stop
193 then Done ()
194 else (f a; Continue ())
195 in
196 fold' (v, start, (), step, fn () => ())
197 end
198 else Error.bug "Vector.foreachR"
199
200 fun foreach2 (a, a', f) =
201 fold2 (a, a', (), fn (x, x', ()) => f (x, x'))
202
203 fun forall2 (v, v', f) =
204 let
205 val n = length v
206 fun loop i =
207 i = n
208 orelse (f (sub (v, i), sub (v', i))
209 andalso loop (i + 1))
210 in
211 if n = length v'
212 then loop 0
213 else Error.bug "Vector.forall2"
214 end
215
216 fun foreach3 (v1, v2, v3, f: 'a * 'b * 'c -> unit) =
217 let
218 val n = length v1
219 val _ =
220 if n = length v2 andalso n = length v3
221 then ()
222 else Error.bug "Vector.foreach3"
223 fun loop i =
224 if i = n
225 then ()
226 else (f (sub (v1, i), sub (v2, i), sub (v3, i))
227 ; loop (i + 1))
228 in
229 loop 0
230 end
231
232 fun foreachi (a, f) = foldi (a, (), fn (i, x, ()) => f (i, x))
233
234 fun foreach (a, f) = foreachi (a, f o #2)
235
236 fun 'a peeki (v, f) =
237 let
238 val n = length v
239 fun loop i =
240 if i = n
241 then NONE
242 else let
243 val x = sub (v, i)
244 in
245 if f (i, x)
246 then SOME (i, x)
247 else loop (i + 1)
248 end
249 in
250 loop 0
251 end
252
253 fun peek (a, f) = Option.map (peeki (a, f o #2), #2)
254
255 fun existsi (a, f) = isSome (peeki (a, f))
256
257 fun exists (a, f) = existsi (a, f o #2)
258
259 fun contains (v, a, f) = exists (v, fn a' => f (a, a'))
260
261 fun foralli (a, f) = not (existsi (a, not o f))
262
263 fun forall (a, f) = foralli (a, f o #2)
264
265 fun equals (a, a', equals) =
266 length a = length a'
267 andalso foralli (a, fn (i, x) => equals (x, unsafeSub (a', i)))
268
269 fun foldri (a, b, f) =
270 Int.foldDown (0, length a, b, fn (i, b) => f (i, unsafeSub (a, i), b))
271
272 fun foldr (a, b, f) =
273 foldri (a, b, fn (_, a, b) => f (a, b))
274
275 fun foreachri (a, f) = foldri (a, (), fn (i, x, ()) => f (i, x))
276
277 fun foreachr (a, f) = foreachri (a, f o #2)
278
279 fun toList a = foldr (a, [], op ::)
280
281 fun toListMap (a, f) = foldr (a, [], fn (a, ac) => f a :: ac)
282
283 fun layout l v = Layout.tuple (toListMap (v, l))
284
285 fun toString xToString l =
286 Layout.toString (layout (Layout.str o xToString) l)
287
288 fun new0 () = tabulate (0, fn _ => Error.bug "Vector.new0")
289
290 fun new1 x = tabulate (1, fn _ => x)
291
292 fun new2 (x0, x1) = tabulate (2, fn 0 => x0 | 1 => x1 | _ => Error.bug "Vector.new2")
293
294 fun new3 (x0, x1, x2) =
295 tabulate (3,
296 fn 0 => x0
297 | 1 => x1
298 | 2 => x2
299 | _ => Error.bug "Vector.new3")
300
301 fun new4 (x0, x1, x2, x3) =
302 tabulate (4,
303 fn 0 => x0
304 | 1 => x1
305 | 2 => x2
306 | 3 => x3
307 | _ => Error.bug "Vector.new4")
308
309 fun new5 (x0, x1, x2, x3, x4) =
310 tabulate (5,
311 fn 0 => x0
312 | 1 => x1
313 | 2 => x2
314 | 3 => x3
315 | 4 => x4
316 | _ => Error.bug "Vector.new5")
317
318 fun new6 (x0, x1, x2, x3, x4, x5) =
319 tabulate (6,
320 fn 0 => x0
321 | 1 => x1
322 | 2 => x2
323 | 3 => x3
324 | 4 => x4
325 | 5 => x5
326 | _ => Error.bug "Vector.new6")
327
328 fun unzip (a: ('a * 'b) t) = (map (a, #1), map (a, #2))
329
330 fun unzip3 (a: ('a * 'b * 'c) t) = (map (a, #1), map (a, #2), map (a, #3))
331
332 fun rev v =
333 let
334 val n = length v
335 val n1 = n - 1
336 in
337 tabulate (n, fn i => unsafeSub (v, n1 - i))
338 end
339
340 fun fromListRev l = rev (fromList l)
341
342 fun mapAndFold (v, b, f) =
343 let
344 val r = ref b
345 val v = map (v, fn x =>
346 let
347 val (c, b) = f (x, !r)
348 val _ = r := b
349 in c
350 end)
351 in (v, !r)
352 end
353
354 fun map2i (v, v', f) =
355 let
356 val n = length v
357 in
358 if n = length v'
359 then tabulate (n, fn i => f (i, unsafeSub (v, i), unsafeSub (v', i)))
360 else Error.bug "Vector.map2i"
361 end
362
363 fun map2 (v, v', f) = map2i (v, v', fn (_, x, x') => f (x, x'))
364
365 fun map2AndFold (v, v', b, f) =
366 let
367 val r = ref b
368 val v =
369 map2 (v, v', fn (x, x') =>
370 let
371 val (y, b) = f (x, x', !r)
372 val _ = r := b
373 in y
374 end)
375 in (v, !r)
376 end
377
378 fun map3 (v1, v2, v3, f) =
379 let
380 val n = length v1
381 in
382 if n = length v2 andalso n = length v3
383 then tabulate (n, fn i => f (unsafeSub (v1, i),
384 unsafeSub (v2, i),
385 unsafeSub (v3, i)))
386 else Error.bug "Vector.map3"
387 end
388
389 fun zip (v, v') = map2 (v, v', fn z => z)
390
391 local
392 fun doit (f, mapi) =
393 let
394 val n = ref 0
395 val b = mapi (fn x =>
396 let
397 val b = f x
398 val _ = if isSome b then n := 1 + !n else ()
399 in b
400 end)
401 val r = ref 0
402 fun loop (i: int) =
403 case unsafeSub (b, i) of
404 NONE => loop (i + 1)
405 | SOME b => (r := i + 1; b)
406 in tabulate (!n, fn _ => loop (!r))
407 end
408 in
409 fun keepAllMapi (a, f) = doit (f, fn f => mapi (a, f))
410 fun keepAllMap2i (a, b, f) = doit (f, fn f => map2i (a, b, f))
411 end
412
413 fun keepAllMap (v, f) = keepAllMapi (v, f o #2)
414
415 fun keepAllMap2 (v, v', f) = keepAllMap2i (v, v', fn (_, x, x') => f (x, x'))
416
417 fun keepAllSome v = keepAllMap (v, fn a => a)
418
419 fun keepAll (v, f) = keepAllMap (v, fn a => if f a then SOME a else NONE)
420
421 fun compare (v, v', comp) =
422 let
423 val n = length v
424 val n' = length v'
425 in
426 Relation.lexico
427 (Int.compare (n, n'), fn () =>
428 let
429 fun loop i =
430 if i = n
431 then EQUAL
432 else
433 Relation.lexico
434 (comp (unsafeSub (v, i), unsafeSub (v', i)), fn () =>
435 loop (i + 1))
436 in
437 loop 0
438 end)
439 end
440
441 fun toListRev v = fold (v, [], op ::)
442
443 fun last v =
444 let
445 val n = length v
446 in
447 if n = 0
448 then Error.bug "Vector.last"
449 else unsafeSub (v, n - 1)
450 end
451
452 fun tabulator (n: int, f: ('a -> unit) -> unit) =
453 let
454 val a = Pervasive.Array.array (n, NONE)
455 val r = ref 0
456 val _ =
457 f (fn x =>
458 let
459 val i = !r
460 in
461 if i >= n
462 then Error.bug "Vector.tabulator: too many elements"
463 else (Pervasive.Array.update (a, i, SOME x)
464 ; r := i + 1)
465 end)
466 in
467 if !r < n
468 then Error.bug "Vector.tabulator: not enough elements"
469 else tabulate (n, fn i => valOf (Pervasive.Array.sub (a, i)))
470 end
471
472 fun 'a concat (vs: 'a t list): 'a t =
473 case vs of
474 [] => new0 ()
475 | v :: vs' =>
476 let
477 val n = List.fold (vs, 0, fn (v, s) => s + length v)
478 in
479 #1 (unfold (n, (0, v, vs'),
480 let
481 fun loop (i, v, vs) =
482 if i < length v
483 then (sub (v, i), (i + 1, v, vs))
484 else
485 case vs of
486 [] => Error.bug "Vector.concat"
487 | v :: vs => loop (0, v, vs)
488 in loop
489 end))
490 end
491
492 fun concatV vs =
493 if 0 = length vs then
494 new0 ()
495 else
496 let
497 val n = fold (vs, 0, fn (v, s) => s + length v)
498 fun state i = (i, sub (vs, i), 0)
499 in
500 #1 (unfold (n, state 0,
501 let
502 fun loop (i, v, j) =
503 if j < length v then
504 (sub (v, j), (i, v, j + 1))
505 else
506 loop (state (i + 1))
507 in loop
508 end))
509 end
510
511 fun splitLast v =
512 let
513 val n = length v
514 in
515 if n <= 0
516 then Error.bug "Vector.splitLast"
517 else (tabulate (n - 1, fn i => unsafeSub (v, i)),
518 unsafeSub (v, n - 1))
519 end
520
521 fun isSortedRange (v: 'a t,
522 start: int,
523 stop: int,
524 le : 'a * 'a -> bool): bool =
525 (Assert.assert
526 ("Vector.isSortedRange", fn () =>
527 0 <= start andalso start <= stop andalso stop <= length v)
528 ; start = stop
529 orelse
530 let
531 fun loop (i, prev) =
532 i >= stop
533 orelse let val cur = sub (v, i)
534 in
535 le (prev, cur)
536 andalso loop (i + 1, cur)
537 end
538 in loop (start + 1, sub (v, start))
539 end)
540
541 fun isSorted (v, op <=) = isSortedRange (v, 0, length v, op <=)
542
543 fun indexi (v, f) =
544 fold' (v, 0, (),
545 fn (i, a, _) => if f (i, a) then Done (SOME i) else Continue (),
546 fn _ => NONE)
547
548 fun index (v, f) = indexi (v, f o #2)
549
550 fun indices (a: bool t): int t =
551 keepAllMapi (a, fn (i, b) => if b then SOME i else NONE)
552
553 val indices =
554 Trace.trace ("Vector.indices", layout Bool.layout, layout Int.layout)
555 indices
556
557 fun isSubsequence (va, vb, f) =
558 let
559 val na = length va
560 val nb = length vb
561 fun loop (ia, ib) =
562 ia >= na
563 orelse let
564 val a = sub (va, ia)
565 fun loop' ib =
566 ib < nb
567 andalso if f (a, sub (vb, ib))
568 then loop (ia + 1, ib + 1)
569 else loop' (ib + 1)
570 in
571 loop' ib
572 end
573 in
574 loop (0, 0)
575 end
576
577 fun removeFirst (v, f) =
578 let
579 val seen = ref false
580 val v = keepAll (v, fn a =>
581 not (f a)
582 orelse (!seen)
583 orelse (seen := true
584 ; false))
585 val _ = if !seen then () else Error.bug "Vector.removeFirst"
586 in
587 v
588 end
589
590 fun partitioni (v, f) =
591 let
592 val n = ref 0
593 val v' = mapi (v, fn (i, x) =>
594 let
595 val b = f (i, x)
596 val _ = if b then n := 1 + !n else ()
597 in
598 (x,b)
599 end)
600 val n = !n
601 val r = ref 0
602 fun loop b (i:int) =
603 case unsafeSub (v', i) of
604 (x, b') => if b = b'
605 then (r := i + 1; x)
606 else loop b (i + 1)
607 val yes = tabulate (n, fn _ => loop true (!r))
608 val _ = r := 0
609 val no = tabulate (length v - n, fn _ => loop false (!r))
610 in
611 {yes = yes, no = no}
612 end
613
614 fun partition (v, f) = partitioni (v, f o #2)
615
616 fun prefix (v, n) = tabulate (n, fn i => sub (v, i))
617
618 fun removeDuplicates (v, equals) =
619 keepAllMapi (v, fn (i, x) =>
620 if i > 0 andalso equals (x, sub (v, i - 1))
621 then NONE
622 else SOME x)
623
624 fun randomElement v = sub (v, Random.natLessThan (length v))
625
626 end