Import Debian changes 20180207-1
[hcoop/debian/mlton.git] / mlton / ssa / deep-flatten.fun
1 (* Copyright (C) 2009,2017 Matthew Fluet.
2 * Copyright (C) 2004-2008 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 *
5 * MLton is released under a BSD-style license.
6 * See the file MLton-LICENSE for details.
7 *)
8
9 functor DeepFlatten (S: SSA2_TRANSFORM_STRUCTS): SSA2_TRANSFORM =
10 struct
11
12 open S
13
14 datatype z = datatype Exp.t
15 datatype z = datatype Statement.t
16 datatype z = datatype Transfer.t
17
18 structure Tree = Tree (structure Seq = Prod)
19
20 structure TypeTree =
21 struct
22 datatype t = datatype Tree.t
23
24 datatype info =
25 Flat
26 | NotFlat of {ty: Type.t,
27 var: Var.t option}
28
29 type t = info Tree.t
30
31 fun layout (t: t): Layout.t =
32 Tree.layout
33 (t,
34 let
35 open Layout
36 in
37 fn Flat => str "Flat"
38 | NotFlat {ty, var} =>
39 seq [str "NotFlat ",
40 record [("ty", Type.layout ty),
41 ("var", Option.layout Var.layout var)]]
42 end)
43
44 val isFlat: t -> bool =
45 fn T (i, _) =>
46 case i of
47 Flat => true
48 | NotFlat _ => false
49 end
50
51 structure VarTree =
52 struct
53 open TypeTree
54
55 val labelRoot: t * Var.t -> t =
56 fn (t as T (info, ts), x) =>
57 case info of
58 Flat => t
59 | NotFlat {ty, ...} => T (NotFlat {ty = ty, var = SOME x}, ts)
60
61 val fromTypeTree: TypeTree.t -> t = fn t => t
62
63 val foldRoots: t * 'a * (Var.t * 'a -> 'a) -> 'a =
64 fn (t, a, f) =>
65 let
66 fun loop (T (info, children), a: 'a): 'a =
67 case info of
68 Flat => Prod.fold (children, a, loop)
69 | NotFlat {var, ...} =>
70 case var of
71 NONE => Error.bug "DeepFlatten.VarTree.foldRoots"
72 | SOME x => f (x, a)
73 in
74 loop (t, a)
75 end
76
77 fun foreachRoot (t, f) = foldRoots (t, (), f o #1)
78
79 val rootsOnto: t * Var.t list -> Var.t list =
80 fn (t, ac) =>
81 List.appendRev (foldRoots (t, [], op ::), ac)
82
83 val rec dropVars: t -> t =
84 fn T (info, ts) =>
85 let
86 val info =
87 case info of
88 Flat => Flat
89 | NotFlat {ty, ...} => NotFlat {ty = ty, var = NONE}
90 in
91 T (info, Prod.map (ts, dropVars))
92 end
93
94 fun fillInRoots (t: t, {base: Var.t Base.t, offset: int})
95 : t * Statement.t list =
96 let
97 fun loop (t as T (info, ts), offset, ac) =
98 case info of
99 Flat =>
100 let
101 val (ts, (offset, ac)) =
102 Vector.mapAndFold
103 (Prod.dest ts, (offset, ac),
104 fn ({elt = t, isMutable}, (offset, ac)) =>
105 let
106 val (t, offset, ac) = loop (t, offset, ac)
107 in
108 ({elt = t, isMutable = isMutable},
109 (offset, ac))
110 end)
111 in
112 (T (Flat, Prod.make ts), offset, ac)
113 end
114 | NotFlat {ty, var} =>
115 let
116 val (t, ac) =
117 case var of
118 NONE =>
119 let
120 val var = Var.newNoname ()
121 in
122 (T (NotFlat {ty = ty, var = SOME var}, ts),
123 Bind
124 {exp = Select {base = base,
125 offset = offset},
126 ty = ty,
127 var = SOME var} :: ac)
128 end
129 | SOME _ => (t, ac)
130 in
131 (t, offset + 1, ac)
132 end
133 val (t, _, ac) = loop (t, offset, [])
134 in
135 (t, ac)
136 end
137
138 val fillInRoots =
139 Trace.trace2 ("DeepFlatten.VarTree.fillInRoots",
140 layout,
141 fn {base, offset} =>
142 Layout.record [("base", Base.layout (base, Var.layout)),
143 ("offset", Int.layout offset)],
144 Layout.tuple2 (layout, List.layout Statement.layout))
145 fillInRoots
146 end
147
148 fun flatten {base: Var.t Base.t option,
149 from: VarTree.t,
150 offset: int,
151 to: TypeTree.t}: {offset: int} * VarTree.t * Statement.t list =
152 let
153 val Tree.T (from, fs) = from
154 in
155 case from of
156 VarTree.Flat =>
157 if TypeTree.isFlat to
158 then flattensAt {base = base,
159 froms = fs,
160 offset = offset,
161 tos = Tree.children to}
162 else Error.bug "DeepFlatten.flatten: cannot flatten from Flat to NotFlat"
163 | VarTree.NotFlat {ty, var} =>
164 let
165 val (var, ss) =
166 case var of
167 NONE =>
168 let
169 val base =
170 case base of
171 NONE => Error.bug "DeepFlatten.flatten: flatten missing base"
172 | SOME base => base
173 val result = Var.newNoname ()
174 in
175 (result,
176 [Bind {exp = Select {base = base,
177 offset = offset},
178 ty = ty,
179 var = SOME result}])
180 end
181 | SOME var => (var, [])
182 val (r, ss) =
183 if TypeTree.isFlat to
184 then
185 let
186 val (_, r, ss') =
187 flattensAt {base = SOME (Base.Object var),
188 froms = fs,
189 offset = 0,
190 tos = Tree.children to}
191 in
192 (r, ss @ ss')
193 end
194 else (Tree.T (VarTree.NotFlat {ty = ty, var = SOME var},
195 fs),
196 ss)
197 in
198 ({offset = 1 + offset}, r, ss)
199 end
200 end
201 and flattensAt {base: Var.t Base.t option,
202 froms: VarTree.t Prod.t,
203 offset: int,
204 tos: TypeTree.t Prod.t} =
205 let
206 val (ts, (off, ss)) =
207 Vector.map2AndFold
208 (Prod.dest froms, Prod.dest tos, ({offset = offset}, []),
209 fn ({elt = f, isMutable}, {elt = t, ...}, ({offset}, ss)) =>
210 let
211 val () =
212 if isMutable
213 then Error.bug "DeepFlatten.flattensAt: mutable"
214 else ()
215 val ({offset}, t, ss') =
216 flatten {base = base,
217 from = f,
218 offset = offset,
219 to = t}
220 in
221 ({elt = t, isMutable = false},
222 ({offset = offset}, ss' @ ss))
223 end)
224 in
225 (off, Tree.T (VarTree.Flat, Prod.make ts), ss)
226 end
227
228 fun coerceTree {from: VarTree.t, to: TypeTree.t}: VarTree.t * Statement.t list =
229 let
230 val (_, r, ss) =
231 flatten {base = NONE,
232 from = from,
233 offset = 0,
234 to = to}
235 in
236 (r, ss)
237 end
238
239 val coerceTree =
240 let
241 open Layout
242 in
243 Trace.trace ("DeepFlatten.coerceTree",
244 fn {from, to} =>
245 record [("from", VarTree.layout from),
246 ("to", TypeTree.layout to)],
247 fn (vt, ss) =>
248 tuple [VarTree.layout vt,
249 List.layout Statement.layout ss])
250 coerceTree
251 end
252
253 structure Flat =
254 struct
255 datatype t = Flat | NotFlat
256
257 val toString: t -> string =
258 fn Flat => "Flat"
259 | NotFlat => "NotFlat"
260
261 val layout = Layout.str o toString
262 end
263
264 datatype z = datatype Flat.t
265
266 structure Value =
267 struct
268 datatype t =
269 Ground of Type.t
270 | Object of object Equatable.t
271 | Weak of {arg: t}
272 withtype object = {args: t Prod.t,
273 coercedFrom: t AppendList.t ref,
274 con: ObjectCon.t,
275 finalOffsets: int vector option ref,
276 finalTree: TypeTree.t option ref,
277 finalType: Type.t option ref,
278 finalTypes: Type.t Prod.t option ref,
279 flat: Flat.t ref}
280
281 fun layout (v: t): Layout.t =
282 let
283 open Layout
284 in
285 case v of
286 Ground t => Type.layout t
287 | Object e =>
288 Equatable.layout
289 (e, fn {args, con, flat, ...} =>
290 seq [str "Object ",
291 record [("args", Prod.layout (args, layout)),
292 ("con", ObjectCon.layout con),
293 ("flat", Flat.layout (! flat))]])
294 | Weak {arg, ...} => seq [str "Weak ", layout arg]
295 end
296
297 val ground = Ground
298
299 val traceCoerce =
300 Trace.trace ("DeepFlatten.Value.coerce",
301 fn {from, to} =>
302 Layout.record [("from", layout from),
303 ("to", layout to)],
304 Unit.layout)
305
306 val traceUnify =
307 Trace.trace2 ("DeepFlatten.Value.unify", layout, layout, Unit.layout)
308
309 val rec unify: t * t -> unit =
310 fn arg =>
311 traceUnify
312 (fn (v, v') =>
313 case (v, v') of
314 (Ground _, Ground _) => ()
315 | (Object e, Object e') =>
316 let
317 val callDont = ref false
318 val () =
319 Equatable.equate
320 (e, e',
321 fn (z as {args = a, coercedFrom = c, flat = f, ...},
322 z' as {args = a', coercedFrom = c', flat = f', ...}) =>
323 let
324 val () = unifyProd (a, a')
325 in
326 case (!f, !f') of
327 (Flat, Flat) =>
328 (c := AppendList.append (!c', !c); z)
329 | (Flat, NotFlat) =>
330 (callDont := true; z)
331 | (NotFlat, Flat) =>
332 (callDont := true; z')
333 | (NotFlat, NotFlat) => z
334 end)
335 in
336 if !callDont
337 then dontFlatten v
338 else ()
339 end
340 | (Weak {arg = a, ...}, Weak {arg = a', ...}) =>
341 unify (a, a')
342 | _ => Error.bug "DeepFlatten.unify: strange") arg
343 and unifyProd =
344 fn (p, p') =>
345 Vector.foreach2
346 (Prod.dest p, Prod.dest p',
347 fn ({elt = e, ...}, {elt = e', ...}) => unify (e, e'))
348 and dontFlatten: t -> unit =
349 fn v =>
350 case v of
351 Object e =>
352 let
353 val {coercedFrom, flat, ...} = Equatable.value e
354 in
355 case ! flat of
356 Flat =>
357 let
358 val () = flat := NotFlat
359 val from = !coercedFrom
360 val () = coercedFrom := AppendList.empty
361 in
362 AppendList.foreach (from, fn v' => unify (v, v'))
363 end
364 | NotFlat => ()
365 end
366 | _ => ()
367
368 val rec coerce =
369 fn arg as {from, to} =>
370 traceCoerce
371 (fn _ =>
372 case (from, to) of
373 (Ground _, Ground _) => ()
374 | (Object e, Object e') =>
375 if Equatable.equals (e, e')
376 then ()
377 else
378 Equatable.whenComputed
379 (e', fn {args = a', coercedFrom = c', flat = f', ...} =>
380 let
381 val {args = a, con, ...} = Equatable.value e
382 in
383 if Prod.someIsMutable a orelse ObjectCon.isVector con
384 then unify (from, to)
385 else
386 case !f' of
387 Flat => (AppendList.push (c', from)
388 ; coerceProd {from = a, to = a'})
389 | NotFlat => unify (from, to)
390 end)
391 | (Weak _, Weak _) => unify (from, to)
392 | _ => Error.bug "DeepFlatten.coerce: strange") arg
393 and coerceProd =
394 fn {from = p, to = p'} =>
395 Vector.foreach2
396 (Prod.dest p, Prod.dest p', fn ({elt = e, ...}, {elt = e', ...}) =>
397 coerce {from = e, to = e'})
398
399 fun mayFlatten {args, con}: bool =
400 (* Don't flatten constructors, since they are part of a sum type.
401 * Don't flatten unit.
402 * Don't flatten vectors (of course their components can be
403 * flattened).
404 * Don't flatten objects with mutable fields, since sharing must be
405 * preserved.
406 *)
407 not (Prod.isEmpty args)
408 andalso Prod.allAreImmutable args
409 andalso (case con of
410 ObjectCon.Con _ => false
411 | ObjectCon.Tuple => true
412 | ObjectCon.Vector => false)
413
414 fun objectFields {args, con} =
415 let
416 (* Don't flatten object components that are immutable fields. Those
417 * have already had a chance to be flattened by other passes.
418 *)
419 val _ =
420 if (case con of
421 ObjectCon.Con _ => true
422 | ObjectCon.Tuple => true
423 | ObjectCon.Vector => false)
424 then Vector.foreach (Prod.dest args, fn {elt, isMutable} =>
425 if isMutable
426 then ()
427 else dontFlatten elt)
428 else ()
429 val flat =
430 if mayFlatten {args = args, con = con}
431 then Flat.Flat
432 else Flat.NotFlat
433 in
434 {args = args,
435 coercedFrom = ref AppendList.empty,
436 con = con,
437 finalOffsets = ref NONE,
438 finalTree = ref NONE,
439 finalType = ref NONE,
440 finalTypes = ref NONE,
441 flat = ref flat}
442 end
443
444 fun object f =
445 Object (Equatable.delay (fn () => objectFields (f ())))
446
447 val tuple: t Prod.t -> t =
448 fn vs =>
449 Object (Equatable.new (objectFields {args = vs, con = ObjectCon.Tuple}))
450
451 val tuple =
452 Trace.trace ("DeepFlatten.Value.tuple",
453 fn p => Prod.layout (p, layout),
454 layout)
455 tuple
456
457 fun weak (arg: t) = Weak {arg = arg}
458
459 val deObject: t -> object option =
460 fn v =>
461 case v of
462 Object e => SOME (Equatable.value e)
463 | _ => NONE
464
465 val traceFinalType =
466 Trace.trace ("DeepFlatten.Value.finalType", layout, Type.layout)
467 val traceFinalTypes =
468 Trace.trace ("DeepFlatten.Value.finalTypes",
469 layout,
470 fn p => Prod.layout (p, Type.layout))
471
472 fun finalTree (v: t): TypeTree.t =
473 let
474 fun notFlat (): TypeTree.info =
475 TypeTree.NotFlat {ty = finalType v, var = NONE}
476 in
477 case deObject v of
478 NONE => Tree.T (notFlat (), Prod.empty ())
479 | SOME {args, finalTree = r, flat, ...} =>
480 Ref.memoize
481 (r, fn () =>
482 let
483 val info =
484 case !flat of
485 Flat => TypeTree.Flat
486 | NotFlat => notFlat ()
487 in
488 Tree.T (info, Prod.map (args, finalTree))
489 end)
490 end
491 and finalType arg: Type.t =
492 traceFinalType
493 (fn v =>
494 case v of
495 Ground t => t
496 | Object e =>
497 let
498 val {finalType = r, ...} = Equatable.value e
499 in
500 Ref.memoize (r, fn () => Prod.elt (finalTypes v, 0))
501 end
502 | Weak {arg, ...} => Type.weak (finalType arg)) arg
503 and finalTypes arg: Type.t Prod.t =
504 traceFinalTypes
505 (fn v =>
506 case deObject v of
507 NONE =>
508 Prod.make (Vector.new1 {elt = finalType v,
509 isMutable = false})
510 | SOME {args, con, finalTypes, flat, ...} =>
511 Ref.memoize
512 (finalTypes, fn () =>
513 let
514 val args = prodFinalTypes args
515 in
516 case !flat of
517 Flat => args
518 | NotFlat =>
519 Prod.make
520 (Vector.new1
521 {elt = Type.object {args = args, con = con},
522 isMutable = false})
523 end)) arg
524 and prodFinalTypes (p: t Prod.t): Type.t Prod.t =
525 Prod.make
526 (Vector.fromList
527 (Vector.foldr
528 (Prod.dest p, [], fn ({elt, isMutable = i}, ac) =>
529 Vector.foldr
530 (Prod.dest (finalTypes elt), ac, fn ({elt, isMutable = i'}, ac) =>
531 {elt = elt, isMutable = i orelse i'} :: ac))))
532 end
533
534 structure Object =
535 struct
536 type t = Value.object
537
538 fun select ({args, ...}: t, offset): Value.t =
539 Prod.elt (args, offset)
540
541 fun finalOffsets ({args, finalOffsets = r, ...}: t): int vector =
542 Ref.memoize
543 (r, fn () =>
544 Vector.fromListRev
545 (#2 (Prod.fold
546 (args, (0, []), fn (elt, (offset, offsets)) =>
547 (offset + Prod.length (Value.finalTypes elt),
548 offset :: offsets)))))
549
550 fun finalOffset (object, offset) =
551 Vector.sub (finalOffsets object, offset)
552 end
553
554 fun transform2 (program as Program.T {datatypes, functions, globals, main}) =
555 let
556 val {get = conValue: Con.t -> Value.t option ref, ...} =
557 Property.get (Con.plist, Property.initFun (fn _ => ref NONE))
558 val conValue =
559 Trace.trace ("DeepFlatten.conValue",
560 Con.layout, Ref.layout (Option.layout Value.layout))
561 conValue
562 datatype 'a make =
563 Const of 'a
564 | Make of unit -> 'a
565 val traceMakeTypeValue =
566 Trace.trace ("DeepFlatten.makeTypeValue",
567 Type.layout o #1,
568 Layout.ignore)
569 fun makeValue m =
570 case m of
571 Const v => v
572 | Make f => f ()
573 fun needToMakeProd p =
574 Vector.exists (Prod.dest p, fn {elt, ...} =>
575 case elt of
576 Const _ => false
577 | Make _ => true)
578 fun makeProd p = Prod.map (p, makeValue)
579 val {get = makeTypeValue: Type.t -> Value.t make, ...} =
580 Property.get
581 (Type.plist,
582 Property.initRec
583 (traceMakeTypeValue
584 (fn (t, makeTypeValue) =>
585 let
586 fun const () = Const (Value.ground t)
587 datatype z = datatype Type.dest
588 in
589 case Type.dest t of
590 Object {args, con} =>
591 let
592 val args = Prod.map (args, makeTypeValue)
593 fun doit () =
594 if needToMakeProd args
595 orelse Value.mayFlatten {args = args, con = con}
596 then
597 Make
598 (fn () =>
599 Value.object (fn () => {args = makeProd args,
600 con = con}))
601 else const ()
602 datatype z = datatype ObjectCon.t
603 in
604 case con of
605 Con c =>
606 Const (Ref.memoize
607 (conValue c, fn () =>
608 makeValue (doit ())))
609 | Tuple => doit ()
610 | Vector => doit ()
611 end
612 | Weak t =>
613 (case makeTypeValue t of
614 Const _ => const ()
615 | Make f => Make (fn () => Value.weak (f ())))
616 | _ => const ()
617 end)))
618 fun typeValue (t: Type.t): Value.t =
619 makeValue (makeTypeValue t)
620 val typeValue =
621 Trace.trace ("DeepFlatten.typeValue", Type.layout, Value.layout)
622 typeValue
623 val (coerce, coerceProd) = (Value.coerce, Value.coerceProd)
624 fun inject {sum, variant = _} = typeValue (Type.datatypee sum)
625 fun object {args, con, resultType} =
626 let
627 val m = makeTypeValue resultType
628 in
629 case con of
630 NONE =>
631 (case m of
632 Const v => v
633 | Make _ => Value.tuple args)
634 | SOME _ =>
635 (case m of
636 Const v =>
637 let
638 val () =
639 case Value.deObject v of
640 NONE => ()
641 | SOME {args = args', ...} =>
642 coerceProd {from = args, to = args'}
643 in
644 v
645 end
646 | _ => Error.bug "DeepFlatten.object: strange con value")
647 end
648 val object =
649 Trace.trace
650 ("DeepFlatten.object",
651 fn {args, con, ...} =>
652 Layout.record [("args", Prod.layout (args, Value.layout)),
653 ("con", Option.layout Con.layout con)],
654 Value.layout)
655 object
656 val deWeak : Value.t -> Value.t =
657 fn v =>
658 case v of
659 Value.Ground t =>
660 typeValue (case Type.dest t of
661 Type.Weak t => t
662 | _ => Error.bug "DeepFlatten.primApp: deWeak")
663 | Value.Weak {arg, ...} => arg
664 | _ => Error.bug "DeepFlatten.primApp: Value.deWeak"
665 fun primApp {args, prim, resultVar = _, resultType} =
666 let
667 fun weak v =
668 case makeTypeValue resultType of
669 Const v => v
670 | Make _ => Value.weak v
671 fun arg i = Vector.sub (args, i)
672 fun result () = typeValue resultType
673 datatype z = datatype Prim.Name.t
674 fun dontFlatten () =
675 (Vector.foreach (args, Value.dontFlatten)
676 ; result ())
677 fun equal () =
678 (Value.unify (arg 0, arg 1)
679 ; Value.dontFlatten (arg 0)
680 ; result ())
681 in
682 case Prim.name prim of
683 Array_toArray =>
684 let
685 val res = result ()
686 val () =
687 case (Value.deObject (arg 0), Value.deObject res) of
688 (NONE, NONE) => ()
689 | (SOME {args = a, ...}, SOME {args = a', ...}) =>
690 Vector.foreach2
691 (Prod.dest a, Prod.dest a',
692 fn ({elt = v, ...}, {elt = v', ...}) =>
693 Value.unify (v, v'))
694 | _ => Error.bug "DeepFlatten.primApp: Array_toArray"
695 in
696 res
697 end
698 | Array_toVector =>
699 let
700 val res = result ()
701 val () =
702 case (Value.deObject (arg 0), Value.deObject res) of
703 (NONE, NONE) => ()
704 | (SOME {args = a, ...}, SOME {args = a', ...}) =>
705 Vector.foreach2
706 (Prod.dest a, Prod.dest a',
707 fn ({elt = v, ...}, {elt = v', ...}) =>
708 Value.unify (v, v'))
709 | _ => Error.bug "DeepFlatten.primApp: Array_toVector"
710 in
711 res
712 end
713 | FFI _ =>
714 (* Some imports, like Real64.modf, take ref cells that can not
715 * be flattened.
716 *)
717 dontFlatten ()
718 | MLton_eq => equal ()
719 | MLton_equal => equal ()
720 | MLton_size => dontFlatten ()
721 | MLton_share => dontFlatten ()
722 | Weak_get => deWeak (arg 0)
723 | Weak_new =>
724 let val a = arg 0
725 in (Value.dontFlatten a; weak a)
726 end
727 | _ => result ()
728 end
729 fun base b =
730 case b of
731 Base.Object obj => obj
732 | Base.VectorSub {vector, ...} => vector
733 fun select {base, offset} =
734 let
735 datatype z = datatype Value.t
736 in
737 case base of
738 Ground t =>
739 (case Type.dest t of
740 Type.Object {args, ...} =>
741 typeValue (Prod.elt (args, offset))
742 | _ => Error.bug "DeepFlatten.select: Ground")
743 | Object e => Object.select (Equatable.value e, offset)
744 | _ => Error.bug "DeepFlatten.select:"
745 end
746 fun update {base, offset, value} =
747 coerce {from = value,
748 to = select {base = base, offset = offset}}
749 fun const c = typeValue (Type.ofConst c)
750 val {func, value = varValue, ...} =
751 analyze {base = base,
752 coerce = coerce,
753 const = const,
754 filter = fn _ => (),
755 filterWord = fn _ => (),
756 fromType = typeValue,
757 inject = inject,
758 layout = Value.layout,
759 object = object,
760 primApp = primApp,
761 program = program,
762 select = fn {base, offset, ...} => select {base = base,
763 offset = offset},
764 update = update,
765 useFromTypeOnBinds = false}
766 (* Don't flatten outermost part of formal parameters. *)
767 fun dontFlattenFormals (xts: (Var.t * Type.t) vector): unit =
768 Vector.foreach (xts, fn (x, _) => Value.dontFlatten (varValue x))
769 val () =
770 List.foreach
771 (functions, fn f =>
772 let
773 val {args, blocks, ...} = Function.dest f
774 val () = dontFlattenFormals args
775 val () = Vector.foreach (blocks, fn Block.T {args, ...} =>
776 dontFlattenFormals args)
777 in
778 ()
779 end)
780 val () =
781 Control.diagnostics
782 (fn display =>
783 let
784 open Layout
785 val () =
786 Vector.foreach
787 (datatypes, fn Datatype.T {cons, ...} =>
788 Vector.foreach
789 (cons, fn {con, ...} =>
790 display (Option.layout Value.layout (! (conValue con)))))
791 val () =
792 Program.foreachVar
793 (program, fn (x, _) =>
794 display
795 (seq [Var.layout x, str " ", Value.layout (varValue x)]))
796 in
797 ()
798 end)
799 (* Transform the program. *)
800 val datatypes =
801 Vector.map
802 (datatypes, fn Datatype.T {cons, tycon} =>
803 let
804 val cons =
805 Vector.map
806 (cons, fn {con, args} =>
807 let
808 val args =
809 case ! (conValue con) of
810 NONE => args
811 | SOME v =>
812 case Type.dest (Value.finalType v) of
813 Type.Object {args, ...} => args
814 | _ => Error.bug "DeepFlatten.datatypes: strange con"
815 in
816 {args = args, con = con}
817 end)
818 in
819 Datatype.T {cons = cons, tycon = tycon}
820 end)
821 val valueType = Value.finalType
822 fun valuesTypes vs = Vector.map (vs, Value.finalType)
823 val {get = varTree: Var.t -> VarTree.t, set = setVarTree, ...} =
824 Property.getSetOnce (Var.plist,
825 Property.initRaise ("tree", Var.layout))
826 val setVarTree =
827 Trace.trace2 ("DeepFlatten.setVarTree",
828 Var.layout, VarTree.layout, Unit.layout)
829 setVarTree
830 fun simpleVarTree (x: Var.t): unit =
831 setVarTree
832 (x, VarTree.labelRoot (VarTree.fromTypeTree
833 (Value.finalTree (varValue x)),
834 x))
835 fun transformFormals xts =
836 Vector.map (xts, fn (x, _) =>
837 let
838 val () = simpleVarTree x
839 in
840 (x, Value.finalType (varValue x))
841 end)
842 fun replaceVar (x: Var.t): Var.t =
843 let
844 fun bug () = Error.bug (concat ["DeepFlatten.replaceVar ", Var.toString x])
845 val Tree.T (info, _) = varTree x
846 in
847 case info of
848 VarTree.Flat => bug ()
849 | VarTree.NotFlat {var, ...} =>
850 case var of
851 NONE => bug ()
852 | SOME y => y
853 end
854 fun transformBind {exp, ty, var}: Statement.t list =
855 let
856 fun simpleTree () = Option.app (var, simpleVarTree)
857 fun doit (e: Exp.t) =
858 let
859 val ty =
860 case var of
861 NONE => ty
862 | SOME var => valueType (varValue var)
863 in
864 [Bind {exp = e, ty = ty, var = var}]
865 end
866 fun simple () =
867 (simpleTree ()
868 ; doit (Exp.replaceVar (exp, replaceVar)))
869 fun none () = []
870 in
871 case exp of
872 Exp.Const _ => simple ()
873 | Inject _ => simple ()
874 | Object {args, con} =>
875 (case var of
876 NONE => none ()
877 | SOME var =>
878 let
879 val v = varValue var
880 in
881 case Value.deObject v of
882 NONE => simple ()
883 | SOME {args = expects, flat, ...} =>
884 let
885 val z =
886 Vector.map2
887 (args, Prod.dest expects,
888 fn (arg, {elt, isMutable}) =>
889 let
890 val (vt, ss) =
891 coerceTree
892 {from = varTree arg,
893 to = Value.finalTree elt}
894 in
895 ({elt = vt,
896 isMutable = isMutable},
897 ss)
898 end)
899 val vts = Vector.map (z, #1)
900 fun set info =
901 setVarTree (var,
902 Tree.T (info,
903 Prod.make vts))
904 in
905 case !flat of
906 Flat => (set VarTree.Flat; none ())
907 | NotFlat =>
908 let
909 val ty = Value.finalType v
910 val () =
911 set (VarTree.NotFlat
912 {ty = ty,
913 var = SOME var})
914 val args =
915 Vector.fromList
916 (Vector.foldr
917 (vts, [],
918 fn ({elt = vt, ...}, ac) =>
919 VarTree.rootsOnto (vt, ac)))
920 val obj =
921 Bind
922 {exp = Object {args = args,
923 con = con},
924 ty = ty,
925 var = SOME var}
926 in
927 Vector.foldr
928 (z, [obj],
929 fn ((_, ss), ac) => ss @ ac)
930 end
931 end
932 end)
933 | PrimApp _ => simple ()
934 | Select {base, offset} =>
935 (case var of
936 NONE => none ()
937 | SOME var =>
938 let
939 val baseVar = Base.object base
940 in
941 case Value.deObject (varValue baseVar) of
942 NONE => simple ()
943 | SOME obj =>
944 let
945 val Tree.T (info, children) =
946 varTree baseVar
947 val {elt = child, isMutable} =
948 Prod.sub (children, offset)
949 val (child, ss) =
950 case info of
951 VarTree.Flat => (child, [])
952 | VarTree.NotFlat _ =>
953 let
954 val child =
955 (* Don't simplify a select out
956 * of a mutable field.
957 * Something may have mutated
958 * it.
959 *)
960 if isMutable
961 then VarTree.dropVars child
962 else child
963 in
964 VarTree.fillInRoots
965 (child,
966 {base = Base.map (base, replaceVar),
967 offset = (Object.finalOffset
968 (obj, offset))})
969 end
970 val () = setVarTree (var, child)
971 in
972 ss
973 end
974 end)
975 | Var x =>
976 (Option.app (var, fn y => setVarTree (y, varTree x))
977 ; none ())
978 end
979 fun transformStatement (s: Statement.t): Statement.t list =
980 let
981 fun simple () = [Statement.replaceUses (s, replaceVar)]
982 in
983 case s of
984 Bind b => transformBind b
985 | Profile _ => simple ()
986 | Update {base, offset, value} =>
987 let
988 val baseVar =
989 case base of
990 Base.Object x => x
991 | Base.VectorSub {vector = x, ...} => x
992 in
993 case Value.deObject (varValue baseVar) of
994 NONE => simple ()
995 | SOME object =>
996 let
997 val ss = ref []
998 val child =
999 Value.finalTree (Object.select (object, offset))
1000 val offset = Object.finalOffset (object, offset)
1001 val base = Base.map (base, replaceVar)
1002 val us =
1003 if not (TypeTree.isFlat child)
1004 then [Update {base = base,
1005 offset = offset,
1006 value = replaceVar value}]
1007 else
1008 let
1009 val (vt, ss') =
1010 coerceTree {from = varTree value,
1011 to = child}
1012 val () = ss := ss' @ (!ss)
1013 val r = ref offset
1014 val us = ref []
1015 val () =
1016 VarTree.foreachRoot
1017 (vt, fn var =>
1018 let
1019 val offset = !r
1020 val () = r := 1 + !r
1021 in
1022 List.push (us,
1023 Update {base = base,
1024 offset = offset,
1025 value = var})
1026 end)
1027 in
1028 !us
1029 end
1030 in
1031 !ss @ us
1032 end
1033 end
1034 end
1035 val transformStatement =
1036 Trace.trace ("DeepFlatten.transformStatement",
1037 Statement.layout,
1038 List.layout Statement.layout)
1039 transformStatement
1040 fun transformStatements ss =
1041 Vector.concatV
1042 (Vector.map (ss, Vector.fromList o transformStatement))
1043 fun transformTransfer t = Transfer.replaceVar (t, replaceVar)
1044 val transformTransfer =
1045 Trace.trace ("DeepFlatten.transformTransfer",
1046 Transfer.layout, Transfer.layout)
1047 transformTransfer
1048 fun transformBlock (Block.T {args, label, statements, transfer}) =
1049 Block.T {args = transformFormals args,
1050 label = label,
1051 statements = transformStatements statements,
1052 transfer = transformTransfer transfer}
1053 fun transformFunction (f: Function.t): Function.t =
1054 let
1055 val {args, mayInline, name, start, ...} = Function.dest f
1056 val {raises, returns, ...} = func name
1057 val args = transformFormals args
1058 val raises = Option.map (raises, valuesTypes)
1059 val returns = Option.map (returns, valuesTypes)
1060 val blocks = ref []
1061 val () =
1062 Function.dfs (f, fn b =>
1063 (List.push (blocks, transformBlock b)
1064 ; fn () => ()))
1065 in
1066 Function.new {args = args,
1067 blocks = Vector.fromList (!blocks),
1068 mayInline = mayInline,
1069 name = name,
1070 raises = raises,
1071 returns = returns,
1072 start = start}
1073 end
1074 val globals = transformStatements globals
1075 val functions = List.revMap (functions, transformFunction)
1076 val program =
1077 Program.T {datatypes = datatypes,
1078 functions = functions,
1079 globals = globals,
1080 main = main}
1081 val () = Program.clear program
1082 in
1083 shrink program
1084 end
1085
1086 end