Import Upstream version 20180207
[hcoop/debian/mlton.git] / mlton / xml / shrink.fun
1 (* Copyright (C) 2009 Matthew Fluet.
2 * Copyright (C) 1999-2006, 2008 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 * Copyright (C) 1997-2000 NEC Research Institute.
5 *
6 * MLton is released under a BSD-style license.
7 * See the file MLton-LICENSE for details.
8 *)
9
10 (* This simplifier is based on the following article.
11 * Shrinking Lambda Expressions in Linear Time.
12 * Journal of Functional Programming. Vol 7, no 5, 1997.
13 *)
14
15 functor Shrink (S: SHRINK_STRUCTS): SHRINK =
16 struct
17
18 open S
19 open Dec PrimExp
20
21 val tracePrimApplyInfo = Trace.info "Xml.Shrink.Prim.apply"
22
23 val traceShrinkExp =
24 Trace.trace ("Xml.Shrink.shrinkExp", Exp.layout, Exp.layout)
25
26 val traceShrinkLambda =
27 Trace.trace ("Xml.Shrink.shrinkLambda", Lambda.layout, Lambda.layout)
28
29 fun inc (r: int ref, n) =
30 let val n = !r + n
31 in Assert.assert ("Xml.Shrink.inc", fn () => n >= 0)
32 ; r := n
33 end
34
35 structure VarInfo =
36 struct
37 datatype t =
38 Mono of monoVarInfo
39 | Poly of VarExp.t
40 and value =
41 ConApp of {con: Con.t,
42 targs: Type.t vector,
43 arg: t option}
44 | Const of Const.t
45 | Lambda of {isInlined: bool ref,
46 lam: Lambda.t}
47 | Tuple of t vector
48 withtype monoVarInfo = {numOccurrences: int ref,
49 value: value option ref,
50 varExp: VarExp.t}
51
52 local
53 open Layout
54 in
55 val rec layout =
56 fn Mono {numOccurrences, value, varExp} =>
57 record [("numOccurrences", Int.layout (!numOccurrences)),
58 ("value", Option.layout layoutValue (!value)),
59 ("varExp", VarExp.layout varExp)]
60 | Poly x => seq [str "Poly ", VarExp.layout x]
61 and layoutValue =
62 fn ConApp {con, arg, ...} =>
63 seq [Con.layout con,
64 case arg of
65 NONE => empty
66 | SOME i => paren (layout i)]
67 | Const c => Const.layout c
68 | Lambda {isInlined, ...} =>
69 seq [str "Lambda ", Bool.layout (!isInlined)]
70 | Tuple is => Vector.layout layout is
71 end
72
73 val inc =
74 fn (i, n) =>
75 case i of
76 Mono {numOccurrences = r, ...} => inc (r, n)
77 | Poly _ => ()
78
79 val inc =
80 Trace.trace2 ("Xml.Shrink.VarInfo.inc", layout, Int.layout, Unit.layout) inc
81
82 fun inc1 i = inc (i, 1)
83
84 val inc1 = Trace.trace ("Xml.Shrink.VarInfo.inc1", layout, Unit.layout) inc1
85
86 fun delete i = inc (i, ~1)
87
88 val delete = Trace.trace ("Xml.Shrink.VarInfo.delete", layout, Unit.layout) delete
89
90 fun deletes is = Vector.foreach (is, delete)
91
92 val varExp =
93 fn Mono {varExp, ...} => varExp
94 | Poly x => x
95
96 fun equals (vi1, vi2) =
97 VarExp.equals (varExp vi1, varExp vi2)
98 end
99
100 structure InternalVarInfo =
101 struct
102 datatype t =
103 VarInfo of VarInfo.t
104 | Self
105
106 val layout =
107 fn VarInfo i => VarInfo.layout i
108 | Self => Layout.str "self"
109 end
110
111 structure MonoVarInfo =
112 struct
113 type t = VarInfo.monoVarInfo
114 end
115
116 structure Value =
117 struct
118 datatype t = datatype VarInfo.value
119
120 fun toPrimExp v =
121 case v of
122 ConApp {con, targs, arg} =>
123 PrimExp.ConApp {con = con,
124 targs = targs,
125 arg = Option.map (arg, VarInfo.varExp)}
126 | Const c => PrimExp.Const c
127 | Lambda {lam, ...} => PrimExp.Lambda lam
128 | Tuple vs => PrimExp.Tuple (Vector.map (vs, VarInfo.varExp))
129 end
130
131 fun shrinkOnce (Program.T {datatypes, body, overflow}) =
132 let
133 (* Keep track of the number of constuctors in each datatype so that
134 * we can eliminate redundant defaults.
135 *)
136 val {get = conNumCons: Con.t -> int , set = setConNumCons, ...} =
137 Property.getSetOnce (Con.plist, Property.initConst ~1)
138 val _ =
139 Vector.foreach
140 (datatypes, fn {cons, ...} =>
141 let
142 val n = Vector.length cons
143 in
144 Vector.foreach (cons, fn {con, ...} => setConNumCons (con, n))
145 end)
146 fun isExhaustive (cases: exp Cases.t): bool =
147 case cases of
148 Cases.Con v =>
149 (0 < Vector.length v
150 andalso (Vector.length v
151 = conNumCons (Pat.con (#1 (Vector.first v)))))
152 | _ => false
153 val {get = varInfo: Var.t -> InternalVarInfo.t, set = setVarInfo, ...} =
154 Property.getSet (Var.plist,
155 Property.initRaise ("shrink varInfo", Var.layout))
156 val setVarInfo =
157 Trace.trace2 ("Xml.Shrink.setVarInfo",
158 Var.layout, InternalVarInfo.layout, Unit.layout)
159 setVarInfo
160 val varInfo =
161 Trace.trace ("Xml.Shrink.varInfo", Var.layout, InternalVarInfo.layout)
162 varInfo
163 fun monoVarInfo x =
164 case varInfo x of
165 InternalVarInfo.VarInfo (VarInfo.Mono i) => i
166 | _ => Error.bug "Xml.Shrink.monoVarInfo"
167 fun varExpInfo (x as VarExp.T {var, ...}): VarInfo.t =
168 case varInfo var of
169 InternalVarInfo.Self => VarInfo.Poly x
170 | InternalVarInfo.VarInfo i => i
171 val varExpInfo =
172 Trace.trace ("Xml.Shrink.varExpInfo", VarExp.layout, VarInfo.layout) varExpInfo
173 fun varExpInfos xs = Vector.map (xs, varExpInfo)
174 fun replaceInfo (x: Var.t,
175 {numOccurrences = r, ...}: MonoVarInfo.t,
176 i: VarInfo.t): unit =
177 (VarInfo.inc (i, !r)
178 ; setVarInfo (x, InternalVarInfo.VarInfo i))
179 val replaceInfo =
180 Trace.trace ("Xml.Shrink.replaceInfo",
181 fn (x, _, i) => Layout.tuple [Var.layout x,
182 VarInfo.layout i],
183 Unit.layout)
184 replaceInfo
185 fun replace (x, i) = replaceInfo (x, monoVarInfo x, i)
186 val shrinkVarExp = VarInfo.varExp o varExpInfo
187 local
188 fun handleBoundVar (x, ts, _) =
189 setVarInfo (x,
190 if Vector.isEmpty ts
191 then (InternalVarInfo.VarInfo
192 (VarInfo.Mono {numOccurrences = ref 0,
193 value = ref NONE,
194 varExp = VarExp.mono x}))
195 else InternalVarInfo.Self)
196 fun handleVarExp x = VarInfo.inc1 (varExpInfo x)
197 in
198 fun countExp (e: Exp.t): unit =
199 Exp.foreach {exp = e,
200 handleBoundVar = handleBoundVar,
201 handleExp = fn _ => (),
202 handlePrimExp = fn _ => (),
203 handleVarExp = handleVarExp}
204 end
205 fun deleteVarExp (x: VarExp.t): unit =
206 VarInfo.delete (varExpInfo x)
207 fun deleteExp (e: Exp.t): unit = Exp.foreachVarExp (e, deleteVarExp)
208 val deleteExp =
209 Trace.trace ("Xml.Shrink.deleteExp", Exp.layout, Unit.layout) deleteExp
210 fun deleteLambda l = deleteExp (Lambda.body l)
211 fun primApp (prim: Type.t Prim.t, args: VarInfo.t vector)
212 : (Type.t, VarInfo.t) Prim.ApplyResult.t =
213 let
214 val args' =
215 Vector.map
216 (args, fn vi =>
217 case vi of
218 VarInfo.Poly _ => Prim.ApplyArg.Var vi
219 | VarInfo.Mono {value, ...} =>
220 (case !value of
221 SOME (Value.ConApp {con, arg, ...}) =>
222 if isSome arg
223 then Prim.ApplyArg.Var vi
224 else Prim.ApplyArg.Con {con = con,
225 hasArg = false}
226 | SOME (Value.Const c) =>
227 Prim.ApplyArg.Const c
228 | _ => Prim.ApplyArg.Var vi))
229 in
230 Trace.traceInfo'
231 (tracePrimApplyInfo,
232 fn (p, args, _) =>
233 let
234 open Layout
235 in
236 seq [Prim.layout p, str " ",
237 List.layout (Prim.ApplyArg.layout
238 (VarExp.layout o VarInfo.varExp)) args]
239 end,
240 Prim.ApplyResult.layout (VarExp.layout o VarInfo.varExp))
241 Prim.apply
242 (prim, Vector.toList args', VarInfo.equals)
243 end
244 (*---------------------------------------------------*)
245 (* shrinkExp *)
246 (*---------------------------------------------------*)
247 fun shrinkExp arg: Exp.t =
248 traceShrinkExp
249 (fn (e: Exp.t) =>
250 let
251 val {decs, result} = Exp.dest e
252 in
253 Exp.make {decs = shrinkDecs decs,
254 result = shrinkVarExp result}
255 end) arg
256 and shrinkDecs (decs: Dec.t list): Dec.t list =
257 case decs of
258 [] => []
259 | dec :: decs =>
260 case dec of
261 Exception _ => dec :: shrinkDecs decs
262 | PolyVal {var, tyvars, ty, exp} =>
263 Dec.PolyVal {var = var, tyvars = tyvars, ty = ty,
264 exp = shrinkExp exp}
265 :: shrinkDecs decs
266 | Fun {tyvars, decs = decs'} =>
267 if Vector.isEmpty tyvars
268 then
269 let
270 val decs' =
271 Vector.keepAll
272 (decs', fn {lambda, var, ...} =>
273 let
274 val {numOccurrences, value, ...} =
275 monoVarInfo var
276 in if 0 = !numOccurrences
277 then (deleteLambda lambda; false)
278 else (value := (SOME
279 (Value.Lambda
280 {isInlined = ref false,
281 lam = lambda}))
282 ; true)
283 end)
284 val decs = shrinkDecs decs
285 (* Need to walk over all the decs and remove
286 * their value before shrinking any of them
287 * because they are mutually recursive.
288 *)
289 val decs' =
290 Vector.keepAll
291 (decs', fn {var, lambda, ...} =>
292 let
293 val {numOccurrences, value, ...} =
294 monoVarInfo var
295 in
296 case !value of
297 SOME (Value.Lambda {isInlined, ...}) =>
298 not (!isInlined)
299 andalso
300 if 0 = !numOccurrences
301 then (deleteLambda lambda
302 ; false)
303 else (value := NONE; true)
304 | _ => Error.bug "Xml.Shrink.shrinkDecs: should be a lambda"
305 end)
306 in
307 if Vector.isEmpty decs'
308 then decs
309 else
310 Dec.Fun {tyvars = tyvars,
311 decs =
312 Vector.map
313 (decs', fn {var, ty, lambda} =>
314 {var = var,
315 ty = ty,
316 lambda = shrinkLambda lambda})}
317 :: decs
318 end
319 else
320 Dec.Fun {tyvars = tyvars,
321 decs =
322 Vector.map
323 (decs', fn {var, ty, lambda} =>
324 {var = var,
325 ty = ty,
326 lambda = shrinkLambda lambda})}
327 :: shrinkDecs decs
328 | MonoVal b =>
329 shrinkMonoVal (b, fn () => shrinkDecs decs)
330 and shrinkMonoVal ({var, ty, exp},
331 rest: unit -> Dec.t list) =
332 let
333 val info as {numOccurrences, value, ...} = monoVarInfo var
334 fun finish (exp, decs) =
335 MonoVal {var = var, ty = ty, exp = exp} :: decs
336 fun nonExpansive (delete: unit -> unit,
337 set: unit -> (unit -> PrimExp.t) option) =
338 if 0 = !numOccurrences
339 then (delete (); rest ())
340 else let
341 val s = set ()
342 val decs = rest ()
343 in if 0 = !numOccurrences
344 then (delete (); decs)
345 else (case s of
346 NONE => decs
347 | SOME mk => finish (mk (), decs))
348 end
349 fun expansive (e: PrimExp.t) = finish (e, rest ())
350 fun nonExpansiveValue (delete, v: Value.t) =
351 nonExpansive
352 (delete,
353 fn () => (value := SOME v
354 ; SOME (fn () => Value.toPrimExp v)))
355 fun expression (e: Exp.t): Dec.t list =
356 let
357 val {decs = decs', result} = Exp.dest (shrinkExp e)
358 val _ = replaceInfo (var, info, varExpInfo result)
359 val decs = rest ()
360 in decs' @ decs
361 end
362 in
363 case exp of
364 App {func, arg} =>
365 let
366 val arg = varExpInfo arg
367 fun normal func =
368 expansive (App {func = func,
369 arg = VarInfo.varExp arg})
370 in case varExpInfo func of
371 VarInfo.Poly x => normal x
372 | VarInfo.Mono {numOccurrences, value, varExp, ...} =>
373 case (!numOccurrences, !value) of
374 (1, SOME (Value.Lambda {isInlined, lam = l})) =>
375 if not (Lambda.mayInline l)
376 then normal varExp
377 else
378 let
379 val {arg = form, body, ...} = Lambda.dest l
380 in
381 VarInfo.delete arg
382 ; replace (form, arg)
383 ; isInlined := true
384 ; numOccurrences := 0
385 ; expression body
386 end
387 | _ => normal varExp
388 end
389 | Case {test, cases, default} =>
390 let
391 fun match (cases, f): Dec.t list =
392 let
393 val _ = deleteVarExp test
394 fun step (i, (c, e), ()) =
395 if f c
396 then
397 (Vector.foreachR (cases, i + 1,
398 Vector.length cases,
399 deleteExp o #2)
400 ; Option.app (default, deleteExp o #1)
401 ; Vector.Done (expression e))
402 else (deleteExp e; Vector.Continue ())
403 fun done () =
404 case default of
405 SOME (e, _) => expression e
406 | NONE => Error.bug "Xml.Shrink.shrinkMonoVal: Case, match"
407 in Vector.fold' (cases, 0, (), step, done)
408 end
409 fun normal test =
410 let
411 (* Eliminate redundant default case. *)
412 val default =
413 if isExhaustive cases
414 then (Option.app (default, deleteExp o #1)
415 ; NONE)
416 else Option.map (default, fn (e, r) =>
417 (shrinkExp e, r))
418 in
419 expansive
420 (Case {test = test,
421 cases = Cases.map (cases, shrinkExp),
422 default = default})
423 end
424 in
425 case varExpInfo test of
426 VarInfo.Poly test => normal test
427 | VarInfo.Mono {value, varExp, ...} =>
428 case (cases, !value) of
429 (Cases.Con cases,
430 SOME (Value.ConApp {con = c, arg, ...})) =>
431 let
432 val match =
433 fn f =>
434 match (cases,
435 fn Pat.T {con = c', arg, ...} =>
436 Con.equals (c, c')
437 andalso f arg)
438 in case arg of
439 NONE => match Option.isNone
440 | SOME v =>
441 match
442 (fn SOME (x, _) => (replace (x, v); true)
443 | _ => false)
444 end
445 | (_, SOME (Value.Const c)) =>
446 (case (cases, c) of
447 (Cases.Word (_, l), Const.Word w) =>
448 match (l, fn w' => WordX.equals (w, w'))
449 | _ => Error.bug "Xml.Shrink.shrinkMonoVal: Case, strange case")
450 | (_, NONE) => normal varExp
451 | _ => Error.bug "Xml.Shrink.shrinkMonoVal: Case, default"
452 end
453 | ConApp {con, targs, arg} =>
454 if Con.equals (con, Con.overflow)
455 then
456 expansive
457 (ConApp
458 {con = con,
459 targs = targs,
460 arg = Option.map (arg, shrinkVarExp)})
461 else
462 let
463 val arg = Option.map (arg, varExpInfo)
464 in nonExpansiveValue
465 (fn () => Option.app (arg, VarInfo.delete),
466 Value.ConApp {con = con, targs = targs, arg = arg})
467 end
468 | Const c => nonExpansiveValue (fn () => (), Value.Const c)
469 | Handle {try, catch, handler} =>
470 expansive (Handle {try = shrinkExp try,
471 catch = catch,
472 handler = shrinkExp handler})
473 | Lambda l =>
474 let val isInlined = ref false
475 in nonExpansive
476 (fn () => if !isInlined then () else deleteLambda l,
477 fn () => (value := SOME (Value.Lambda
478 {isInlined = isInlined,
479 lam = l})
480 ; SOME (fn () => Lambda (shrinkLambda l))))
481 end
482 | PrimApp {prim, args, targs} =>
483 let
484 val args = varExpInfos args
485 fun doit {prim, targs, args} =
486 let
487 fun make () =
488 PrimApp {prim = prim, targs = targs,
489 args = Vector.map (args, VarInfo.varExp)}
490 in
491 if Prim.maySideEffect prim
492 then expansive (make ())
493 else nonExpansive (fn () => VarInfo.deletes args,
494 fn () => SOME make)
495 end
496 fun default () = doit {prim = prim, targs = targs, args = args}
497 datatype z = datatype Prim.ApplyResult.t
498 in
499 case primApp (prim, args) of
500 Apply (prim, args') =>
501 let
502 val args' = Vector.fromList args'
503 val {no = unused, ...} =
504 Vector.partition
505 (args, fn arg =>
506 Vector.exists
507 (args', fn arg' =>
508 VarInfo.equals (arg, arg')))
509 val _ = VarInfo.deletes unused
510 in
511 doit {prim = prim, targs = targs, args = args'}
512 end
513 | Bool b =>
514 let
515 val _ = VarInfo.deletes args
516 in
517 nonExpansiveValue
518 (fn () => (),
519 Value.ConApp {con = Con.fromBool b,
520 targs = Vector.new0 (),
521 arg = NONE})
522 end
523 | Const c =>
524 let
525 val _ = VarInfo.deletes args
526 in
527 nonExpansiveValue
528 (fn () => (),
529 Value.Const c)
530 end
531 | Var x =>
532 let
533 val _ =
534 Vector.foreach
535 (args, fn arg =>
536 if VarInfo.equals (arg, x)
537 then ()
538 else VarInfo.delete arg)
539 in
540 replaceInfo (var, info, x)
541 ; VarInfo.delete x
542 ; rest ()
543 end
544 | _ => default ()
545 end
546 | Profile _ => expansive exp
547 | Raise {exn, extend} =>
548 expansive (Raise {exn = shrinkVarExp exn, extend = extend})
549 | Select {tuple, offset} =>
550 let
551 fun normal x = Select {tuple = x, offset = offset}
552 in case varExpInfo tuple of
553 VarInfo.Poly x => finish (normal x, rest ())
554 | VarInfo.Mono {numOccurrences, value, varExp, ...} =>
555 nonExpansive
556 (fn () => inc (numOccurrences, ~1),
557 fn () =>
558 case !value of
559 NONE => SOME (fn () => normal varExp)
560 | SOME (Value.Tuple vs) =>
561 (inc (numOccurrences, ~1)
562 ; replaceInfo (var, info, Vector.sub (vs, offset))
563 ; NONE)
564 | _ => Error.bug "Xml.Shrink.shrinkMonoVal: Select")
565 end
566 | Tuple xs =>
567 let val xs = varExpInfos xs
568 in nonExpansiveValue (fn () => VarInfo.deletes xs,
569 Value.Tuple xs)
570 end
571 | Var x => let val x = varExpInfo x
572 in replaceInfo (var, info, x)
573 ; VarInfo.delete x
574 ; rest ()
575 end
576 end
577 and shrinkLambda l: Lambda.t =
578 traceShrinkLambda
579 (fn l =>
580 let
581 val {arg, argType, body, mayInline} = Lambda.dest l
582 in
583 Lambda.make {arg = arg,
584 argType = argType,
585 body = shrinkExp body,
586 mayInline = mayInline}
587 end) l
588 val _ = countExp body
589 val _ =
590 Option.app
591 (overflow, fn x =>
592 case varInfo x of
593 InternalVarInfo.VarInfo i => VarInfo.inc1 i
594 | _ => Error.bug "Xml.Shrink.shrinkOnce: strange overflow var")
595 val body = shrinkExp body
596 (* Must lookup the overflow variable again because it may have been set
597 * during shrinking.
598 *)
599 val overflow =
600 Option.map
601 (overflow, fn x =>
602 case varInfo x of
603 InternalVarInfo.VarInfo i => VarExp.var (VarInfo.varExp i)
604 | _ => Error.bug "Xml.Shrink.shrinkOnce: strange overflow var")
605 val _ = Exp.clear body
606 val _ = Vector.foreach (datatypes, fn {cons, ...} =>
607 Vector.foreach (cons, Con.clear o #con))
608 in
609 Program.T {datatypes = datatypes,
610 body = body,
611 overflow = overflow}
612 end
613
614 val shrinkOnce =
615 Trace.trace ("Xml.Shrink.shrinkOnce", Program.layout, Program.layout) shrinkOnce
616
617 val shrink = shrinkOnce o shrinkOnce
618
619 structure SccFuns = SccFuns (S)
620
621 val shrink = shrink o SccFuns.sccFuns
622
623 val shrink =
624 Trace.trace ("Xml.Shrink.shrink", Program.layout, Program.layout) shrink
625
626 end