Import Debian changes 20180207-1
[hcoop/debian/mlton.git] / mlton / match-compile / match-compile.fun
1 (* Copyright (C) 2015,2017 Matthew Fluet.
2 * Copyright (C) 1999-2007 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 * Copyright (C) 1997-2000 NEC Research Institute.
5 *
6 * MLton is released under a BSD-style license.
7 * See the file MLton-LICENSE for details.
8 *)
9
10 functor MatchCompile (S: MATCH_COMPILE_STRUCTS): MATCH_COMPILE =
11 struct
12
13 open S
14
15 structure Example =
16 struct
17 datatype t =
18 ConApp of {arg: t option, con: Con.t}
19 | ConstRange of {lo: Const.t option, hi: Const.t option, isChar: bool, isInt: bool}
20 | Exn
21 | Or of t vector
22 | Record of t SortedRecord.t
23 | Vector of t vector * {dots: bool}
24 | Wild
25
26 fun layout (ex, isDelimited) =
27 let
28 open Layout
29 fun delimit t = if isDelimited then t else paren t
30 fun layoutChar c =
31 let
32 fun loop (n: int, c: IntInf.t, ac: char list) =
33 if n = 0
34 then implode ac
35 else
36 let
37 val (q, r) = IntInf.quotRem (c, 0x10)
38 in
39 loop (n - 1, q, Char.fromHexDigit (Int.fromIntInf r) :: ac)
40 end
41 fun doit (n, esc) = str (concat ["\\", esc, loop (n, c, [])])
42 in
43 if c <= 0xFF
44 then str (Char.escapeSML (Char.fromInt (Int.fromIntInf c)))
45 else if c <= 0xFFFF
46 then doit (4, "u")
47 else doit (8, "U")
48 end
49 fun layoutConst (c, isChar, isInt) =
50 if isChar
51 then
52 case c of
53 Const.Word w =>
54 seq [str "#\"",
55 layoutChar (WordX.toIntInf w),
56 str "\""]
57 | _ => Error.bug (concat
58 ["MatchCompile.Example.layout.layoutConst: ",
59 "strange char: ",
60 Layout.toString (Const.layout c)])
61 else if isInt
62 then
63 case c of
64 Const.IntInf i => IntInf.layout i
65 | Const.Word w => IntInf.layout (WordX.toIntInfX w)
66 | _ => Error.bug (concat
67 ["MatchCompile.Example.layout.layoutConst: ",
68 "strange int: ",
69 Layout.toString (Const.layout c)])
70 else
71 case c of
72 Const.Word w =>
73 seq [str "0wx", str (IntInf.format (WordX.toIntInf w, StringCvt.HEX))]
74 | Const.WordVector ws =>
75 seq [str "\"",
76 seq (WordXVector.toListMap (ws, layoutChar o WordX.toIntInf)),
77 str "\""]
78 | _ => Error.bug (concat
79 ["MatchCompile.Example.layout.layoutConst: ",
80 "strange const: ",
81 Layout.toString (Const.layout c)])
82 in
83 case ex of
84 ConApp {arg, con} =>
85 (case arg of
86 NONE => str (Con.originalName con)
87 | SOME arg =>
88 (delimit o seq)
89 [str (Con.originalName con),
90 str " ",
91 layoutF arg])
92 | ConstRange {lo, hi, isChar, isInt} =>
93 (case (lo, hi) of
94 (NONE, NONE) => str "..."
95 | (NONE, SOME hi) =>
96 delimit (seq [str "... ", layoutConst (hi, isChar, isInt)])
97 | (SOME lo, NONE) =>
98 delimit (seq [layoutConst (lo, isChar, isInt), str " ..."])
99 | (SOME lo, SOME hi) =>
100 if Const.equals (lo, hi)
101 then layoutConst (lo, isChar, isInt)
102 else delimit (seq [layoutConst (lo, isChar, isInt),
103 str " .. ",
104 layoutConst (hi, isChar, isInt)]))
105 | Exn => delimit (str "_ : exn")
106 | Or exs =>
107 (delimit o mayAlign o separateLeft)
108 (Vector.toListMap (exs, layoutT), "| ")
109 | Record rexs =>
110 SortedRecord.layout
111 {extra = "",
112 layoutElt = layoutT,
113 layoutTuple = fn exs => tuple (Vector.toListMap (exs, layoutT)),
114 record = rexs,
115 separator = " = "}
116 | Vector (exs, {dots}) =>
117 let
118 val exs = Vector.map (exs, layoutT)
119 in
120 vector (if dots
121 then Vector.concat [exs, Vector.new1 (str "...")]
122 else exs)
123 end
124 | Wild => str "_"
125 end
126 and layoutF ex = layout (ex, false)
127 and layoutT ex = layout (ex, true)
128
129 val layout = layoutT
130
131 fun isWild ex =
132 case ex of
133 Wild => true
134 | _ => false
135
136 fun const {const, isChar, isInt} =
137 ConstRange {lo = SOME const, hi = SOME const,
138 isChar = isChar, isInt = isInt}
139 fun constRange {lo, hi, isChar, isInt} =
140 ConstRange {lo = lo, hi = hi,
141 isChar = isChar, isInt = isInt}
142
143 fun record rexs =
144 if SortedRecord.forall (rexs, isWild)
145 then Wild
146 else Record rexs
147
148 fun vector exs = Vector (exs, {dots = false})
149 fun vectorDots exs = Vector (exs, {dots = true})
150
151 fun compare (ex1, ex2) =
152 case (ex1, ex2) of
153 (* Wild sorts last *)
154 (Wild, Wild) => EQUAL
155 | (_, Wild) => LESS
156 | (Wild, _) => GREATER
157 (* Exn sorts last *)
158 | (Exn, Exn) => EQUAL
159 | (_, Exn) => LESS
160 | (Exn, _) => GREATER
161 | (ConstRange {lo = lo1, hi = hi1, isInt, ...},
162 ConstRange {lo = lo2, hi = hi2, ...}) =>
163 let
164 fun cmp (x, y, b, k) =
165 case (x, y) of
166 (NONE, NONE) => k EQUAL
167 | (NONE, SOME _) => if b then LESS else GREATER
168 | (SOME _, NONE) => if b then GREATER else LESS
169 | (SOME (Const.Word w1), SOME (Const.Word w2)) =>
170 k (WordX.compare (w1, w2, {signed = isInt}))
171 | (SOME (Const.IntInf ii1), SOME (Const.IntInf ii2)) =>
172 k (IntInf.compare (ii1, ii2))
173 | (SOME (Const.WordVector ws1), SOME (Const.WordVector ws2)) =>
174 k (WordXVector.compare (ws1, ws2))
175 | _ => Error.bug "MatchCompile.Example.compare: ConstRange/ConstRange"
176 in
177 cmp (lo1, lo2, true, fn order =>
178 case order of
179 LESS => LESS
180 | EQUAL => cmp (hi1, hi2, false, fn order => order)
181 | GREATER => GREATER)
182 end
183 | (ConApp {con = con1, arg = arg1}, ConApp {con = con2, arg = arg2}) =>
184 (case String.compare (Con.toString con1, Con.toString con2) of
185 LESS => LESS
186 | EQUAL => (case (arg1, arg2) of
187 (SOME arg1, SOME arg2) => compare' (arg1, arg2)
188 | (NONE, NONE) => EQUAL
189 | _ => Error.bug "MatchCompile.Example.compare: ConApp/ConApp")
190 | GREATER => GREATER)
191 | (Vector (exs1, {dots = dots1}), Vector (exs2, {dots = dots2})) =>
192 (case (dots1, dots2) of
193 (false, true) => LESS
194 | (true, false) => GREATER
195 | _ => Vector.compare (exs1, exs2, compare'))
196 | (Record rexs1, Record rexs2) =>
197 Vector.compare (SortedRecord.range rexs1, SortedRecord.range rexs2, compare')
198 | _ => Error.bug "MatchCompile.Example.compare"
199 and compare' (ex1, ex2) =
200 case (ex1, ex2) of
201 (Or ex1s, Or ex2s) => compares (Vector.toList ex1s, Vector.toList ex2s)
202 | (Or ex1s, _) => compares (Vector.toList ex1s, [ex2])
203 | (_, Or ex2s) => compares ([ex1], Vector.toList ex2s)
204 | _ => compare (ex1, ex2)
205 and compares (exs1, exs2) =
206 List.compare (exs1, exs2, compare)
207
208 fun or exs =
209 let
210 fun join (exs1, exs2) =
211 case (exs1, exs2) of
212 ([], _) => exs2
213 | (_, []) => exs1
214 | ((ex1 as ConApp {con = con1, arg = arg1})::exs1',
215 (ex2 as ConApp {con = con2, arg = arg2})::exs2') =>
216 (case String.compare (Con.toString con1, Con.toString con2) of
217 LESS => ex1::(join (exs1', exs2))
218 | EQUAL =>
219 let
220 val arg =
221 case (arg1, arg2) of
222 (SOME arg1, SOME arg2) => or [arg1, arg2]
223 | (NONE, NONE) => NONE
224 | _ => Error.bug "MatchCompile.Example.or.join"
225 in
226 (ConApp {con = con1, arg = arg})::
227 (join (exs1', exs2'))
228 end
229 | GREATER => ex2::(join (exs1, exs2')))
230 | (ex1::exs1', ex2::exs2') =>
231 (case compare (ex1, ex2) of
232 LESS => ex1::(join (exs1', exs2))
233 | EQUAL => ex1::(join (exs1', exs2'))
234 | GREATER => ex2::(join (exs1, exs2')))
235 val exss =
236 List.map (exs, fn Or exs => Vector.toList exs | ex => [ex])
237 val exs =
238 List.fold (exss, [], join)
239 in
240 case exs of
241 [] => NONE
242 | [ex] => SOME ex
243 | _ => SOME (Or (Vector.fromList exs))
244 end
245
246 end
247
248 structure Env = MonoEnv (structure Domain = Var
249 structure Range = Var)
250
251 structure Fact =
252 struct
253 datatype t =
254 Con of {arg: Var.t option,
255 con: Con.t}
256 | Record of Var.t SortedRecord.t
257 | Vector of Var.t vector
258
259 fun layout (f: t): Layout.t =
260 let
261 open Layout
262 in
263 case f of
264 Con {arg, con} =>
265 seq [Con.layout con,
266 case arg of
267 NONE => empty
268 | SOME x => seq [str " ", Var.layout x]]
269 | Record r =>
270 SortedRecord.layout
271 {extra = "",
272 layoutElt = Var.layout,
273 layoutTuple = fn xs => tuple (Vector.toListMap (xs, Var.layout)),
274 record = r,
275 separator = " = "}
276 | Vector xs => vector (Vector.map (xs, Var.layout))
277 end
278 end
279
280 structure Examples =
281 struct
282 datatype t = T of {exs: (Var.t * Example.t) list,
283 isOnlyExns: bool}
284
285 fun layout (T {exs, ...}) =
286 List.layout (Layout.tuple2 (Var.layout, Example.layout)) exs
287
288 val empty = T {exs = [], isOnlyExns = true}
289
290 fun add (T {exs, isOnlyExns = is}, x, ex, {isOnlyExns: bool}) =
291 T {exs = (x, ex) :: exs,
292 isOnlyExns = is andalso isOnlyExns}
293 end
294
295 structure Facts =
296 struct
297 datatype t = T of {fact: Fact.t,
298 var: Var.t} list
299
300 fun layout (T fs) =
301 let
302 open Layout
303 in
304 List.layout (fn {fact, var} =>
305 seq [Var.layout var, str " = ", Fact.layout fact])
306 fs
307 end
308
309 val empty: t = T []
310
311 fun add (T fs, x, f) = T ({fact = f, var = x} :: fs)
312
313 fun bind (T facts, x: Var.t, p: NestedPat.t): Env.t =
314 let
315 val {destroy, get = fact: Var.t -> Fact.t, set = setFact, ...} =
316 Property.destGetSetOnce
317 (Var.plist, Property.initRaise ("fact", Var.layout))
318 val () = List.foreach (facts, fn {fact, var} => setFact (var, fact))
319 fun loop (p: NestedPat.t, x: Var.t, env: Env.t): Env.t =
320 let
321 datatype z = datatype NestedPat.node
322 in
323 case NestedPat.node p of
324 Con {arg, ...} =>
325 (case arg of
326 NONE => env
327 | SOME p =>
328 (case fact x of
329 Fact.Con {arg = SOME x, ...} =>
330 loop (p, x, env)
331 | _ => Error.bug "MatchCompile.Facts.bind: Con:wrong fact"))
332 | Const _ => env
333 | Layered (y, p) => loop (p, x, Env.extend (env, y, x))
334 | Or _ => Error.bug "MatchCompile.factbind: or pattern shouldn't be here"
335 | Record rp =>
336 (case fact x of
337 Fact.Record rx =>
338 Vector.fold2 (SortedRecord.range rp, SortedRecord.range rx, env, loop)
339 | _ => Error.bug "MatchCompile.Facts.bind: Record:wrong fact")
340 | Var y => Env.extend (env, y, x)
341 | Vector ps =>
342 (case fact x of
343 Fact.Vector xs =>
344 Vector.fold2 (ps, xs, env, loop)
345 | _ => Error.bug "MatchCompile.Facts.bind: Vector:wrong fact")
346 | Wild => env
347 end
348 val env = loop (p, x, Env.empty)
349 val () = destroy ()
350 in
351 env
352 end
353
354 val bind =
355 Trace.trace3 ("MatchCompile.Facts.bind",
356 layout, Var.layout, NestedPat.layout, Env.layout)
357 bind
358
359 fun example (T facts, Examples.T {exs, ...}, x: Var.t): Example.t =
360 let
361 val {destroy,
362 get = fact: Var.t -> Fact.t option,
363 set = setFact, ...} =
364 Property.destGetSetOnce (Var.plist, Property.initConst NONE)
365 val () = List.foreach (facts, fn {fact, var} =>
366 setFact (var, SOME fact))
367 fun loop (x: Var.t): Example.t =
368 case fact x of
369 NONE =>
370 (case List.peek (exs, fn (x', _) => Var.equals (x, x')) of
371 NONE => Example.Wild
372 | SOME (_, ex) => ex)
373 | SOME f =>
374 (case f of
375 Fact.Con {arg, con} =>
376 Example.ConApp {con = con, arg = Option.map (arg, loop)}
377 | Fact.Record rxs =>
378 Example.record (SortedRecord.map (rxs, loop))
379 | Fact.Vector xs =>
380 Example.vector (Vector.map (xs, loop)))
381 val res = loop x
382 val () = destroy ()
383 in
384 res
385 end
386
387 val example =
388 Trace.trace3
389 ("MatchCompile.Facts.example",
390 layout, Examples.layout, Var.layout, Example.layout)
391 example
392 end
393
394 structure Pat =
395 struct
396 datatype t =
397 Const of {const: Const.t,
398 isChar: bool,
399 isInt: bool}
400 | Con of {arg: (t * Type.t) option,
401 con: Con.t,
402 targs: Type.t vector}
403 | Record of t SortedRecord.t
404 | Vector of t vector
405 | Wild
406
407 fun layout (p: t): Layout.t =
408 let
409 open Layout
410 in
411 case p of
412 Const {const, ...} => Const.layout const
413 | Con {arg, con, ...} =>
414 seq [Con.layout con,
415 case arg of
416 NONE => empty
417 | SOME (p, _) => seq [str " ", layout p]]
418 | Record rps =>
419 SortedRecord.layout
420 {extra = "",
421 layoutElt = layout,
422 layoutTuple = fn ps => tuple (Vector.toListMap (ps, layout)),
423 record = rps,
424 separator = " = "}
425 | Vector ps => vector (Vector.map (ps, layout))
426 | Wild => str "_"
427 end
428
429 val isWild: t -> bool =
430 fn Wild => true
431 | _ => false
432
433 val fromNestedPat: NestedPat.t -> t =
434 let
435 fun loop (p: NestedPat.t): t =
436 case NestedPat.node p of
437 NestedPat.Con {arg, con, targs} =>
438 let
439 val arg =
440 Option.map (arg, fn p => (loop p, NestedPat.ty p))
441 in
442 Con {arg = arg, con = con, targs = targs}
443 end
444 | NestedPat.Const r => Const r
445 | NestedPat.Layered (_, p) => loop p
446 | NestedPat.Or _ => Error.bug "MatchCompile.fromNestedPat: or pattern shouldn't be here"
447 | NestedPat.Record rps => Record (SortedRecord.map (rps, loop))
448 | NestedPat.Var _ => Wild
449 | NestedPat.Vector ps => Vector (Vector.map (ps, loop))
450 | NestedPat.Wild => Wild
451 in
452 loop
453 end
454 end
455
456 structure Vector =
457 struct
458 open Vector
459
460 fun dropNth (v: 'a t, n: int): 'a t =
461 keepAllMapi (v, fn (i, a) => if i = n then NONE else SOME a)
462 end
463
464 structure Rule =
465 struct
466 datatype t =
467 T of {pats: Pat.t vector,
468 rest: {examples: (Example.t * {isOnlyExns: bool}) list ref option,
469 finish: (Var.t -> Var.t) -> Exp.t,
470 nestedPat: NestedPat.t}}
471
472
473 fun layout (T {pats, ...}) =
474 Layout.tuple (Vector.toListMap (pats, Pat.layout))
475
476 fun allWild (T {pats, ...}) = Vector.forall (pats, Pat.isWild)
477
478 fun dropNth (T {pats, rest}, n) =
479 T {pats = Vector.dropNth (pats, n),
480 rest = rest}
481 end
482
483 structure Rules =
484 struct
485 type t = Rule.t vector
486
487 fun layout (rs: t) = Layout.align (Vector.toListMap (rs, Rule.layout))
488
489 fun dropNth (rs: t, n: int): t =
490 Vector.map (rs, fn r => Rule.dropNth (r, n))
491 end
492
493 structure Vars =
494 struct
495 type t = (Var.t * Type.t) vector
496
497 val layout = Vector.layout (Layout.tuple2 (Var.layout, Type.layout))
498 end
499
500 val directCases =
501 List.keepAllMap (WordSize.all, fn s =>
502 if WordSize.equals (s, WordSize.fromBits (Bits.fromInt 64))
503 then NONE
504 else SOME {size = s, ty = Type.word s})
505
506 fun unhandledConsts {consts = cs: Const.t vector, isChar, isInt}: Example.t option =
507 let
508 fun search {<= : 'a * 'a -> bool,
509 equals: 'a * 'a -> bool,
510 extract: Const.t -> 'a,
511 make: 'a -> Const.t,
512 max: 'a option,
513 min: 'a option,
514 next: 'a -> 'a,
515 prev: 'a -> 'a} =
516 let
517 fun exampleConstRange (lo, hi) =
518 Example.constRange
519 {lo = Option.map (lo, make),
520 hi = Option.map (hi, make),
521 isChar = isChar, isInt = isInt}
522 fun mkExampleConstRange (lo, hi) =
523 if lo <= hi
524 then if equals (lo, hi)
525 then [exampleConstRange (SOME lo, SOME hi)]
526 else let
527 val lo' = next lo
528 val hi' = prev hi
529 in
530 if equals (lo', hi)
531 then [exampleConstRange (SOME lo, SOME lo),
532 exampleConstRange (SOME hi, SOME hi)]
533 else if equals (lo', hi')
534 then [exampleConstRange (SOME lo, SOME lo),
535 exampleConstRange (SOME lo', SOME hi'),
536 exampleConstRange (SOME hi, SOME hi)]
537 else [exampleConstRange (SOME lo, SOME hi)]
538 end
539 else []
540 val cs = QuickSort.sortVector (Vector.map (cs, extract), op <=)
541 val cs = Vector.toList cs
542 fun loop cs =
543 case cs of
544 [] => []
545 | [cMax] =>
546 (case max of
547 NONE => [exampleConstRange (SOME (next cMax), NONE)]
548 | SOME max' =>
549 if equals (cMax, max')
550 then []
551 else mkExampleConstRange (next cMax, max'))
552 | c1::c2::cs =>
553 (mkExampleConstRange (next c1, prev c2)) @ (loop (c2::cs))
554 val cMin = hd cs
555 val examples =
556 case min of
557 NONE => [exampleConstRange (NONE, SOME (prev cMin))] @ (loop cs)
558 | SOME min' =>
559 if equals (cMin, min')
560 then loop cs
561 else (mkExampleConstRange (min', prev cMin)) @ (loop cs)
562 in
563 Example.or examples
564 end
565 datatype z = datatype Const.t
566 in
567 case Vector.first cs of
568 IntInf _ =>
569 let
570 fun extract c =
571 case c of
572 IntInf i => i
573 | _ => Error.bug "MatchCompile.unhandledConsts: expected IntInf"
574 in
575 search {<= = op <=,
576 equals = op =,
577 extract = extract,
578 make = Const.IntInf,
579 max = NONE,
580 min = NONE,
581 next = fn i => i + 1,
582 prev = fn i => i - 1}
583 end
584 | Null => Error.bug "MatchCompile.unhandledConsts: Null"
585 | Real _ => Error.bug "MatchCompile.unhandledConsts: Real"
586 | Word w =>
587 let
588 val s = WordX.size w
589 val signed = {signed = isInt}
590 fun extract c =
591 case c of
592 Word w => w
593 | _ => Error.bug "MatchCompile.unhandledConsts: expected Word"
594 in
595 search {<= = fn (w1, w2) => WordX.le (w1, w2, signed),
596 equals = WordX.equals,
597 extract = extract,
598 make = Const.word,
599 max = SOME (WordX.max (s, signed)),
600 min = SOME (WordX.min (s, signed)),
601 next = fn w => WordX.add (w, WordX.one s),
602 prev = fn w => WordX.sub (w, WordX.one s)}
603 end
604 | WordVector ws =>
605 let
606 val s = WordXVector.elementSize ws
607 val signed = {signed = false}
608 fun extract c =
609 case c of
610 WordVector ws => ws
611 | _ => Error.bug "MatchCompile.unhandledConsts: expected Word"
612 fun next ws =
613 let
614 val wsOrig = List.rev (WordXVector.toListMap (ws, fn w => w))
615 val wsNext =
616 let
617 fun loop ws =
618 case ws of
619 [] => [WordX.min (s, signed)]
620 | w::ws =>
621 if WordX.isMax (w, signed)
622 then (WordX.min (s, signed))::(loop ws)
623 else (WordX.add (w, WordX.one s))::ws
624 in
625 loop wsOrig
626 end
627 in
628 WordXVector.fromListRev ({elementSize = s}, wsNext)
629 end
630 fun prev ws =
631 let
632 val wsOrig = List.rev (WordXVector.toListMap (ws, fn w => w))
633 val wsPrev =
634 let
635 fun loop ws =
636 case ws of
637 [] => Error.bug "MatchCompile.unhandledConst: WordXVector.prev"
638 | [w] =>
639 if WordX.isMin (w, signed)
640 then []
641 else [WordX.sub (w, WordX.one s)]
642 | w::ws =>
643 if WordX.isMin (w, signed)
644 then (WordX.max (s, signed))::(loop ws)
645 else (WordX.sub (w, WordX.one s))::ws
646 in
647 loop wsOrig
648 end
649 in
650 WordXVector.fromListRev ({elementSize = s}, wsPrev)
651 end
652 in
653 search {<= = WordXVector.le,
654 equals = WordXVector.equals,
655 extract = extract,
656 make = Const.wordVector,
657 max = NONE,
658 min = SOME (WordXVector.fromVector ({elementSize = s}, Vector.new0 ())),
659 next = next,
660 prev = prev}
661 end
662 end
663
664 structure Exp =
665 struct
666 open Exp
667
668 fun layout (_: t) = Layout.str "<exp>"
669 end
670
671 val traceMatch =
672 Trace.trace ("MatchCompile.match",
673 fn (vars, rules, facts, es) =>
674 Layout.record [("vars", Vars.layout vars),
675 ("rules", Rules.layout rules),
676 ("facts", Facts.layout facts),
677 ("examples", Examples.layout es)],
678 Exp.layout)
679 val traceConst =
680 Trace.trace ("MatchCompile.const",
681 fn (vars, rules, facts, es, i: Int.t, test: Exp.t) =>
682 Layout.record [("vars", Vars.layout vars),
683 ("rules", Rules.layout rules),
684 ("facts", Facts.layout facts),
685 ("examples", Examples.layout es),
686 ("index", Int.layout i),
687 ("test", Exp.layout test)],
688 Exp.layout)
689 val traceSum =
690 Trace.trace ("MatchCompile.sum",
691 fn (vars, rules, facts, es, i: Int.t, test: Exp.t, _: Tycon.t) =>
692 Layout.record [("vars", Vars.layout vars),
693 ("rules", Rules.layout rules),
694 ("facts", Facts.layout facts),
695 ("examples", Examples.layout es),
696 ("index", Int.layout i),
697 ("test", Exp.layout test)],
698 Exp.layout)
699 val traceRecord =
700 Trace.trace ("MatchCompile.record",
701 fn (vars, rules, facts, es, i: Int.t, test: Exp.t, _: Field.t vector) =>
702 Layout.record [("vars", Vars.layout vars),
703 ("rules", Rules.layout rules),
704 ("facts", Facts.layout facts),
705 ("examples", Examples.layout es),
706 ("index", Int.layout i),
707 ("test", Exp.layout test)],
708 Exp.layout)
709 val traceVector =
710 Trace.trace ("MatchCompile.vector",
711 fn (vars, rules, facts, es, i: Int.t, test: Exp.t) =>
712 Layout.record [("vars", Vars.layout vars),
713 ("rules", Rules.layout rules),
714 ("facts", Facts.layout facts),
715 ("examples", Examples.layout es),
716 ("index", Int.layout i),
717 ("test", Exp.layout test)],
718 Exp.layout)
719
720 (*---------------------------------------------------*)
721 (* matchCompile *)
722 (*---------------------------------------------------*)
723
724 fun matchCompile {caseType: Type.t,
725 cases: (NestedPat.t * ((Var.t -> Var.t) -> Exp.t)) vector,
726 conTycon: Con.t -> Tycon.t,
727 region: Region.t,
728 test: Var.t,
729 testType: Type.t,
730 tyconCons: Tycon.t -> {con: Con.t,
731 hasArg: bool} vector} =
732 let
733 fun chooseColumn _ = 0
734 fun match arg : Exp.t =
735 traceMatch
736 (fn (vars: Vars.t, rules: Rules.t, facts: Facts.t, es) =>
737 if Vector.isEmpty rules
738 then Error.bug "MatchCompile.match: no rules"
739 else if Rule.allWild (Vector.first rules)
740 then (* The first rule matches. *)
741 let
742 val Rule.T {rest = {examples, finish, nestedPat, ...}, ...} =
743 Vector.first rules
744 val env = Facts.bind (facts, test, nestedPat)
745 val Examples.T {isOnlyExns, ...} = es
746 val () =
747 Option.app
748 (examples, fn examples =>
749 List.push (examples,
750 (Facts.example (facts, es, test),
751 {isOnlyExns = isOnlyExns})))
752 in
753 finish (fn x => Env.lookup (env, x))
754 end
755 else
756 let
757 val i = chooseColumn rules
758 in
759 case Vector.peek (rules, fn Rule.T {pats, ...} =>
760 not (Pat.isWild (Vector.sub (pats, i)))) of
761 NONE => match (Vector.dropNth (vars, i),
762 Rules.dropNth (rules, i),
763 facts, es)
764 | SOME (Rule.T {pats, ...}) =>
765 let
766 datatype z = datatype Pat.t
767 val test = Exp.var (Vector.sub (vars, i))
768 in
769 case Vector.sub (pats, i) of
770 Const _ => const (vars, rules, facts, es, i, test)
771 | Con {con, ...} =>
772 sum (vars, rules, facts, es, i, test, conTycon con)
773 | Record rps =>
774 record (vars, rules, facts, es, i, test, SortedRecord.domain rps)
775 | Vector _ => vector (vars, rules, facts, es, i, test)
776 | Wild => Error.bug "MatchCompile.match: Wild"
777 end
778 end) arg
779 and const arg =
780 traceConst
781 (fn (vars, rules, facts, es, i, test) =>
782 let
783 val (var, ty) = Vector.sub (vars, i)
784 val {isChar, isInt} =
785 case Vector.peekMap (rules, fn Rule.T {pats, ...} =>
786 case Vector.sub (pats, i) of
787 Pat.Const {isChar, isInt, ...} =>
788 SOME {isChar = isChar, isInt = isInt}
789 | _ => NONE) of
790 NONE => {isChar = false, isInt = false}
791 | SOME {isChar, isInt} => {isChar = isChar, isInt = isInt}
792 fun exampleConst c =
793 Example.const {const = c, isChar = isChar, isInt = isInt}
794 val (cases, defaults) =
795 Vector.foldr
796 (rules, ([], []),
797 fn (rule as Rule.T {pats, ...}, (cases, defaults)) =>
798 let
799 val rule = Rule.dropNth (rule, i)
800 in
801 case Vector.sub (pats, i) of
802 Pat.Const {const = c, ...} =>
803 let
804 fun insert (cases, ac) =
805 case cases of
806 [] =>
807 {const = c, rules = rule :: defaults} :: ac
808 | (casee as {const, rules}) :: cases =>
809 if Const.equals (c, const)
810 then
811 {const = c, rules = rule :: rules}
812 :: List.appendRev (ac, cases)
813 else insert (cases, casee :: ac)
814 in
815 (insert (cases, []), defaults)
816 end
817 | Pat.Wild =>
818 (List.map (cases, fn {const, rules} =>
819 {const = const, rules = rule :: rules}),
820 rule :: defaults)
821 | _ => Error.bug "MatchCompile.const: expected Const pat"
822 end)
823 val cases = Vector.fromListMap (cases, fn {const, rules} =>
824 {const = const,
825 rules = Vector.fromList rules})
826 val defaults = Vector.fromList defaults
827 val vars = Vector.dropNth (vars, i)
828 fun finish (rules: Rule.t vector, e): Exp.t =
829 match (vars, rules, facts,
830 Examples.add (es, var, e, {isOnlyExns = false}))
831 val default: Exp.t option =
832 Option.map
833 (unhandledConsts {consts = Vector.map (cases, #const),
834 isChar = isChar, isInt = isInt},
835 fn e => finish (defaults, e))
836 in
837 case List.peek (directCases, fn {ty = ty', ...} =>
838 Type.equals (ty, ty')) of
839 NONE =>
840 let
841 val (cases, default) =
842 case default of
843 SOME default => (cases, default)
844 | NONE =>
845 (Vector.dropSuffix (cases, 1),
846 let val {const, rules} = Vector.last cases
847 in finish (rules, exampleConst const)
848 end)
849 in
850 Vector.fold
851 (cases, default, fn ({const, rules}, rest) =>
852 Exp.iff {test = Exp.equal (test, Exp.const const),
853 thenn = finish (rules, exampleConst const),
854 elsee = rest,
855 ty = caseType})
856 end
857 | SOME {size, ...} =>
858 let
859 val default =
860 Option.map
861 (default, fn default =>
862 (default, region))
863 val cases =
864 Vector.map
865 (cases, fn {const, rules} =>
866 let
867 val w =
868 case const of
869 Const.Word w => w
870 | _ => Error.bug "MatchCompile.const: caseWord type error"
871 in
872 (w, finish (rules, exampleConst const))
873 end)
874 in
875 Exp.casee {cases = Cases.word (size, cases),
876 default = default,
877 test = test,
878 ty = caseType}
879 end
880 end) arg
881 and sum arg =
882 traceSum
883 (fn (vars: Vars.t, rules: Rules.t, facts: Facts.t, es,
884 i, test, tycon) =>
885 let
886 val (var, _) = Vector.sub (vars, i)
887 val (cases, defaults) =
888 Vector.foldr
889 (rules, ([], []),
890 fn (rule as Rule.T {pats, ...}, (cases, defaults)) =>
891 case Vector.sub (pats, i) of
892 Pat.Con {arg, con, targs} =>
893 let
894 fun oneCase () =
895 let
896 val (arg, vars) =
897 case arg of
898 NONE =>
899 (NONE,
900 Vector.keepAllMapi
901 (vars, fn (i', x) =>
902 if i = i' then NONE else SOME x))
903 | SOME (_, ty) =>
904 let
905 val arg = Var.newNoname ()
906 in
907 (SOME (arg, ty),
908 Vector.mapi
909 (vars, fn (i', x) =>
910 if i = i' then (arg, ty) else x))
911 end
912 in
913 {rest = {arg = arg,
914 con = con,
915 targs = targs,
916 vars = vars},
917 rules = rule :: defaults}
918 end
919 fun insert (cases, ac) =
920 case cases of
921 [] => oneCase () :: ac
922 | ((casee as {rest as {con = con', ...}, rules})
923 :: cases) =>
924 if Con.equals (con, con')
925 then
926 {rest = rest, rules = rule :: rules}
927 :: List.appendRev (ac, cases)
928 else insert (cases, casee :: ac)
929 in
930 (insert (cases, []), defaults)
931 end
932 | Pat.Wild =>
933 (List.map (cases, fn {rest, rules} =>
934 {rest = rest, rules = rule :: rules}),
935 rule :: defaults)
936 | _ => Error.bug "MatchCompile.sum: expected Con pat")
937 val cases =
938 Vector.fromListMap
939 (cases, fn {rest = {arg, con, targs, vars}, rules} =>
940 let
941 val rules =
942 Vector.fromListMap
943 (rules, fn Rule.T {pats, rest} =>
944 let
945 val pats =
946 Vector.keepAllMapi
947 (pats, fn (i', p') =>
948 if i <> i' then SOME p'
949 else
950 case p' of
951 Pat.Con {arg, ...} => Option.map (arg, #1)
952 | Pat.Wild =>
953 Option.map (arg, fn _ => Pat.Wild)
954 | _ => Error.bug "MatchCompile.sum: decon got strange pattern")
955 in
956 Rule.T {pats = pats, rest = rest}
957 end)
958 val facts =
959 Facts.add
960 (facts, var,
961 Fact.Con {arg = Option.map (arg, #1), con = con})
962 in
963 {arg = arg,
964 con = con,
965 rhs = match (vars, rules, facts, es),
966 targs = targs}
967 end)
968 fun done (e, isOnlyExns) =
969 SOME (match (Vector.dropNth (vars, i),
970 Rules.dropNth (Vector.fromList defaults, i),
971 facts,
972 Examples.add (es, var, e,
973 {isOnlyExns = isOnlyExns})))
974 val default =
975 if Vector.isEmpty cases
976 then done (Example.Wild, true)
977 else if Tycon.equals (tycon, Tycon.exn)
978 then done (Example.Exn, true)
979 else
980 let
981 val cons = tyconCons tycon
982 val unhandled =
983 List.keepAllMap
984 (Vector.toList cons, fn {con, hasArg, ...} =>
985 if Vector.exists (cases, fn {con = con', ...} =>
986 Con.equals (con, con'))
987 then NONE
988 else SOME (Example.ConApp
989 {con = con,
990 arg = if hasArg
991 then SOME Example.Wild
992 else NONE}))
993 in
994 Option.fold
995 (Example.or unhandled, NONE, fn (e, _) => done (e, false))
996 end
997 fun normal () =
998 Exp.casee {cases = Cases.con cases,
999 default = Option.map (default, fn e => (e, region)),
1000 test = test,
1001 ty = caseType}
1002 in
1003 if 1 <> Vector.length cases
1004 then normal ()
1005 else
1006 let
1007 val {arg, con, rhs, ...} = Vector.first cases
1008 in
1009 if not (Con.equals (con, Con.reff))
1010 then normal ()
1011 else
1012 case arg of
1013 NONE => Error.bug "MatchCompile.sum: ref missing arg"
1014 | SOME (var, _) =>
1015 Exp.lett {body = rhs,
1016 exp = Exp.deref test,
1017 var = var}
1018 end
1019 end) arg
1020 and record arg =
1021 traceRecord
1022 (fn (vars: Vars.t, rules: Rules.t, facts: Facts.t, es, i, test, fs) =>
1023 let
1024 val (var, varTy) = Vector.sub (vars, i)
1025 fun body vars' =
1026 let
1027 val n = Vector.length vars'
1028 val vars =
1029 Vector.concatV
1030 (Vector.mapi
1031 (vars, fn (i', x) =>
1032 if i = i'
1033 then vars'
1034 else Vector.new1 x))
1035 val rules =
1036 Vector.map
1037 (rules, fn Rule.T {pats, rest} =>
1038 let
1039 val pats =
1040 Vector.concatV
1041 (Vector.mapi
1042 (pats, fn (i', p) =>
1043 if i <> i'
1044 then Vector.new1 p
1045 else (case p of
1046 Pat.Record rps => SortedRecord.range rps
1047 | Pat.Wild =>
1048 Vector.tabulate (n, fn _ => Pat.Wild)
1049 | _ => Error.bug "MatchCompile.record: derecord")))
1050 in
1051 Rule.T {pats = pats, rest = rest}
1052 end)
1053 val facts =
1054 Facts.add
1055 (facts, var,
1056 Fact.Record (SortedRecord.zip (fs, Vector.map (vars', #1))))
1057 in
1058 match (vars, rules, facts, es)
1059 end
1060 in
1061 if Vector.length fs = 1
1062 then let val var' = Var.newNoname ()
1063 in
1064 (* Although 'test' is likely a variable,
1065 * must bind to a fresh variable to maintain
1066 * a unique Fact.t per variable in Facts.t.
1067 *)
1068 Exp.lett {var = var', exp = test,
1069 body = body (Vector.new1 (var', varTy))}
1070 end
1071 else Exp.detuple {body = body, tuple = test}
1072 end) arg
1073 and vector arg =
1074 traceVector
1075 (fn (vars: Vars.t, rules: Rules.t, facts: Facts.t, es, i, test) =>
1076 let
1077 val (var, _) = Vector.sub (vars, i)
1078 val (cases, defaults) =
1079 Vector.foldr
1080 (rules, ([], []),
1081 fn (rule as Rule.T {pats, ...}, (cases, defaults)) =>
1082 case Vector.sub (pats, i) of
1083 Pat.Vector args =>
1084 let
1085 fun oneCase () =
1086 {len = Vector.length args,
1087 rules = rule :: defaults}
1088 fun insert (cases, ac) =
1089 case cases of
1090 [] => oneCase () :: ac
1091 | ((casee as {len, rules})::cases) =>
1092 if Vector.length args = len
1093 then
1094 {len = len, rules = rule :: rules}
1095 :: List.appendRev (ac, cases)
1096 else insert (cases, casee :: ac)
1097 in
1098 (insert (cases, []), defaults)
1099 end
1100 | Pat.Wild =>
1101 (List.map (cases, fn {len, rules} =>
1102 {len = len, rules = rule :: rules}),
1103 rule :: defaults)
1104 | _ => Error.bug "MatchCompile.vector: expected Vector pat")
1105 val default =
1106 let
1107 val maxLen =
1108 List.fold
1109 (cases, ~1, fn ({len, ...}, max) =>
1110 Int.max (max, len))
1111 val unhandled =
1112 Example.vectorDots (Vector.new (maxLen + 1, Example.Wild))
1113 val unhandled =
1114 Int.foldDown
1115 (0, maxLen, [unhandled], fn (i, unhandled) =>
1116 if List.exists (cases, fn {len, ...} => i = len)
1117 then unhandled
1118 else (Example.vector (Vector.new (i, Example.Wild))) :: unhandled)
1119 val unhandled =
1120 Example.or unhandled
1121 in
1122 match (Vector.dropNth (vars, i),
1123 Rules.dropNth (Vector.fromList defaults, i),
1124 facts,
1125 Option.fold
1126 (unhandled, es, fn (unhandled, es) =>
1127 Examples.add (es, var, unhandled, {isOnlyExns = false})))
1128 end
1129 val cases =
1130 Vector.fromListMap
1131 (cases, fn {len, rules} =>
1132 let
1133 fun body vars' =
1134 let
1135 val vars =
1136 Vector.concatV
1137 (Vector.mapi
1138 (vars, fn (i', x) =>
1139 if i = i'
1140 then vars'
1141 else Vector.new1 x))
1142 val rules =
1143 Vector.fromListMap
1144 (rules, fn Rule.T {pats, rest} =>
1145 let
1146 val pats =
1147 Vector.concatV
1148 (Vector.mapi
1149 (pats, fn (i', p) =>
1150 if i <> i'
1151 then Vector.new1 p
1152 else (case p of
1153 Pat.Vector ps => ps
1154 | Pat.Wild => Vector.new (len, Pat.Wild)
1155 | _ => Error.bug "MatchCompile.vector: devector")))
1156 in
1157 Rule.T {pats = pats, rest = rest}
1158 end)
1159 in
1160 match (vars, rules,
1161 Facts.add (facts, var,
1162 Fact.Vector (Vector.map (vars', #1))),
1163 es)
1164 end
1165 in
1166 (WordX.fromIntInf (IntInf.fromInt len, WordSize.seqIndex ()),
1167 Exp.devector {vector = test, length = len, body = body})
1168 end)
1169 in
1170 Exp.casee
1171 {cases = Cases.word (WordSize.seqIndex (), cases),
1172 default = SOME (default, region),
1173 test = Exp.vectorLength test,
1174 ty = caseType}
1175 end) arg
1176 val examples = ref []
1177 val res =
1178 match (Vector.new1 (test, testType),
1179 Vector.mapi (cases, fn (i, (p, f)) =>
1180 Rule.T {pats = Vector.new1 (Pat.fromNestedPat p),
1181 rest = {examples = if i = Vector.length cases - 1
1182 then SOME examples
1183 else NONE,
1184 finish = f,
1185 nestedPat = p}}),
1186 Facts.empty,
1187 Examples.empty)
1188 val examples =
1189 fn {dropOnlyExns} =>
1190 let
1191 val example =
1192 (Example.or o List.keepAllMap)
1193 (!examples, fn (ex, {isOnlyExns}) =>
1194 if dropOnlyExns andalso isOnlyExns
1195 then NONE
1196 else SOME ex)
1197 in
1198 Option.map (example, Example.layout)
1199 end
1200 in
1201 (res, examples)
1202 end
1203
1204 val matchCompile =
1205 fn {caseType: Type.t,
1206 cases: (NestedPat.t * (int -> (Var.t -> Var.t) -> Exp.t)) vector,
1207 conTycon: Con.t -> Tycon.t,
1208 region: Region.t,
1209 test: Var.t,
1210 testType: Type.t,
1211 tyconCons: Tycon.t -> {con: Con.t,
1212 hasArg: bool} vector} =>
1213 let
1214 val cases =
1215 Vector.map
1216 (cases, fn (pat, mk) =>
1217 let
1218 val pats = NestedPat.flatten pat
1219 val mk = mk (Vector.length pats)
1220 in
1221 Vector.map (pats, fn pat => (pat, mk))
1222 end)
1223 val cases = Vector.concatV cases
1224 in
1225 matchCompile {caseType = caseType,
1226 cases = cases,
1227 conTycon = conTycon,
1228 region = region,
1229 test = test,
1230 testType = testType,
1231 tyconCons = tyconCons}
1232 end
1233
1234 val matchCompile =
1235 Trace.trace
1236 ("MatchCompile.matchCompile",
1237 fn {caseType, cases, test, testType, ...} =>
1238 Layout.record [("caseType", Type.layout caseType),
1239 ("cases", Vector.layout (NestedPat.layout o #1) cases),
1240 ("test", Var.layout test),
1241 ("testType", Type.layout testType)],
1242 Exp.layout o #1)
1243 matchCompile
1244
1245 end