Import Debian changes 20180207-1
[hcoop/debian/mlton.git] / mlton / ssa / loop-unswitch.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 2016 Matthew Surawski.
2 *
3 * MLton is released under a BSD-style license.
4 * See the file MLton-LICENSE for details.
5 *)
6
7(* Moves a conditional statement outside a loop by duplicating the loops body
8 * under each branch of the conditional.
9 *)
10functor LoopUnswitch (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
11struct
12
13open S
14open Exp Transfer
15
16structure Graph = DirectedGraph
17local
18 open Graph
19in
20 structure Node = Node
21 structure Forest = LoopForest
22end
23
24fun ++ (v: int ref): unit =
25 v := (!v) + 1
26
27val optCount = ref 0
28val tooBig = ref 0
29val notInvariant = ref 0
30val multiHeaders = ref 0
31
32type BlockInfo = Label.t * (Var.t * Type.t) vector
33
34fun logli (l: Layout.t, i: int): unit =
35 Control.diagnostics
36 (fn display =>
37 display(Layout.indent(l, i)))
38
39fun logsi (s: string, i: int): unit =
40 logli((Layout.str s), i)
41
42fun logs (s: string): unit =
43 logsi(s, 0)
44
45fun logstat (x: int ref, s: string): unit =
46 logs (concat[Int.toString(!x), " ", s])
47
48(* If a block was renamed, return the new name. Otherwise return the old name. *)
49fun fixLabel (getBlockInfo: Label.t -> BlockInfo,
50 label: Label.t,
51 origLabels: Label.t vector): Label.t =
52 if Vector.contains(origLabels, label, Label.equals) then
53 let
54 val (name, _) = getBlockInfo(label)
55 in
56 name
57 end
58 else
59 label
60
61(* Copy an entire loop. *)
62fun copyLoop(blocks: Block.t vector,
63 blockInfo: Label.t -> BlockInfo,
64 setBlockInfo: Label.t * BlockInfo -> unit): Block.t vector =
65 let
66 val labels = Vector.map (blocks, Block.label)
67 (* Assign a new label for each block *)
68 val newBlocks = Vector.map (blocks, fn b =>
69 let
70 val oldName = Block.label b
71 val oldArgs = Block.args b
72 val newName = Label.newNoname()
73 val () = setBlockInfo(oldName, (newName, oldArgs))
74 in
75 Block.T {args = Block.args b,
76 label = newName,
77 statements = Block.statements b,
78 transfer = Block.transfer b}
79 end)
80 (* Rewrite the transfers of each block *)
81 val fixedBlocks = Vector.map (newBlocks,
82 fn Block.T {args, label, statements, transfer} =>
83 let
84 val f = fn l => fixLabel(blockInfo, l, labels)
85 val newTransfer = Transfer.replaceLabel(transfer, f)
86 in
87 Block.T {args = args,
88 label = label,
89 statements = statements,
90 transfer = newTransfer}
91 end)
92 in
93 fixedBlocks
94 end
95
96(* Find all variables introduced in a block. *)
97fun blockVars (block: Block.t): Var.t list =
98 let
99 val args = Vector.fold ((Block.args block), [], (fn ((var, _), lst) => var::lst))
100 val stmts = Vector.fold ((Block.statements block), [], (fn (stmt, lst) =>
101 case Statement.var stmt of
102 NONE => lst
103 | SOME(v) => v::lst))
104 in
105 args @ stmts
106 end
107
108(* Determine if the block can be unswitched. *)
109fun detectCases(block: Block.t, loopVars: Var.t list, depth: int) =
110 case Block.transfer block of
111 Case {cases, default, test} =>
112 let
113 val blockName = Block.label block
114 val () = logsi (concat ["Evaluating ", Label.toString(blockName)], depth)
115 val () = logli(Transfer.layout (Block.transfer block), depth)
116 val testIsInvariant = not (List.contains(loopVars, test, Var.equals))
117 in
118 if testIsInvariant then
119 (logsi("Can optimize!", depth) ; SOME(cases, test, default))
120 else
121 (logsi ("Can't optimize: condition not invariant", depth) ;
122 ++notInvariant ;
123 NONE)
124 end
125 | _ => NONE
126
127(* Look for any optimization opportunities in the loop. *)
128fun findOpportunity(loopBody: Block.t vector,
129 loopHeaders: Block.t vector,
130 depth: int)
131 : ((Cases.t * Var.t * Label.t option) * Block.t) option =
132 let
133 val vars = Vector.fold (loopBody, [], (fn (b, lst) => (blockVars b) @ lst))
134 val canOptimize = Vector.keepAllMap (loopBody,
135 fn b => detectCases (b, vars, depth + 1))
136 in
137 if (Vector.length loopHeaders) = 1 then
138 case Vector.length canOptimize of
139 0 => NONE
140 | _ => SOME(Vector.sub(canOptimize, 0), Vector.sub(loopHeaders, 0))
141 else
142 (logsi ("Can't optimize: loop has more than 1 header", depth) ;
143 ++multiHeaders ;
144 NONE)
145 end
146
147(* Copy a loop and set up the transfer *)
148fun makeBranch (loopBody: Block.t vector,
149 loopHeader: Block.t,
150 branchLabel: Label.t,
151 blockInfo: Label.t -> BlockInfo,
152 setBlockInfo: Label.t * BlockInfo -> unit,
153 labelNode: Label.t -> unit Node.t,
154 nodeBlock: unit Node.t -> Block.t)
155 : Block.t vector * Label.t =
156 let
157 (* Copy the loop body *)
158 val loopBodyLabels = Vector.map (loopBody, Block.label)
159 val newLoop = copyLoop(loopBody, blockInfo, setBlockInfo)
160 (* Set up a goto for the loop *)
161 val (newLoopHeaderLabel, _) = blockInfo(Block.label loopHeader)
162 val newLoopArgs = Vector.map (Block.args loopHeader,
163 fn (v, _) => v)
164 val newLoopEntryTransfer = Transfer.Goto {args = newLoopArgs,
165 dst = newLoopHeaderLabel}
166 val newLoopEntryLabel = Label.newNoname()
167 val newLoopEntryArgs =
168 if Vector.contains (loopBodyLabels, branchLabel, Label.equals) then
169 let
170 val (_, args) = blockInfo(branchLabel)
171 in
172 args
173 end
174 else
175 let
176 val block = nodeBlock (labelNode branchLabel)
177 in
178 Block.args block
179 end
180 val newLoopEntry = Block.T {args = newLoopEntryArgs,
181 label = newLoopEntryLabel,
182 statements = Vector.new0(),
183 transfer = newLoopEntryTransfer}
184
185 (* Return the new loop, entrypoint, and entrypoint label *)
186 val returnBlocks =
187 Vector.concat [newLoop, (Vector.new1(newLoopEntry))]
188 in
189 (returnBlocks, newLoopEntryLabel)
190 end
191
192fun shouldOptimize (cases, default, loopBlocks, depth) =
193 let
194 val loopSize' = Block.sizeV (loopBlocks, {sizeExp = Exp.size, sizeTransfer = Transfer.size})
195 val loopSize = IntInf.fromInt (loopSize')
196 val branchCount =
197 IntInf.fromInt (
198 (case cases of
199 Cases.Con v => Vector.length v
200 | Cases.Word (_, v) => Vector.length v)
201 +
202 (case default of
203 NONE => 0
204 | SOME _ => 1))
205 val unswitchLimit = IntInf.fromInt (!Control.loopUnswitchLimit)
206 val shouldUnswitch = (branchCount * loopSize) < unswitchLimit
207 val () = logsi ("branches * loop size < unswitch factor = can unswitch",
208 depth)
209 val () = logsi (concat[IntInf.toString branchCount, " * ",
210 IntInf.toString loopSize, " < ",
211 IntInf.toString unswitchLimit, " = ",
212 Bool.toString shouldUnswitch], depth)
213 in
214 shouldUnswitch
215 end
216
217(* Attempt to optimize a single loop. Returns a list of blocks to add to the program
218 and a list of blocks to remove from the program. *)
219fun optimizeLoop(headerNodes, loopNodes, labelNode, nodeBlock, depth):
220 Block.t list * Label.t list =
221 let
222 val () = logsi ("At innermost loop", depth)
223 val headers = Vector.map (headerNodes, nodeBlock)
224 val blocks = Vector.map (loopNodes, nodeBlock)
225 val blockNames = Vector.map (blocks, Block.label)
226 val condLabelOpt = findOpportunity(blocks, headers, depth)
227 val {get = blockInfo: Label.t -> BlockInfo,
228 set = setBlockInfo: Label.t * BlockInfo -> unit, destroy} =
229 Property.destGetSet(Label.plist,
230 Property.initRaise("blockInfo", Label.layout))
231 in
232 case condLabelOpt of
233 NONE => ([], [])
234 | SOME((cases, check, default), header) =>
235 if shouldOptimize (cases, default, blocks, depth + 1) then
236 let
237 val () = ++optCount
238 val mkBranch = fn lbl => makeBranch(blocks, header, lbl, blockInfo,
239 setBlockInfo, labelNode, nodeBlock)
240 (* Copy the loop body for the default case if necessary *)
241 val (newDefaultLoop, newDefault) =
242 case default of
243 NONE => ([], NONE)
244 | SOME(defaultLabel) =>
245 let
246 val (newLoop, newLoopEntryLabel) = mkBranch(defaultLabel)
247 in
248 (Vector.toList newLoop, SOME(newLoopEntryLabel))
249 end
250 (* Copy the loop body for each case (except default) *)
251 val (newLoops, newCases) =
252 case cases of
253 Cases.Con v =>
254 let
255 val newLoopCases =
256 Vector.map(v,
257 fn (con, lbl) =>
258 let
259 val (newLoop, newLoopEntryLabel) = mkBranch(lbl)
260 val newCase = (con, newLoopEntryLabel)
261 in
262 (newLoop, newCase)
263 end)
264 val (newLoops, newCaseList) = Vector.unzip newLoopCases
265 val newCases = Cases.Con (newCaseList)
266 in
267 (newLoops, newCases)
268 end
269 | Cases.Word (size, v) =>
270 let
271 val newLoopCases =
272 Vector.map(v,
273 fn (wrd, lbl) =>
274 let
275 val (newLoop, newLoopEntryLabel) = mkBranch(lbl)
276 val newCase = (wrd, newLoopEntryLabel)
277 in
278 (newLoop, newCase)
279 end)
280 val (newLoops, newCaseList) = Vector.unzip newLoopCases
281 val newCases = Cases.Word (size, newCaseList)
282 in
283 (newLoops, newCases)
284 end
285
286 (* Produce a single list of new blocks *)
287 val loopBlocks = Vector.fold(newLoops, newDefaultLoop, fn (loop, acc) =>
288 acc @ (Vector.toList loop))
289
290 (* Produce a new entry block with the same label as the old loop header *)
291 val newTransfer = Transfer.Case {cases = newCases,
292 default = newDefault,
293 test = check}
294 val newEntry = Block.T {args = Block.args header,
295 label = Block.label header,
296 statements = Vector.new0(),
297 transfer = newTransfer}
298 val () = destroy()
299 in
300 (newEntry::loopBlocks, (Vector.toList blockNames))
301 end
302 else
303 (logsi ("Can't unswitch: too big", depth) ;
304 ++tooBig ;
305 ([], []))
306 end
307
308(* Traverse sub-forests until the innermost loop is found. *)
309fun traverseSubForest ({loops, notInLoop},
310 enclosingHeaders,
311 labelNode, nodeBlock, depth): Block.t list * Label.t list =
312 if (Vector.length loops) = 0 then
313 optimizeLoop(enclosingHeaders, notInLoop, labelNode, nodeBlock, depth)
314 else
315 Vector.fold(loops, ([], []), fn (loop, (new, remove)) =>
316 let
317 val (nBlocks, rBlocks) = traverseLoop(loop, labelNode, nodeBlock, depth + 1)
318 in
319 ((new @ nBlocks), (remove @ rBlocks))
320 end)
321
322(* Traverse loops in the loop forest. *)
323and traverseLoop ({headers, child},
324 labelNode, nodeBlock, depth): Block.t list * Label.t list =
325 traverseSubForest ((Forest.dest child), headers, labelNode, nodeBlock, depth + 1)
326
327(* Traverse the top-level loop forest. *)
328fun traverseForest ({loops, notInLoop = _}, allBlocks, labelNode, nodeBlock): Block.t list =
329 let
330 (* Gather the blocks to add/remove *)
331 val (newBlocks, blocksToRemove) =
332 Vector.fold(loops, ([], []), fn (loop, (new, remove)) =>
333 let
334 val (nBlocks, rBlocks) = traverseLoop(loop, labelNode, nodeBlock, 1)
335 in
336 ((new @ nBlocks), (remove @ rBlocks))
337 end)
338 val keep: Block.t -> bool =
339 (fn b => not (List.contains(blocksToRemove, (Block.label b), Label.equals)))
340 val reducedBlocks = Vector.keepAll(allBlocks, keep)
341 in
342 (Vector.toList reducedBlocks) @ newBlocks
343 end
344
345(* Performs the optimization on the body of a single function. *)
346fun optimizeFunction(function: Function.t): Function.t =
347 let
348 val {graph, labelNode, nodeBlock} = Function.controlFlow function
349 val {args, blocks, mayInline, name, raises, returns, start} =
350 Function.dest function
351 val fsize = Function.size (function, {sizeExp = Exp.size, sizeTransfer = Transfer.size})
352 val () = logs (concat["Optimizing function: ", Func.toString name,
353 " of size ", Int.toString fsize])
354 val root = labelNode start
355 val forest = Graph.loopForestSteensgaard(graph, {root = root})
356 val newBlocks = traverseForest((Forest.dest forest), blocks, labelNode, nodeBlock)
357 in
358 Function.new {args = args,
359 blocks = Vector.fromList(newBlocks),
360 mayInline = mayInline,
361 name = name,
362 raises = raises,
363 returns = returns,
364 start = start}
365 end
366
367(* Entry point. *)
368fun transform (Program.T {datatypes, globals, functions, main}) =
369 let
370 val () = optCount := 0
371 val () = tooBig := 0
372 val () = notInvariant := 0
373 val () = multiHeaders := 0
374 val () = logs "Unswitching loops"
375 val optimizedFunctions = List.map (functions, optimizeFunction)
376 val restore = restoreFunction {globals = globals}
377 val () = logs "Performing SSA restore"
378 val cleanedFunctions = List.map (optimizedFunctions, restore)
379 val () = logstat (optCount,
380 "loops optimized")
381 val () = logstat (tooBig,
382 "loops too big to unswitch")
383 val () = logstat (notInvariant,
384 "loops had variant conditions")
385 val () = logstat (multiHeaders,
386 "loops had multiple headers")
387 val () = logs "Done."
388 in
389 Program.T {datatypes = datatypes,
390 globals = globals,
391 functions = cleanedFunctions,
392 main = main}
393 end
394
395end