Import Debian changes 20180207-1
[hcoop/debian/mlton.git] / mlton / ssa / common-subexp.fun
1 (* Copyright (C) 2009,2011,2017 Matthew Fluet.
2 * Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 * Copyright (C) 1997-2000 NEC Research Institute.
5 *
6 * MLton is released under a BSD-style license.
7 * See the file MLton-LICENSE for details.
8 *)
9
10 functor CommonSubexp (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
11 struct
12
13 open S
14
15 open Exp Transfer
16
17 fun transform (Program.T {globals, datatypes, functions, main}) =
18 let
19 (* Keep track of control-flow specific cse's,
20 * arguments, and in-degree of blocks.
21 *)
22 val {get = labelInfo: Label.t -> {add: (Var.t * Exp.t) list ref,
23 args: (Var.t * Type.t) vector,
24 inDeg: int ref},
25 set = setLabelInfo, ...} =
26 Property.getSetOnce (Label.plist,
27 Property.initRaise ("info", Label.layout))
28 (* Keep track of a total ordering on variables. *)
29 val {get = varIndex : Var.t -> int, set = setVarIndex, ...} =
30 Property.getSetOnce (Var.plist,
31 Property.initRaise ("varIndex", Var.layout))
32 val setVarIndex =
33 let
34 val c = Counter.new 0
35 in
36 fn x => setVarIndex (x, Counter.next c)
37 end
38 (* Keep track of variables used as overflow variables. *)
39 val {get = overflowVar: Var.t -> bool, set = setOverflowVar, ...} =
40 Property.getSetOnce (Var.plist, Property.initConst false)
41 (* Keep track of the replacements of variables. *)
42 val {get = replace: Var.t -> Var.t option, set = setReplace, ...} =
43 Property.getSetOnce (Var.plist, Property.initConst NONE)
44 (* Keep track of the variable that holds the length of arrays (and
45 * vectors and strings).
46 *)
47 val {get = getLength: Var.t -> Var.t option, set = setLength, ...} =
48 Property.getSetOnce (Var.plist, Property.initConst NONE)
49 fun canonVar x =
50 case replace x of
51 NONE => x
52 | SOME y => y
53 fun canonVars xs = Vector.map (xs, canonVar)
54 (* Canonicalize an Exp.
55 * Replace vars with their replacements.
56 * Put commutative arguments in canonical order.
57 *)
58 fun canon (e: Exp.t): Exp.t =
59 case e of
60 ConApp {con, args} =>
61 ConApp {con = con, args = canonVars args}
62 | Const _ => e
63 | PrimApp {prim, targs, args} =>
64 let
65 fun doit args =
66 PrimApp {prim = prim,
67 targs = targs,
68 args = args}
69 val args = canonVars args
70 fun arg i = Vector.sub (args, i)
71 fun canon2 () =
72 let
73 val a0 = arg 0
74 val a1 = arg 1
75 in
76 if varIndex a0 >= varIndex a1
77 then (a0, a1)
78 else (a1, a0)
79 end
80 datatype z = datatype Prim.Name.t
81 in
82 if Prim.isCommutative prim
83 then doit (Vector.new2 (canon2 ()))
84 else
85 if (case Prim.name prim of
86 IntInf_add => true
87 | IntInf_andb => true
88 | IntInf_gcd => true
89 | IntInf_mul => true
90 | IntInf_orb => true
91 | IntInf_xorb => true
92 | _ => false)
93 then
94 let
95 val (a0, a1) = canon2 ()
96 in doit (Vector.new3 (a0, a1, arg 2))
97 end
98 else doit args
99 end
100 | Select {tuple, offset} => Select {tuple = canonVar tuple,
101 offset = offset}
102 | Tuple xs => Tuple (canonVars xs)
103 | Var x => Var (canonVar x)
104 | _ => e
105
106 (* Keep a hash table of canonicalized Exps that are in scope. *)
107 val table: {hash: word, exp: Exp.t, var: Var.t} HashSet.t =
108 HashSet.new {hash = #hash}
109 fun lookup (var, exp, hash) =
110 HashSet.lookupOrInsert
111 (table, hash,
112 fn {exp = exp', ...} => Exp.equals (exp, exp'),
113 fn () => {exp = exp,
114 hash = hash,
115 var = var})
116
117 (* All of the globals are in scope, and never go out of scope. *)
118 (* The hash-cons'ing of globals in ConstantPropagation ensures
119 * that each global is unique.
120 *)
121 val _ =
122 Vector.foreach
123 (globals, fn Statement.T {var, exp, ...} =>
124 let
125 val var = valOf var
126 val () = setVarIndex var
127 val exp = canon exp
128 val _ = lookup (var, exp, Exp.hash exp)
129 in
130 ()
131 end)
132
133 fun doitTree tree =
134 let
135 val blocks = ref []
136 fun loop (Tree.T (Block.T {args, label,
137 statements, transfer},
138 children)): unit =
139 let
140 fun diag s =
141 Control.diagnostics
142 (fn display =>
143 let open Layout
144 in
145 display (seq [Label.layout label, str ": ", str s])
146 end)
147 val _ = diag "started"
148 val remove = ref []
149 val {add, ...} = labelInfo label
150 val _ = Control.diagnostics
151 (fn display =>
152 let open Layout
153 in
154 display (seq [str "add: ",
155 List.layout (fn (var,exp) =>
156 seq [Var.layout var,
157 str ": ",
158 Exp.layout exp]) (!add)])
159 end)
160 val _ = List.foreach
161 (!add, fn (var, exp) =>
162 let
163 val hash = Exp.hash exp
164 val elem as {var = var', ...} = lookup (var, exp, hash)
165 val _ = if Var.equals(var, var')
166 then List.push (remove, elem)
167 else ()
168 in
169 ()
170 end)
171 val _ = diag "added"
172
173 val _ =
174 Vector.foreach
175 (args, fn (var, _) => setVarIndex var)
176 val statements =
177 Vector.keepAllMap
178 (statements,
179 fn Statement.T {var, ty, exp} =>
180 let
181 val exp = canon exp
182 fun keep () = SOME (Statement.T {var = var,
183 ty = ty,
184 exp = exp})
185 in
186 case var of
187 NONE => keep ()
188 | SOME var =>
189 let
190 val _ = setVarIndex var
191 fun replace var' =
192 (setReplace (var, SOME var'); NONE)
193 fun doit () =
194 let
195 val hash = Exp.hash exp
196 val elem as {var = var', ...} =
197 lookup (var, exp, hash)
198 in
199 if Var.equals(var, var')
200 then (List.push (remove, elem)
201 ; keep ())
202 else replace var'
203 end
204 in
205 case exp of
206 PrimApp ({args, prim, ...}) =>
207 let
208 fun arg () = Vector.first args
209 fun knownLength var' =
210 let
211 val _ = setLength (var, SOME var')
212 in
213 keep ()
214 end
215 fun conv () =
216 case getLength (arg ()) of
217 NONE => keep ()
218 | SOME var' => knownLength var'
219 fun length () =
220 case getLength (arg ()) of
221 NONE => doit ()
222 | SOME var' => replace var'
223 datatype z = datatype Prim.Name.t
224 in
225 case Prim.name prim of
226 Array_alloc _ => knownLength (arg ())
227 | Array_length => length ()
228 | Array_toArray => conv ()
229 | Array_toVector => conv ()
230 | Vector_length => length ()
231 | _ => if Prim.isFunctional prim
232 then doit ()
233 else keep ()
234 end
235 | _ => doit ()
236 end
237 end)
238 val _ = diag "statements"
239 val transfer = Transfer.replaceVar (transfer, canonVar)
240 val transfer =
241 case transfer of
242 Arith {prim, args, overflow, success, ...} =>
243 let
244 val {args = succArgs,
245 inDeg = succInDeg,
246 add = succAdd, ...} =
247 labelInfo success
248 val {inDeg = overInDeg,
249 add = overAdd, ...} =
250 labelInfo overflow
251 val exp = canon (PrimApp {prim = prim,
252 targs = Vector.new0 (),
253 args = args})
254 val hash = Exp.hash exp
255 in
256 case HashSet.peek
257 (table, hash,
258 fn {exp = exp', ...} => Exp.equals (exp, exp')) of
259 SOME {var, ...} =>
260 if overflowVar var
261 then Goto {dst = overflow,
262 args = Vector.new0 ()}
263 else (if !succInDeg = 1
264 then let
265 val (var', _) =
266 Vector.first succArgs
267 in
268 setReplace (var', SOME var)
269 end
270 else ()
271 ; Goto {dst = success,
272 args = Vector.new1 var})
273 | NONE => (if !succInDeg = 1
274 then let
275 val (var, _) =
276 Vector.first succArgs
277 in
278 List.push
279 (succAdd, (var, exp))
280 end
281 else () ;
282 if !overInDeg = 1
283 then let
284 val var = Var.newNoname ()
285 val _ = setOverflowVar (var, true)
286 in
287 List.push
288 (overAdd, (var, exp))
289 end
290 else () ;
291 transfer)
292 end
293 | Goto {dst, args} =>
294 let
295 val {args = args', inDeg, ...} = labelInfo dst
296 in
297 if !inDeg = 1
298 then (Vector.foreach2
299 (args, args', fn (var, (var', _)) =>
300 setReplace (var', SOME var))
301 ; transfer)
302 else transfer
303 end
304 | _ => transfer
305 val _ = diag "transfer"
306 val block = Block.T {args = args,
307 label = label,
308 statements = statements,
309 transfer = transfer}
310 val _ = List.push (blocks, block)
311 val _ = Vector.foreach (children, loop)
312 val _ = diag "children"
313 val _ = Control.diagnostics
314 (fn display =>
315 let open Layout
316 in
317 display (seq [str "remove: ",
318 List.layout (fn {var,exp,...} =>
319 seq [Var.layout var,
320 str ": ",
321 Exp.layout exp]) (!remove)])
322 end)
323 val _ = List.foreach
324 (!remove, fn {var, hash, ...} =>
325 HashSet.remove
326 (table, hash, fn {var = var', ...} =>
327 Var.equals (var, var')))
328 val _ = diag "removed"
329 in
330 ()
331 end
332 val _ =
333 Control.diagnostics
334 (fn display =>
335 let open Layout
336 in
337 display (seq [str "starting loop"])
338 end)
339 val _ = loop tree
340 val _ =
341 Control.diagnostics
342 (fn display =>
343 let open Layout
344 in
345 display (seq [str "finished loop"])
346 end)
347 in
348 Vector.fromList (!blocks)
349 end
350 val shrink = shrinkFunction {globals = globals}
351 val functions =
352 List.revMap
353 (functions, fn f =>
354 let
355 val {args, blocks, mayInline, name, raises, returns, start} =
356 Function.dest f
357 val _ =
358 Vector.foreach
359 (args, fn (var, _) => setVarIndex var)
360 val _ =
361 Vector.foreach
362 (blocks, fn Block.T {label, args, ...} =>
363 (setLabelInfo (label, {add = ref [],
364 args = args,
365 inDeg = ref 0})))
366 val _ =
367 Vector.foreach
368 (blocks, fn Block.T {transfer, ...} =>
369 Transfer.foreachLabel (transfer, fn label' =>
370 Int.inc (#inDeg (labelInfo label'))))
371 val blocks = doitTree (Function.dominatorTree f)
372 in
373 shrink (Function.new {args = args,
374 blocks = blocks,
375 mayInline = mayInline,
376 name = name,
377 raises = raises,
378 returns = returns,
379 start = start})
380 end)
381 val program =
382 Program.T {datatypes = datatypes,
383 globals = globals,
384 functions = functions,
385 main = main}
386 val _ = Program.clearTop program
387 in
388 program
389 end
390
391 end