Import Debian changes 20180207-1
[hcoop/debian/mlton.git] / mlton / ssa / flatten.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 * Copyright (C) 1997-2000 NEC Research Institute.
4 *
5 * MLton is released under a BSD-style license.
6 * See the file MLton-LICENSE for details.
7 *)
8
9(*
10 * Flatten arguments to jumps, constructors, and functions.
11 * If a tuple is explicitly available at all uses of a jump (resp. function)
12 * then
13 * - The formals and call sites are changed so that the components of the
14 * tuple are passed.
15 * - The tuple is reconstructed at the beginning of the body of the jump.
16 *
17 * Similarly, if a tuple is explicitly available at all uses of a constructor,
18 * - The constructor argument type is changed to flatten the tuple type.
19 * - The tuple is passed flat at each ConApp.
20 * - The tuple is reconstructed at each Case target.
21 *)
22
23functor Flatten (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
24struct
25
26open S
27open Exp Transfer
28
29structure Rep =
30 struct
31 structure L = TwoPointLattice (val bottom = "flatten"
32 val top = "don't flatten")
33
34 open L
35
36 val isFlat = not o isTop
37
38 fun fromType t =
39 case Type.deTupleOpt t of
40 NONE => let val r = new () in makeTop r; r end
41 | SOME _ => new ()
42
43 fun fromTypes (ts: Type.t vector): t vector =
44 Vector.map (ts, fromType)
45
46 val tuplize: t -> unit = makeTop
47
48 val coerce = op <=
49
50 fun coerces (rs, rs') = Vector.foreach2 (rs, rs', coerce)
51
52 val unify = op ==
53
54 fun unifys (rs, rs') = Vector.foreach2 (rs, rs', unify)
55 end
56
57fun transform (Program.T {datatypes, globals, functions, main}) =
58 let
59 val {get = conInfo: Con.t -> {argsTypes: Type.t vector,
60 args: Rep.t vector},
61 set = setConInfo, ...} =
62 Property.getSetOnce
63 (Con.plist, Property.initRaise ("Flatten.conInfo", Con.layout))
64 val conArgs = #args o conInfo
65 val {get = funcInfo: Func.t -> {args: Rep.t vector,
66 returns: Rep.t vector option,
67 raises: Rep.t vector option},
68 set = setFuncInfo, ...} =
69 Property.getSetOnce
70 (Func.plist, Property.initRaise ("Flatten.funcInfo", Func.layout))
71 val funcArgs = #args o funcInfo
72 val {get = labelInfo: Label.t -> {args: Rep.t vector},
73 set = setLabelInfo, ...} =
74 Property.getSetOnce
75 (Label.plist, Property.initRaise ("Flatten.labelInfo", Label.layout))
76 val labelArgs = #args o labelInfo
77 val {get = varInfo: Var.t -> {rep: Rep.t,
78 tuple: Var.t vector option ref},
79 set = setVarInfo, ...} =
80 Property.getSetOnce
81 (Var.plist, Property.initFun
82 (fn _ => {rep = let val r = Rep.new ()
83 in Rep.tuplize r; r
84 end,
85 tuple = ref NONE}))
86 val fromFormal = fn (x, ty) => let val r = Rep.fromType ty
87 in
88 setVarInfo (x, {rep = r,
89 tuple = ref NONE})
90 ; r
91 end
92 val fromFormals = fn xtys => Vector.map (xtys, fromFormal)
93 val varRep = #rep o varInfo
94 val varTuple = #tuple o varInfo
95 fun coerce (x: Var.t, r: Rep.t) =
96 Rep.coerce (varRep x, r)
97 fun coerces (xs: Var.t vector, rs: Rep.t vector) =
98 Vector.foreach2 (xs, rs, coerce)
99
100 val _ =
101 Vector.foreach
102 (datatypes, fn Datatype.T {cons, ...} =>
103 Vector.foreach
104 (cons, fn {con, args} =>
105 setConInfo (con, {argsTypes = args,
106 args = Vector.map (args, Rep.fromType)})))
107 val _ =
108 List.foreach
109 (functions, fn f =>
110 let val {args, name, raises, returns, ...} = Function.dest f
111 in
112 setFuncInfo (name, {args = fromFormals args,
113 returns = Option.map (returns, Rep.fromTypes),
114 raises = Option.map (raises, Rep.fromTypes)})
115 end)
116
117 fun doitStatement (Statement.T {exp, var, ...}) =
118 case exp of
119 Tuple xs =>
120 Option.app
121 (var, fn var =>
122 setVarInfo (var, {rep = Rep.new (),
123 tuple = ref (SOME xs)}))
124 | ConApp {con, args} => coerces (args, conArgs con)
125 | Var x => setVarInfo (valOf var, varInfo x)
126 | _ => ()
127 val _ = Vector.foreach (globals, doitStatement)
128 val _ =
129 List.foreach
130 (functions, fn f =>
131 let
132 val {blocks, name, ...} = Function.dest f
133 val {raises, returns, ...} = funcInfo name
134 in
135 Vector.foreach
136 (blocks, fn Block.T {label, args, statements, ...} =>
137 (setLabelInfo (label, {args = fromFormals args})
138 ; Vector.foreach (statements, doitStatement)))
139 ; Vector.foreach
140 (blocks, fn Block.T {transfer, ...} =>
141 case transfer of
142 Return xs =>
143 (case returns of
144 NONE => Error.bug "Flatten.flatten: return mismatch"
145 | SOME rs => coerces (xs, rs))
146 | Raise xs =>
147 (case raises of
148 NONE => Error.bug "Flatten.flatten: raise mismatch"
149 | SOME rs => coerces (xs, rs))
150 | Call {func, args, return} =>
151 let
152 val {args = funcArgs,
153 returns = funcReturns,
154 raises = funcRaises} =
155 funcInfo func
156 val _ = coerces (args, funcArgs)
157 fun unifyReturns () =
158 case (funcReturns, returns) of
159 (SOME rs, SOME rs') => Rep.unifys (rs, rs')
160 | _ => ()
161 fun unifyRaises () =
162 case (funcRaises, raises) of
163 (SOME rs, SOME rs') => Rep.unifys (rs, rs')
164 | _ => ()
165 in
166 case return of
167 Return.Dead => ()
168 | Return.NonTail {cont, handler} =>
169 (Option.app
170 (funcReturns, fn rs =>
171 Rep.unifys (rs, labelArgs cont))
172 ; case handler of
173 Handler.Caller => unifyRaises ()
174 | Handler.Dead => ()
175 | Handler.Handle handler =>
176 Option.app
177 (funcRaises, fn rs =>
178 Rep.unifys (rs, labelArgs handler)))
179 | Return.Tail => (unifyReturns (); unifyRaises ())
180 end
181 | Goto {dst, args} => coerces (args, labelArgs dst)
182 | Case {cases = Cases.Con cases, ...} =>
183 Vector.foreach
184 (cases, fn (con, label) =>
185 Rep.coerces (conArgs con, labelArgs label))
186 | _ => ())
187 end)
188 val _ =
189 Control.diagnostics
190 (fn display =>
191 List.foreach
192 (functions, fn f =>
193 let
194 val name = Function.name f
195 val {args, raises, returns} = funcInfo name
196 open Layout
197 in
198 display
199 (seq [Func.layout name,
200 str " ",
201 record
202 [("args", Vector.layout Rep.layout args),
203 ("returns", Option.layout (Vector.layout Rep.layout) returns),
204 ("raises", Option.layout (Vector.layout Rep.layout) raises)]])
205 end))
206 fun flattenTypes (ts: Type.t vector, rs: Rep.t vector): Type.t vector =
207 Vector.fromList
208 (Vector.fold2 (ts, rs, [], fn (t, r, ts) =>
209 if Rep.isFlat r
210 then Vector.fold (Type.deTuple t, ts, op ::)
211 else t :: ts))
212 val datatypes =
213 Vector.map
214 (datatypes, fn Datatype.T {tycon, cons} =>
215 Datatype.T {tycon = tycon,
216 cons = (Vector.map
217 (cons, fn {con, args} =>
218 {con = con,
219 args = flattenTypes (args, conArgs con)}))})
220 fun flattens (xs as xsX: Var.t vector, rs: Rep.t vector) =
221 Vector.fromList
222 (Vector.fold2 (xs, rs, [],
223 fn (x, r, xs) =>
224 if Rep.isFlat r
225 then (case !(varTuple x) of
226 SOME ys => Vector.fold (ys, xs, op ::)
227 | _ => (Error.bug
228 (concat
229 ["Flatten.flattens: tuple unavailable: ",
230 (Var.toString x), " ",
231 (Layout.toString
232 (Vector.layout Var.layout xsX))])))
233 else x :: xs))
234 fun doitStatement (stmt as Statement.T {var, ty, exp}) =
235 case exp of
236 ConApp {con, args} =>
237 Statement.T {var = var,
238 ty = ty,
239 exp = ConApp {con = con,
240 args = flattens (args, conArgs con)}}
241 | _ => stmt
242 val globals = Vector.map (globals, doitStatement)
243 fun doitFunction f =
244 let
245 val {args, mayInline, name, raises, returns, start, ...} =
246 Function.dest f
247 val {args = argsReps, returns = returnsReps, raises = raisesReps} =
248 funcInfo name
249
250 val newBlocks = ref []
251
252 fun doitArgs (args, reps) =
253 let
254 val (args, stmts) =
255 Vector.fold2
256 (args, reps, ([], []), fn ((x, ty), r, (args, stmts)) =>
257 if Rep.isFlat r
258 then let
259 val tys = Type.deTuple ty
260 val xs = Vector.map (tys, fn _ => Var.newNoname ())
261 val _ = varTuple x := SOME xs
262 val args =
263 Vector.fold2
264 (xs, tys, args, fn (x, ty, args) =>
265 (x, ty) :: args)
266 in
267 (args,
268 Statement.T {var = SOME x,
269 ty = ty,
270 exp = Tuple xs}
271 :: stmts)
272 end
273 else ((x, ty) :: args, stmts))
274 in
275 (Vector.fromList args, Vector.fromList stmts)
276 end
277
278 fun doitCaseCon {test, cases, default} =
279 let
280 val cases =
281 Vector.map
282 (cases, fn (c, l) =>
283 let
284 val {args, argsTypes} = conInfo c
285 val actualReps = labelArgs l
286 in if Vector.forall2
287 (args, actualReps, fn (r, r') =>
288 Rep.isFlat r = Rep.isFlat r')
289 then (c, l)
290 else
291 (* Coerce from the constructor representation to the
292 * formals the jump expects.
293 *)
294 let
295 val l' = Label.newNoname ()
296 (* The formals need to match the type of the con.
297 * The actuals need to match the type of l.
298 *)
299 val (stmts, formals, actuals) =
300 Vector.fold3
301 (args, actualReps, argsTypes,
302 ([], [], []),
303 fn (r, r', ty, (stmts, formals, actuals)) =>
304 if Rep.isFlat r
305 then
306 (* The con is flat *)
307 let
308 val xts =
309 Vector.map
310 (Type.deTuple ty, fn ty =>
311 (Var.newNoname (), ty))
312 val xs = Vector.map (xts, #1)
313 val formals =
314 Vector.fold (xts, formals, op ::)
315 val (stmts, actuals) =
316 if Rep.isFlat r'
317 then (stmts,
318 Vector.fold
319 (xs, actuals, op ::))
320 else
321 let
322 val x = Var.newNoname ()
323 in
324 (Statement.T {var = SOME x,
325 ty = ty,
326 exp = Tuple xs}
327 :: stmts,
328 x :: actuals)
329 end
330 in (stmts, formals, actuals)
331 end
332 else
333 (* The con is tupled *)
334 let
335 val tuple = Var.newNoname ()
336 val formals = (tuple, ty) :: formals
337 val (stmts, actuals) =
338 if Rep.isFlat r'
339 then
340 let
341 val xts =
342 Vector.map
343 (Type.deTuple ty, fn ty =>
344 (Var.newNoname (), ty))
345 val xs = Vector.map (xts, #1)
346 val actuals =
347 Vector.fold
348 (xs, actuals, op ::)
349 val stmts =
350 Vector.foldi
351 (xts, stmts,
352 fn (i, (x, ty), stmts) =>
353 Statement.T
354 {var = SOME x,
355 ty = ty,
356 exp = Select {tuple = tuple,
357 offset = i}}
358 :: stmts)
359 in (stmts, actuals)
360 end
361 else (stmts, tuple :: actuals)
362 in (stmts, formals, actuals)
363 end)
364 val _ =
365 List.push
366 (newBlocks,
367 Block.T
368 {label = l',
369 args = Vector.fromList formals,
370 statements = Vector.fromList stmts,
371 transfer = Goto {dst = l,
372 args = Vector.fromList actuals}})
373 in
374 (c, l')
375 end
376 end)
377 in Case {test = test,
378 cases = Cases.Con cases,
379 default = default}
380 end
381 fun doitTransfer transfer =
382 case transfer of
383 Call {func, args, return} =>
384 Call {func = func,
385 args = flattens (args, funcArgs func),
386 return = return}
387 | Case {test, cases = Cases.Con cases, default} =>
388 doitCaseCon {test = test,
389 cases = cases,
390 default = default}
391 | Goto {dst, args} =>
392 Goto {dst = dst,
393 args = flattens (args, labelArgs dst)}
394 | Raise xs => Raise (flattens (xs, valOf raisesReps))
395 | Return xs => Return (flattens (xs, valOf returnsReps))
396 | _ => transfer
397
398 fun doitBlock (Block.T {label, args, statements, transfer}) =
399 let
400 val (args, stmts) = doitArgs (args, labelArgs label)
401 val statements = Vector.map (statements, doitStatement)
402 val statements = Vector.concat [stmts, statements]
403 val transfer = doitTransfer transfer
404 in
405 Block.T {label = label,
406 args = args,
407 statements = statements,
408 transfer = transfer}
409 end
410
411 val (args, stmts) = doitArgs (args, argsReps)
412 val start' = Label.newNoname ()
413 val _ = List.push
414 (newBlocks,
415 Block.T {label = start',
416 args = Vector.new0 (),
417 statements = stmts,
418 transfer = Goto {dst = start,
419 args = Vector.new0 ()}})
420 val start = start'
421 val _ = Function.dfs
422 (f, fn b => let val _ = List.push (newBlocks, doitBlock b)
423 in fn () => ()
424 end)
425 val blocks = Vector.fromList (!newBlocks)
426 val returns =
427 Option.map
428 (returns, fn ts =>
429 flattenTypes (ts, valOf returnsReps))
430 val raises =
431 Option.map
432 (raises, fn ts =>
433 flattenTypes (ts, valOf raisesReps))
434 in
435 Function.new {args = args,
436 blocks = blocks,
437 mayInline = mayInline,
438 name = name,
439 raises = raises,
440 returns = returns,
441 start = start}
442 end
443
444 val shrink = shrinkFunction {globals = globals}
445 val functions = List.revMap (functions, shrink o doitFunction)
446 val program =
447 Program.T {datatypes = datatypes,
448 globals = globals,
449 functions = functions,
450 main = main}
451 val _ = Program.clearTop program
452 in
453 program
454 end
455
456end