Import Upstream version 20180207
[hcoop/debian/mlton.git] / mlton / ssa / loop-invariant.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 1999-2005, 2008 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 * Copyright (C) 1997-2000 NEC Research Institute.
4 *
5 * MLton is released under a BSD-style license.
6 * See the file MLton-LICENSE for details.
7 *)
8
9(*
10 * Remove loop invariant args to local loops.
11 * fun loop (x, y) = ... loop (x, z) ...
12 *
13 * becomes
14 *
15 * fun loop (x, y') =
16 * let fun loop' (y) = ... loop' (z) ...
17 * in loop' (y')
18 * end
19 *)
20
21functor LoopInvariant (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
22struct
23
24open S
25open Exp Transfer
26
27fun transform (Program.T {globals, datatypes, functions, main}) =
28 let
29 val shrink = shrinkFunction {globals = globals}
30
31 fun simplifyFunction f =
32 let
33 val {args, blocks, mayInline, name, raises, returns, start} =
34 Function.dest f
35 val {get = labelInfo: Label.t -> {callsSelf: bool ref,
36 visited: bool ref,
37 invariant: (Var.t * bool ref) vector,
38 newLabel: Label.t option ref},
39 set = setLabelInfo, ...} =
40 Property.getSetOnce
41 (Label.plist,
42 Property.initRaise ("LoopInvariant.labelInfo", Label.layout))
43
44 val _ =
45 Vector.foreach
46 (blocks, fn Block.T {label, args, ...} =>
47 setLabelInfo (label,
48 {callsSelf = ref false,
49 visited = ref false,
50 invariant = Vector.map (args, fn (x, _) =>
51 (x, ref true)),
52 newLabel = ref NONE}))
53
54 fun visit (Block.T {label, transfer, ...}): unit -> unit =
55 let
56 val {visited, ...} = labelInfo label
57 val _ = visited := true
58 val _ =
59 case transfer of
60 Goto {dst, args} =>
61 let
62 val {callsSelf, visited, invariant, ...} = labelInfo dst
63 in
64 if !visited
65 then (callsSelf := true
66 ; Vector.foreach2
67 (args, invariant, fn (x, (y, b)) =>
68 if !b andalso not (Var.equals (x, y))
69 then b := false
70 else ()))
71 else ()
72 end
73 | _ => ()
74 in
75 fn () => visited := false
76 end
77 val _ = Function.dfs (f, visit)
78 fun remove (xs: 'a vector, invariant: ('b * bool ref) vector)
79 : 'a vector =
80 Vector.keepAllMap2 (xs, invariant, fn (x, (_, b)) =>
81 if !b then NONE else SOME x)
82
83 val newBlocks = ref []
84 fun visit (Block.T {label, args, statements, transfer})
85 : unit -> unit =
86 let
87 val {callsSelf, invariant, newLabel, ...} = labelInfo label
88 val _ =
89 if !callsSelf
90 andalso Vector.exists (invariant, ! o #2)
91 then newLabel := SOME (Label.new label)
92 else ()
93 val transfer =
94 case transfer of
95 Goto {dst, args} =>
96 let
97 val {invariant, newLabel, ...} = labelInfo dst
98 in
99 case !newLabel of
100 NONE => transfer
101 | SOME dst' =>
102 Goto {dst = dst',
103 args = remove (args, invariant)}
104 end
105 | _ => transfer
106 val (args, statements, transfer) =
107 case !newLabel of
108 NONE => (args, statements, transfer)
109 | SOME label' =>
110 let
111 val _ =
112 Control.diagnostic
113 (fn () =>
114 let open Layout
115 in seq [Label.layout label,
116 str " -> ",
117 Label.layout label']
118 end)
119 val (outerFormals,
120 innerFormals,
121 innerActuals) =
122 Vector.foldr2
123 (args, invariant, ([], [], []),
124 fn ((x, t), (_, b), (ofs, ifs, ias)) =>
125 if !b
126 then ((x, t) :: ofs, ifs, ias)
127 else let val x' = Var.new x
128 in ((x', t) :: ofs,
129 (x, t) :: ifs,
130 x' :: ias)
131 end)
132 in
133 List.push
134 (newBlocks,
135 Block.T {label = label',
136 args = Vector.fromList innerFormals,
137 statements = statements,
138 transfer = transfer})
139 ; (Vector.fromList outerFormals,
140 Vector.new0 (),
141 Goto {dst = label',
142 args = Vector.fromList innerActuals})
143 end
144 val _ = List.push
145 (newBlocks,
146 Block.T {label = label,
147 args = args,
148 statements = statements,
149 transfer = transfer})
150 in
151 fn () => newLabel := NONE
152 end
153 val _ = Function.dfs (f, visit)
154 val blocks = Vector.fromList (!newBlocks)
155 in
156 shrink (Function.new {args = args,
157 blocks = blocks,
158 mayInline = mayInline,
159 name = name,
160 raises = raises,
161 returns = returns,
162 start = start})
163 end
164 val program =
165 Program.T {datatypes = datatypes,
166 globals = globals,
167 functions = List.revMap (functions, simplifyFunction),
168 main = main}
169 val _ = Program.clearTop program
170 in
171 program
172 end
173end