Import Debian changes 20180207-1
[hcoop/debian/mlton.git] / mlton / elaborate / scope.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 2017 Matthew Fluet.
2 * Copyright (C) 1999-2005 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 * Copyright (C) 1997-2000 NEC Research Institute.
5 *
6 * MLton is released under a BSD-style license.
7 * See the file MLton-LICENSE for details.
8 *)
9
10functor Scope (S: SCOPE_STRUCTS): SCOPE =
11struct
12
13open S
14open Ast
15
16structure Tyvars = UnorderedSet (Tyvar)
17structure Tyvars =
18 struct
19 open Tyvars
20 val fromVector = fn v =>
21 Vector.fold (v, empty, fn (x, s) => add (s, x))
22 end
23
24fun ('down, 'up)
25 processDec (d: Dec.t,
26 {(* bindType is used at datatype and type declarations. *)
27 bindType: ('down * Tyvar.t vector
28 -> 'down * ('up -> 'up)),
29 (* bindFunVal is used at fun, overload, and val declarations. *)
30 bindFunVal: ('down * Tyvar.t vector * Region.t
31 -> ('down * ('up -> Tyvar.t vector * 'up))),
32 combineUp: 'up * 'up -> 'up,
33 initDown: 'down,
34 initUp: 'up,
35 tyvar: Tyvar.t * 'down -> 'up
36 }): Dec.t * 'up =
37 let
38 fun visits (xs: 'a vector, visitX: 'a -> 'up): 'up =
39 Vector.fold (xs, initUp, fn (x, u) => combineUp (u, visitX x))
40 fun loops (xs: 'a vector, loopX: 'a -> 'a * 'up): 'a vector * 'up =
41 Vector.mapAndFold (xs, initUp, fn (x, u) =>
42 let
43 val (x, u') = loopX x
44 in
45 (x, combineUp (u, u'))
46 end)
47 fun visitTy (t: Type.t, d: 'down): 'up =
48 let
49 datatype z = datatype Type.node
50 fun visit (t: Type.t): 'up =
51 case Type.node t of
52 Con (_, ts) => visits (ts, visit)
53 | Paren t => visit t
54 | Record r =>
55 Record.fold
56 (r, initUp, fn ((_, t), u) =>
57 combineUp (u, visit t))
58 | Var a => tyvar (a, d)
59 in
60 visit t
61 end
62 fun visitTyOpt (to: Type.t option, d: 'down): 'up =
63 case to of
64 NONE => initUp
65 | SOME t => visitTy (t, d)
66 fun visitTypBind (tb: TypBind.t, d: 'down): 'up =
67 let
68 val TypBind.T tbs = TypBind.node tb
69 val u =
70 visits
71 (tbs, fn {def, tyvars, ...} =>
72 let
73 val (d, finish) = bindType (d, tyvars)
74 in
75 finish (visitTy (def, d))
76 end)
77 in
78 u
79 end
80 fun visitDatBind (db: DatBind.t, d: 'down): 'up =
81 let
82 val DatBind.T {datatypes, withtypes} = DatBind.node db
83 val u =
84 visits
85 (datatypes, fn {cons, tyvars, ...} =>
86 let
87 val (d, finish) = bindType (d, tyvars)
88 in
89 finish (visits (cons, fn (_, arg) =>
90 visitTyOpt (arg, d)))
91 end)
92 val u' = visitTypBind (withtypes, d)
93 in
94 combineUp (u, u')
95 end
96 fun visitPat (p: Pat.t, d: 'down): 'up =
97 let
98 datatype z = datatype Pat.node
99 fun visit (p: Pat.t): 'up =
100 (case Pat.node p of
101 App (_, p) => visit p
102 | Const _ => initUp
103 | Constraint (p, t) =>
104 combineUp (visit p, visitTy (t, d))
105 | FlatApp ps => visits (ps, visit)
106 | Layered {constraint, pat, ...} =>
107 combineUp (visitTyOpt (constraint, d), visit pat)
108 | List ps => visits (ps, visit)
109 | Or ps => visits (ps, visit)
110 | Paren p => visit p
111 | Record {items, ...} =>
112 Vector.fold
113 (items, initUp, fn ((_, _, i), u) =>
114 let
115 datatype z = datatype Pat.Item.t
116 val u' =
117 case i of
118 Field p => visit p
119 | Vid (_, to, po) =>
120 let
121 val u = visitTyOpt (to, d)
122 val u' = visitOpt po
123 in
124 combineUp (u, u')
125 end
126 in
127 combineUp (u, u')
128 end)
129 | Tuple ps => visits (ps, visit)
130 | Var _ => initUp
131 | Vector ps => visits (ps, visit)
132 | Wild => initUp)
133 and visitOpt opt =
134 (case opt of
135 NONE => initUp
136 | SOME p => visit p)
137 in
138 visit p
139 end
140 fun visitPrimKind (kind: PrimKind.t, d: 'down): 'up =
141 let
142 datatype z = datatype PrimKind.t
143 in
144 case kind of
145 Address {ty, ...} =>
146 visitTy (ty, d)
147 | BuildConst {ty, ...} =>
148 visitTy (ty, d)
149 | CommandLineConst {ty, ...} =>
150 visitTy (ty, d)
151 | Const {ty, ...} =>
152 visitTy (ty, d)
153 | Export {ty, ...} =>
154 visitTy (ty, d)
155 | IImport {ty, ...} =>
156 visitTy (ty, d)
157 | Import {ty, ...} =>
158 visitTy (ty, d)
159 | ISymbol {ty} =>
160 visitTy (ty, d)
161 | Prim {ty, ...} =>
162 visitTy (ty, d)
163 | Symbol {ty, ...} =>
164 visitTy (ty, d)
165 end
166 fun loopDec (d: Dec.t, down: 'down): Dec.t * 'up =
167 let
168 fun doit n = Dec.makeRegion (n, Dec.region d)
169 fun do1 ((a, u), f) = (doit (f a), u)
170 fun do2 ((a1, u1), (a2, u2), f) =
171 (doit (f (a1, a2)), combineUp (u1, u2))
172 fun doVec (ds: Dec.t vector, f: Dec.t vector -> Dec.node)
173 : Dec.t * 'up =
174 let
175 val (ds, u) = loops (ds, fn d => loopDec (d, down))
176 in
177 (doit (f ds), u)
178 end
179 fun empty () = (d, initUp)
180 datatype z = datatype Dec.node
181 in
182 case Dec.node d of
183 Abstype {body, datBind} =>
184 let
185 val (body, u) = loopDec (body, down)
186 val u' = visitDatBind (datBind, down)
187 in
188 (doit (Abstype {body = body, datBind = datBind}),
189 combineUp (u, u'))
190 end
191 | Datatype rhs =>
192 let
193 datatype z = datatype DatatypeRhs.node
194 val u =
195 case DatatypeRhs.node rhs of
196 DatBind db => visitDatBind (db, down)
197 | Repl _ => initUp
198 in
199 (d, u)
200 end
201 | DoDec e =>
202 do1 (loopExp (e, down), DoDec)
203 | Exception ebs =>
204 let
205 val u =
206 visits (ebs, fn (_, rhs) =>
207 let
208 datatype z = datatype EbRhs.node
209 val u =
210 case EbRhs.node rhs of
211 Def _ => initUp
212 | Gen to =>
213 let
214 val u = visitTyOpt (to, down)
215 in
216 u
217 end
218 in
219 u
220 end)
221 in
222 (d, u)
223 end
224 | Fix _ => empty ()
225 | Fun {tyvars, fbs} =>
226 let
227 val (down, finish) = bindFunVal (down, tyvars, Dec.region d)
228 val (fbs, u) =
229 loops (fbs, fn clauses =>
230 let
231 val (clauses, u) =
232 loops
233 (clauses, fn {body, pats, resultType} =>
234 let
235 val (body, u) = loopExp (body, down)
236 val u' =
237 visits (pats, fn p =>
238 visitPat (p, down))
239 val u'' =
240 visitTyOpt (resultType, down)
241 in
242 ({body = body,
243 pats = pats,
244 resultType = resultType},
245 combineUp (u, combineUp (u', u'')))
246 end)
247 in
248 (clauses, u)
249 end)
250 val (tyvars, u) = finish u
251 in
252 (doit (Fun {tyvars = tyvars, fbs = fbs}), u)
253 end
254 | Local (d, d') =>
255 do2 (loopDec (d, down), loopDec (d', down), Local)
256 | Open _ => empty ()
257 | Overload (i, x, tyvars, ty, ys) =>
258 let
259 val (down, finish) = bindFunVal (down, tyvars, Dec.region d)
260 val up = visitTy (ty, down)
261 val (tyvars, up) = finish up
262 in
263 (doit (Overload (i, x, tyvars, ty, ys)), up)
264 end
265 | SeqDec ds => doVec (ds, SeqDec)
266 | Type tb =>
267 let
268 val u = visitTypBind (tb, down)
269 in
270 (d, u)
271 end
272 | Val {rvbs, tyvars, vbs} =>
273 let
274 val (down, finish) = bindFunVal (down, tyvars, Dec.region d)
275 val (rvbs, u) =
276 loops (rvbs, fn {match, pat} =>
277 let
278 val (match, u) = loopMatch (match, down)
279 val u' = visitPat (pat, down)
280 in
281 ({match = match,
282 pat = pat},
283 combineUp (u, u'))
284 end)
285 val (vbs, u') =
286 loops (vbs, fn {exp, pat} =>
287 let
288 val (exp, u) = loopExp (exp, down)
289 val u' = visitPat (pat, down)
290 in
291 ({exp = exp,
292 pat = pat},
293 combineUp (u, u'))
294 end)
295 val (tyvars, u) = finish (combineUp (u, u'))
296 in
297 (doit (Val {rvbs = rvbs,
298 tyvars = tyvars,
299 vbs = vbs}),
300 u)
301 end
302 end
303 and loopExp (e: Exp.t, d: 'down): Exp.t * 'up =
304 let
305 val loopMatch = fn m => loopMatch (m, d)
306 fun loop (e: Exp.t): Exp.t * 'up =
307 let
308 fun empty () = (e, initUp)
309 val region = Exp.region e
310 fun doit n = Exp.makeRegion (n, region)
311 datatype z = datatype Exp.node
312 fun do1 ((a, u), f) = (doit (f a), u)
313 fun do2 ((a1, u1), (a2, u2), f) =
314 (doit (f (a1, a2)), combineUp (u1, u2))
315 fun do3 ((a1, u1), (a2, u2), (a3, u3), f) =
316 (doit (f (a1, a2, a3)), combineUp (u1, combineUp (u2, u3)))
317 fun doVec (es: Exp.t vector, f: Exp.t vector -> Exp.node)
318 : Exp.t * 'up =
319 let
320 val (es, u) = loops (es, loop)
321 in
322 (doit (f es), u)
323 end
324 in
325 case Exp.node e of
326 Andalso (e1, e2) => do2 (loop e1, loop e2, Andalso)
327 | App (e1, e2) => do2 (loop e1, loop e2, App)
328 | Case (e, m) => do2 (loop e, loopMatch m, Case)
329 | Const _ => empty ()
330 | Constraint (e, t) =>
331 let
332 val (e, u) = loop e
333 val u' = visitTy (t, d)
334 in
335 (doit (Constraint (e, t)),
336 combineUp (u, u'))
337 end
338 | FlatApp es => doVec (es, FlatApp)
339 | Fn m => do1 (loopMatch m, Fn)
340 | Handle (e, m) => do2 (loop e, loopMatch m, Handle)
341 | If (e1, e2, e3) => do3 (loop e1, loop e2, loop e3, If)
342 | Let (dec, e) => do2 (loopDec (dec, d), loop e, Let)
343 | List ts => doVec (ts, List)
344 | Orelse (e1, e2) => do2 (loop e1, loop e2, Orelse)
345 | Paren e => do1 (loop e, Paren)
346 | Prim kind => (e, visitPrimKind (kind, d))
347 | Raise exn => do1 (loop exn, Raise)
348 | Record r =>
349 let
350 val (r, u) =
351 Record.change
352 (r, fn res =>
353 loops (res, fn (r, e) =>
354 let val (e', u) = loop e
355 in ((r, e'), u)
356 end))
357 in
358 (doit (Record r), u)
359 end
360 | Selector _ => empty ()
361 | Seq es => doVec (es, Seq)
362 | Var _ => empty ()
363 | Vector vs => doVec (vs, Vector)
364 | While {expr, test} =>
365 do2 (loop expr, loop test, fn (expr, test) =>
366 While {expr = expr, test = test})
367 end
368 in
369 loop e
370 end
371 and loopMatch (m, d) =
372 let
373 val (Match.T rules, region) = Match.dest m
374 val (rules, u) =
375 loops (rules, fn (p, e) =>
376 let
377 val u = visitPat (p, d)
378 val (e, u') = loopExp (e, d)
379 in
380 ((p, e), combineUp (u, u'))
381 end)
382 in
383 (Match.makeRegion (Match.T rules, region),
384 u)
385 end
386 in
387 loopDec (d, initDown)
388 end
389
390fun scope (dec: Dec.t): Dec.t =
391 let
392 fun bindFunVal ((), tyvars, regionDec) =
393 let
394 fun finish {free, mayNotBind} =
395 let
396 val bound = Tyvars.+ (free, Tyvars.fromVector tyvars)
397 val mayNotBind =
398 List.keepAll
399 (mayNotBind, fn a =>
400 not (Tyvars.contains (bound, a))
401 orelse
402 let
403 open Layout
404 val _ =
405 Control.error
406 (Tyvar.region a,
407 seq [str "type variable scoped at an outer declaration: ",
408 Tyvar.layout a],
409 seq [str "scoped at: ", Region.layout regionDec])
410 in
411 false
412 end)
413 val bound = Vector.fromList (Tyvars.toList bound)
414 in
415 (bound,
416 {free = Tyvars.empty,
417 mayNotBind = List.append (Vector.toList tyvars, mayNotBind)})
418 end
419 in
420 ((), finish)
421 end
422 fun bindType ((), tyvars) =
423 let
424 fun finish {free, mayNotBind = _} =
425 {free = Tyvars.- (free, Tyvars.fromVector tyvars),
426 mayNotBind = []}
427 in
428 ((), finish)
429 end
430 fun tyvar (a, ()) =
431 {free = Tyvars.singleton a,
432 mayNotBind = []}
433 fun combineUp ({free = f, mayNotBind = m}, {free = f', mayNotBind = m'}) =
434 {free = Tyvars.+ (f, f'),
435 mayNotBind = List.append (m, m')}
436 val (dec, _) =
437 processDec (dec, {bindFunVal = bindFunVal,
438 bindType = bindType,
439 combineUp = combineUp,
440 initDown = (),
441 initUp = {free = Tyvars.empty, mayNotBind = []},
442 tyvar = tyvar})
443
444 (* Walk down and bind a tyvar as soon as you see it, removing
445 * all lower binding occurrences of the tyvar.
446 *)
447 fun bindFunVal (bound, tyvars: Tyvar.t vector, _) =
448 let
449 val tyvars =
450 Vector.keepAll
451 (tyvars, fn a =>
452 not (Tyvars.contains (bound, a)))
453 val bound =
454 Tyvars.+ (bound, Tyvars.fromVector tyvars)
455 in
456 (bound, fn () => (tyvars, ()))
457 end
458 fun bindType (bound, tyvars) =
459 let
460 val bound = Tyvars.+ (bound, Tyvars.fromVector tyvars)
461 in
462 (bound, fn () => ())
463 end
464 fun tyvar (_, _) = ()
465 val (dec, ()) =
466 processDec (dec, {bindFunVal = bindFunVal,
467 bindType = bindType,
468 combineUp = fn ((), ()) => (),
469 initDown = Tyvars.empty,
470 initUp = (),
471 tyvar = tyvar})
472 in
473 dec
474 end
475
476val scope = Trace.trace ("Scope.scope", Dec.layout, Dec.layout) scope
477
478end