Import Upstream version 20180207
[hcoop/debian/mlton.git] / mlton / match-compile / nested-pat.fun
1 (* Copyright (C) 2015,2017 Matthew Fluet.
2 * Copyright (C) 1999-2007 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 * Copyright (C) 1997-2000 NEC Research Institute.
5 *
6 * MLton is released under a BSD-style license.
7 * See the file MLton-LICENSE for details.
8 *)
9
10 functor NestedPat (S: NESTED_PAT_STRUCTS): NESTED_PAT =
11 struct
12
13 open S
14
15 datatype t = T of {pat: node, ty: Type.t}
16 and node =
17 Con of {arg: t option,
18 con: Con.t,
19 targs: Type.t vector}
20 | Const of {const: Const.t,
21 isChar: bool,
22 isInt: bool}
23 | Layered of Var.t * t
24 | Or of t vector
25 | Record of t SortedRecord.t
26 | Var of Var.t
27 | Vector of t vector
28 | Wild
29
30 local
31 fun make f (T r) = f r
32 in
33 val node = make #pat
34 val ty = make #ty
35 end
36
37 fun tuple ps =
38 T {pat = Record (SortedRecord.tuple ps),
39 ty = Type.tuple (Vector.map (ps, ty))}
40
41 fun layout (p, isDelimited) =
42 let
43 open Layout
44 fun delimit t = if isDelimited then t else paren t
45 in
46 case node p of
47 Con {arg, con, targs} =>
48 delimit (Pretty.conApp {arg = Option.map (arg, layoutF),
49 con = Con.layout con,
50 targs = Vector.map (targs, Type.layout)})
51 | Const {const = c, ...} => Const.layout c
52 | Layered (x, p) => delimit (seq [Var.layout x, str " as ", layoutT p])
53 | Or ps => paren (mayAlign (separateLeft (Vector.toListMap (ps, layoutT), "| ")))
54 | Record rps =>
55 SortedRecord.layout
56 {extra = "",
57 layoutElt = layoutT,
58 layoutTuple = fn ps => tuple (Vector.toListMap (ps, layoutT)),
59 record = rps,
60 separator = " = "}
61 | Var x => Var.layout x
62 | Vector ps => vector (Vector.map (ps, layoutT))
63 | Wild => str "_"
64 end
65 and layoutF p = layout (p, false)
66 and layoutT p = layout (p, true)
67
68 val layout = layoutT
69
70 fun make (p, t) =
71 T {pat = p, ty = t}
72
73 fun flatten p =
74 let
75 val ty = ty p
76 val make = fn p => make (p, ty)
77 in
78 case node p of
79 Con {arg, con, targs} =>
80 (case arg of
81 NONE => Vector.new1 p
82 | SOME arg => Vector.map (flatten arg, fn arg =>
83 make (Con {arg = SOME arg, con = con, targs = targs})))
84 | Const _ => Vector.new1 p
85 | Layered (x, p) => Vector.map (flatten p, fn p => make (Layered (x, p)))
86 | Or ps => Vector.concatV (Vector.map (ps, flatten))
87 | Record rps =>
88 let
89 val (fs, ps) = SortedRecord.unzip rps
90 val record = fn ps =>
91 Record (SortedRecord.zip (fs, ps))
92 in
93 flattens (ps, make o record)
94 end
95 | Var _ => Vector.new1 p
96 | Vector ps => flattens (ps, make o Vector)
97 | Wild => Vector.new1 p
98 end
99 and flattens (ps, make) =
100 let
101 val fpss =
102 Vector.foldr
103 (Vector.map (ps, flatten), [[]], fn (fps, fpss) =>
104 List.concat (Vector.toListMap (fps, fn fp =>
105 List.map (fpss, fn fps => fp :: fps))))
106 in
107 Vector.fromListMap (fpss, fn fps => make (Vector.fromList fps))
108 end
109
110 val flatten =
111 Trace.trace ("NestedPat.flatten", layout, Vector.layout layout)
112 flatten
113
114 fun isRefutable p =
115 case node p of
116 Con _ => true
117 | Const _ => true
118 | Layered (_, p) => isRefutable p
119 | Or ps => Vector.exists (ps, isRefutable)
120 | Record rps => SortedRecord.exists (rps, isRefutable)
121 | Var _ => false
122 | Vector _ => true
123 | Wild => false
124
125 fun isVarOrWild p =
126 case node p of
127 Var _ => true
128 | Wild => true
129 | _ => false
130
131 fun removeOthersReplace (p, {new, old}) =
132 let
133 fun loop (T {pat, ty}) =
134 let
135 val pat =
136 case pat of
137 Con {arg, con, targs} =>
138 Con {arg = Option.map (arg, loop),
139 con = con,
140 targs = targs}
141 | Const _ => pat
142 | Layered (x, p) =>
143 let
144 val p = loop p
145 in
146 if Var.equals (x, old)
147 then Layered (new, p)
148 else node p
149 end
150 | Or ps => Or (Vector.map (ps, loop))
151 | Record rps => Record (SortedRecord.map (rps, loop))
152 | Var x =>
153 if Var.equals (x, old)
154 then Var new
155 else Wild
156 | Vector ps => Vector (Vector.map (ps, loop))
157 | Wild => Wild
158 in
159 T {pat = pat, ty = ty}
160 end
161 in
162 loop p
163 end
164
165 val removeOthersReplace =
166 Trace.trace ("NestedPat.removeOthersReplace", fn (p, _) => layout p, layout)
167 removeOthersReplace
168
169 local
170 val bogus = Var.newNoname ()
171 in
172 fun removeVars (p: t): t =
173 removeOthersReplace (p, {new = bogus, old = bogus})
174 end
175
176 fun replaceTypes (p: t, f: Type.t -> Type.t): t =
177 let
178 fun loop (T {pat, ty}) =
179 let
180 val pat =
181 case pat of
182 Con {arg, con, targs} =>
183 Con {arg = Option.map (arg, loop),
184 con = con,
185 targs = Vector.map (targs, f)}
186 | Const _ => pat
187 | Layered (x, p) => Layered (x, loop p)
188 | Or ps => Or (Vector.map (ps, loop))
189 | Record rps => Record (SortedRecord.map (rps, loop))
190 | Var _ => pat
191 | Vector ps => Vector (Vector.map (ps, loop))
192 | Wild => pat
193 in
194 T {pat = pat, ty = f ty}
195 end
196 in
197 loop p
198 end
199
200 fun varsAndTypes (p: t): (Var.t * Type.t) list =
201 let
202 fun loop (p: t, accum: (Var.t * Type.t) list) =
203 case node p of
204 Con {arg, ...} => (case arg of
205 NONE => accum
206 | SOME p => loop (p, accum))
207 | Const _ => accum
208 | Layered (x, p) => loop (p, (x, ty p) :: accum)
209 | Or ps => loop (Vector.first ps, accum)
210 | Record rps => SortedRecord.fold (rps, accum, loop)
211 | Var x => (x, ty p) :: accum
212 | Vector ps => Vector.fold (ps, accum, loop)
213 | Wild => accum
214 in loop (p, [])
215 end
216
217 end