Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* Copyright (C) 2016 Matthew Surawski. |
2 | * | |
3 | * MLton is released under a BSD-style license. | |
4 | * See the file MLton-LICENSE for details. | |
5 | *) | |
6 | ||
7 | (* Moves a conditional statement outside a loop by duplicating the loops body | |
8 | * under each branch of the conditional. | |
9 | *) | |
10 | functor LoopUnswitch (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM = | |
11 | struct | |
12 | ||
13 | open S | |
14 | open Exp Transfer | |
15 | ||
16 | structure Graph = DirectedGraph | |
17 | local | |
18 | open Graph | |
19 | in | |
20 | structure Node = Node | |
21 | structure Forest = LoopForest | |
22 | end | |
23 | ||
24 | fun ++ (v: int ref): unit = | |
25 | v := (!v) + 1 | |
26 | ||
27 | val optCount = ref 0 | |
28 | val tooBig = ref 0 | |
29 | val notInvariant = ref 0 | |
30 | val multiHeaders = ref 0 | |
31 | ||
32 | type BlockInfo = Label.t * (Var.t * Type.t) vector | |
33 | ||
34 | fun logli (l: Layout.t, i: int): unit = | |
35 | Control.diagnostics | |
36 | (fn display => | |
37 | display(Layout.indent(l, i))) | |
38 | ||
39 | fun logsi (s: string, i: int): unit = | |
40 | logli((Layout.str s), i) | |
41 | ||
42 | fun logs (s: string): unit = | |
43 | logsi(s, 0) | |
44 | ||
45 | fun logstat (x: int ref, s: string): unit = | |
46 | logs (concat[Int.toString(!x), " ", s]) | |
47 | ||
48 | (* If a block was renamed, return the new name. Otherwise return the old name. *) | |
49 | fun fixLabel (getBlockInfo: Label.t -> BlockInfo, | |
50 | label: Label.t, | |
51 | origLabels: Label.t vector): Label.t = | |
52 | if Vector.contains(origLabels, label, Label.equals) then | |
53 | let | |
54 | val (name, _) = getBlockInfo(label) | |
55 | in | |
56 | name | |
57 | end | |
58 | else | |
59 | label | |
60 | ||
61 | (* Copy an entire loop. *) | |
62 | fun copyLoop(blocks: Block.t vector, | |
63 | blockInfo: Label.t -> BlockInfo, | |
64 | setBlockInfo: Label.t * BlockInfo -> unit): Block.t vector = | |
65 | let | |
66 | val labels = Vector.map (blocks, Block.label) | |
67 | (* Assign a new label for each block *) | |
68 | val newBlocks = Vector.map (blocks, fn b => | |
69 | let | |
70 | val oldName = Block.label b | |
71 | val oldArgs = Block.args b | |
72 | val newName = Label.newNoname() | |
73 | val () = setBlockInfo(oldName, (newName, oldArgs)) | |
74 | in | |
75 | Block.T {args = Block.args b, | |
76 | label = newName, | |
77 | statements = Block.statements b, | |
78 | transfer = Block.transfer b} | |
79 | end) | |
80 | (* Rewrite the transfers of each block *) | |
81 | val fixedBlocks = Vector.map (newBlocks, | |
82 | fn Block.T {args, label, statements, transfer} => | |
83 | let | |
84 | val f = fn l => fixLabel(blockInfo, l, labels) | |
85 | val newTransfer = Transfer.replaceLabel(transfer, f) | |
86 | in | |
87 | Block.T {args = args, | |
88 | label = label, | |
89 | statements = statements, | |
90 | transfer = newTransfer} | |
91 | end) | |
92 | in | |
93 | fixedBlocks | |
94 | end | |
95 | ||
96 | (* Find all variables introduced in a block. *) | |
97 | fun blockVars (block: Block.t): Var.t list = | |
98 | let | |
99 | val args = Vector.fold ((Block.args block), [], (fn ((var, _), lst) => var::lst)) | |
100 | val stmts = Vector.fold ((Block.statements block), [], (fn (stmt, lst) => | |
101 | case Statement.var stmt of | |
102 | NONE => lst | |
103 | | SOME(v) => v::lst)) | |
104 | in | |
105 | args @ stmts | |
106 | end | |
107 | ||
108 | (* Determine if the block can be unswitched. *) | |
109 | fun detectCases(block: Block.t, loopVars: Var.t list, depth: int) = | |
110 | case Block.transfer block of | |
111 | Case {cases, default, test} => | |
112 | let | |
113 | val blockName = Block.label block | |
114 | val () = logsi (concat ["Evaluating ", Label.toString(blockName)], depth) | |
115 | val () = logli(Transfer.layout (Block.transfer block), depth) | |
116 | val testIsInvariant = not (List.contains(loopVars, test, Var.equals)) | |
117 | in | |
118 | if testIsInvariant then | |
119 | (logsi("Can optimize!", depth) ; SOME(cases, test, default)) | |
120 | else | |
121 | (logsi ("Can't optimize: condition not invariant", depth) ; | |
122 | ++notInvariant ; | |
123 | NONE) | |
124 | end | |
125 | | _ => NONE | |
126 | ||
127 | (* Look for any optimization opportunities in the loop. *) | |
128 | fun findOpportunity(loopBody: Block.t vector, | |
129 | loopHeaders: Block.t vector, | |
130 | depth: int) | |
131 | : ((Cases.t * Var.t * Label.t option) * Block.t) option = | |
132 | let | |
133 | val vars = Vector.fold (loopBody, [], (fn (b, lst) => (blockVars b) @ lst)) | |
134 | val canOptimize = Vector.keepAllMap (loopBody, | |
135 | fn b => detectCases (b, vars, depth + 1)) | |
136 | in | |
137 | if (Vector.length loopHeaders) = 1 then | |
138 | case Vector.length canOptimize of | |
139 | 0 => NONE | |
140 | | _ => SOME(Vector.sub(canOptimize, 0), Vector.sub(loopHeaders, 0)) | |
141 | else | |
142 | (logsi ("Can't optimize: loop has more than 1 header", depth) ; | |
143 | ++multiHeaders ; | |
144 | NONE) | |
145 | end | |
146 | ||
147 | (* Copy a loop and set up the transfer *) | |
148 | fun makeBranch (loopBody: Block.t vector, | |
149 | loopHeader: Block.t, | |
150 | branchLabel: Label.t, | |
151 | blockInfo: Label.t -> BlockInfo, | |
152 | setBlockInfo: Label.t * BlockInfo -> unit, | |
153 | labelNode: Label.t -> unit Node.t, | |
154 | nodeBlock: unit Node.t -> Block.t) | |
155 | : Block.t vector * Label.t = | |
156 | let | |
157 | (* Copy the loop body *) | |
158 | val loopBodyLabels = Vector.map (loopBody, Block.label) | |
159 | val newLoop = copyLoop(loopBody, blockInfo, setBlockInfo) | |
160 | (* Set up a goto for the loop *) | |
161 | val (newLoopHeaderLabel, _) = blockInfo(Block.label loopHeader) | |
162 | val newLoopArgs = Vector.map (Block.args loopHeader, | |
163 | fn (v, _) => v) | |
164 | val newLoopEntryTransfer = Transfer.Goto {args = newLoopArgs, | |
165 | dst = newLoopHeaderLabel} | |
166 | val newLoopEntryLabel = Label.newNoname() | |
167 | val newLoopEntryArgs = | |
168 | if Vector.contains (loopBodyLabels, branchLabel, Label.equals) then | |
169 | let | |
170 | val (_, args) = blockInfo(branchLabel) | |
171 | in | |
172 | args | |
173 | end | |
174 | else | |
175 | let | |
176 | val block = nodeBlock (labelNode branchLabel) | |
177 | in | |
178 | Block.args block | |
179 | end | |
180 | val newLoopEntry = Block.T {args = newLoopEntryArgs, | |
181 | label = newLoopEntryLabel, | |
182 | statements = Vector.new0(), | |
183 | transfer = newLoopEntryTransfer} | |
184 | ||
185 | (* Return the new loop, entrypoint, and entrypoint label *) | |
186 | val returnBlocks = | |
187 | Vector.concat [newLoop, (Vector.new1(newLoopEntry))] | |
188 | in | |
189 | (returnBlocks, newLoopEntryLabel) | |
190 | end | |
191 | ||
192 | fun shouldOptimize (cases, default, loopBlocks, depth) = | |
193 | let | |
194 | val loopSize' = Block.sizeV (loopBlocks, {sizeExp = Exp.size, sizeTransfer = Transfer.size}) | |
195 | val loopSize = IntInf.fromInt (loopSize') | |
196 | val branchCount = | |
197 | IntInf.fromInt ( | |
198 | (case cases of | |
199 | Cases.Con v => Vector.length v | |
200 | | Cases.Word (_, v) => Vector.length v) | |
201 | + | |
202 | (case default of | |
203 | NONE => 0 | |
204 | | SOME _ => 1)) | |
205 | val unswitchLimit = IntInf.fromInt (!Control.loopUnswitchLimit) | |
206 | val shouldUnswitch = (branchCount * loopSize) < unswitchLimit | |
207 | val () = logsi ("branches * loop size < unswitch factor = can unswitch", | |
208 | depth) | |
209 | val () = logsi (concat[IntInf.toString branchCount, " * ", | |
210 | IntInf.toString loopSize, " < ", | |
211 | IntInf.toString unswitchLimit, " = ", | |
212 | Bool.toString shouldUnswitch], depth) | |
213 | in | |
214 | shouldUnswitch | |
215 | end | |
216 | ||
217 | (* Attempt to optimize a single loop. Returns a list of blocks to add to the program | |
218 | and a list of blocks to remove from the program. *) | |
219 | fun optimizeLoop(headerNodes, loopNodes, labelNode, nodeBlock, depth): | |
220 | Block.t list * Label.t list = | |
221 | let | |
222 | val () = logsi ("At innermost loop", depth) | |
223 | val headers = Vector.map (headerNodes, nodeBlock) | |
224 | val blocks = Vector.map (loopNodes, nodeBlock) | |
225 | val blockNames = Vector.map (blocks, Block.label) | |
226 | val condLabelOpt = findOpportunity(blocks, headers, depth) | |
227 | val {get = blockInfo: Label.t -> BlockInfo, | |
228 | set = setBlockInfo: Label.t * BlockInfo -> unit, destroy} = | |
229 | Property.destGetSet(Label.plist, | |
230 | Property.initRaise("blockInfo", Label.layout)) | |
231 | in | |
232 | case condLabelOpt of | |
233 | NONE => ([], []) | |
234 | | SOME((cases, check, default), header) => | |
235 | if shouldOptimize (cases, default, blocks, depth + 1) then | |
236 | let | |
237 | val () = ++optCount | |
238 | val mkBranch = fn lbl => makeBranch(blocks, header, lbl, blockInfo, | |
239 | setBlockInfo, labelNode, nodeBlock) | |
240 | (* Copy the loop body for the default case if necessary *) | |
241 | val (newDefaultLoop, newDefault) = | |
242 | case default of | |
243 | NONE => ([], NONE) | |
244 | | SOME(defaultLabel) => | |
245 | let | |
246 | val (newLoop, newLoopEntryLabel) = mkBranch(defaultLabel) | |
247 | in | |
248 | (Vector.toList newLoop, SOME(newLoopEntryLabel)) | |
249 | end | |
250 | (* Copy the loop body for each case (except default) *) | |
251 | val (newLoops, newCases) = | |
252 | case cases of | |
253 | Cases.Con v => | |
254 | let | |
255 | val newLoopCases = | |
256 | Vector.map(v, | |
257 | fn (con, lbl) => | |
258 | let | |
259 | val (newLoop, newLoopEntryLabel) = mkBranch(lbl) | |
260 | val newCase = (con, newLoopEntryLabel) | |
261 | in | |
262 | (newLoop, newCase) | |
263 | end) | |
264 | val (newLoops, newCaseList) = Vector.unzip newLoopCases | |
265 | val newCases = Cases.Con (newCaseList) | |
266 | in | |
267 | (newLoops, newCases) | |
268 | end | |
269 | | Cases.Word (size, v) => | |
270 | let | |
271 | val newLoopCases = | |
272 | Vector.map(v, | |
273 | fn (wrd, lbl) => | |
274 | let | |
275 | val (newLoop, newLoopEntryLabel) = mkBranch(lbl) | |
276 | val newCase = (wrd, newLoopEntryLabel) | |
277 | in | |
278 | (newLoop, newCase) | |
279 | end) | |
280 | val (newLoops, newCaseList) = Vector.unzip newLoopCases | |
281 | val newCases = Cases.Word (size, newCaseList) | |
282 | in | |
283 | (newLoops, newCases) | |
284 | end | |
285 | ||
286 | (* Produce a single list of new blocks *) | |
287 | val loopBlocks = Vector.fold(newLoops, newDefaultLoop, fn (loop, acc) => | |
288 | acc @ (Vector.toList loop)) | |
289 | ||
290 | (* Produce a new entry block with the same label as the old loop header *) | |
291 | val newTransfer = Transfer.Case {cases = newCases, | |
292 | default = newDefault, | |
293 | test = check} | |
294 | val newEntry = Block.T {args = Block.args header, | |
295 | label = Block.label header, | |
296 | statements = Vector.new0(), | |
297 | transfer = newTransfer} | |
298 | val () = destroy() | |
299 | in | |
300 | (newEntry::loopBlocks, (Vector.toList blockNames)) | |
301 | end | |
302 | else | |
303 | (logsi ("Can't unswitch: too big", depth) ; | |
304 | ++tooBig ; | |
305 | ([], [])) | |
306 | end | |
307 | ||
308 | (* Traverse sub-forests until the innermost loop is found. *) | |
309 | fun traverseSubForest ({loops, notInLoop}, | |
310 | enclosingHeaders, | |
311 | labelNode, nodeBlock, depth): Block.t list * Label.t list = | |
312 | if (Vector.length loops) = 0 then | |
313 | optimizeLoop(enclosingHeaders, notInLoop, labelNode, nodeBlock, depth) | |
314 | else | |
315 | Vector.fold(loops, ([], []), fn (loop, (new, remove)) => | |
316 | let | |
317 | val (nBlocks, rBlocks) = traverseLoop(loop, labelNode, nodeBlock, depth + 1) | |
318 | in | |
319 | ((new @ nBlocks), (remove @ rBlocks)) | |
320 | end) | |
321 | ||
322 | (* Traverse loops in the loop forest. *) | |
323 | and traverseLoop ({headers, child}, | |
324 | labelNode, nodeBlock, depth): Block.t list * Label.t list = | |
325 | traverseSubForest ((Forest.dest child), headers, labelNode, nodeBlock, depth + 1) | |
326 | ||
327 | (* Traverse the top-level loop forest. *) | |
328 | fun traverseForest ({loops, notInLoop = _}, allBlocks, labelNode, nodeBlock): Block.t list = | |
329 | let | |
330 | (* Gather the blocks to add/remove *) | |
331 | val (newBlocks, blocksToRemove) = | |
332 | Vector.fold(loops, ([], []), fn (loop, (new, remove)) => | |
333 | let | |
334 | val (nBlocks, rBlocks) = traverseLoop(loop, labelNode, nodeBlock, 1) | |
335 | in | |
336 | ((new @ nBlocks), (remove @ rBlocks)) | |
337 | end) | |
338 | val keep: Block.t -> bool = | |
339 | (fn b => not (List.contains(blocksToRemove, (Block.label b), Label.equals))) | |
340 | val reducedBlocks = Vector.keepAll(allBlocks, keep) | |
341 | in | |
342 | (Vector.toList reducedBlocks) @ newBlocks | |
343 | end | |
344 | ||
345 | (* Performs the optimization on the body of a single function. *) | |
346 | fun optimizeFunction(function: Function.t): Function.t = | |
347 | let | |
348 | val {graph, labelNode, nodeBlock} = Function.controlFlow function | |
349 | val {args, blocks, mayInline, name, raises, returns, start} = | |
350 | Function.dest function | |
351 | val fsize = Function.size (function, {sizeExp = Exp.size, sizeTransfer = Transfer.size}) | |
352 | val () = logs (concat["Optimizing function: ", Func.toString name, | |
353 | " of size ", Int.toString fsize]) | |
354 | val root = labelNode start | |
355 | val forest = Graph.loopForestSteensgaard(graph, {root = root}) | |
356 | val newBlocks = traverseForest((Forest.dest forest), blocks, labelNode, nodeBlock) | |
357 | in | |
358 | Function.new {args = args, | |
359 | blocks = Vector.fromList(newBlocks), | |
360 | mayInline = mayInline, | |
361 | name = name, | |
362 | raises = raises, | |
363 | returns = returns, | |
364 | start = start} | |
365 | end | |
366 | ||
367 | (* Entry point. *) | |
368 | fun transform (Program.T {datatypes, globals, functions, main}) = | |
369 | let | |
370 | val () = optCount := 0 | |
371 | val () = tooBig := 0 | |
372 | val () = notInvariant := 0 | |
373 | val () = multiHeaders := 0 | |
374 | val () = logs "Unswitching loops" | |
375 | val optimizedFunctions = List.map (functions, optimizeFunction) | |
376 | val restore = restoreFunction {globals = globals} | |
377 | val () = logs "Performing SSA restore" | |
378 | val cleanedFunctions = List.map (optimizedFunctions, restore) | |
379 | val () = logstat (optCount, | |
380 | "loops optimized") | |
381 | val () = logstat (tooBig, | |
382 | "loops too big to unswitch") | |
383 | val () = logstat (notInvariant, | |
384 | "loops had variant conditions") | |
385 | val () = logstat (multiHeaders, | |
386 | "loops had multiple headers") | |
387 | val () = logs "Done." | |
388 | in | |
389 | Program.T {datatypes = datatypes, | |
390 | globals = globals, | |
391 | functions = cleanedFunctions, | |
392 | main = main} | |
393 | end | |
394 | ||
395 | end |