Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* Copyright (C) 2016 Matthew Surawski. |
2 | * | |
3 | * MLton is released under a BSD-style license. | |
4 | * See the file MLton-LICENSE for details. | |
5 | *) | |
6 | ||
7 | (* Reduces or eliminates the iteration count of loops by duplicating | |
8 | * the loop body. | |
9 | *) | |
10 | functor LoopUnroll (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM = | |
11 | struct | |
12 | ||
13 | open S | |
14 | open Exp Transfer Prim | |
15 | ||
16 | structure Graph = DirectedGraph | |
17 | local | |
18 | open Graph | |
19 | in | |
20 | structure Forest = LoopForest | |
21 | end | |
22 | ||
23 | fun ++ (v: int ref): unit = | |
24 | v := (!v) + 1 | |
25 | ||
26 | ||
27 | structure Histogram = | |
28 | struct | |
29 | type t = (IntInf.t * int ref) HashSet.t | |
30 | ||
31 | fun inc (set: t, key: IntInf.t): unit = | |
32 | let | |
33 | val _ = HashSet.insertIfNew (set, IntInf.hash key, | |
34 | (fn (k, _) => k = key), | |
35 | (fn () => (key, ref 1)), | |
36 | (fn (_, r) => ++r)) | |
37 | in | |
38 | () | |
39 | end | |
40 | ||
41 | fun new (): t = | |
42 | HashSet.new {hash = fn (k, _) => IntInf.hash k} | |
43 | ||
44 | fun toList (set: t): (IntInf.t * int ref) list = | |
45 | HashSet.toList set | |
46 | ||
47 | fun toString (set: t) : string = | |
48 | let | |
49 | val eles = toList set | |
50 | in | |
51 | List.fold (eles, "", fn ((k, r), s) => concat[s, | |
52 | IntInf.toString k, ": ", | |
53 | Int.toString (!r), "\n"]) | |
54 | end | |
55 | end | |
56 | ||
57 | val loopCount = ref 0 | |
58 | val optCount = ref 0 | |
59 | val total = ref 0 | |
60 | val partial = ref 0 | |
61 | val multiHeaders = ref 0 | |
62 | val varEntryArg = ref 0 | |
63 | val variantTransfer = ref 0 | |
64 | val unsupported = ref 0 | |
65 | val ccTransfer = ref 0 | |
66 | val varBound = ref 0 | |
67 | val infinite = ref 0 | |
68 | val boundDom = ref 0 | |
69 | val histogram = ref (Histogram.new ()) | |
70 | ||
71 | type BlockInfo = Label.t * (Var.t * Type.t) vector | |
72 | ||
73 | structure Loop = | |
74 | struct | |
75 | datatype Bound = Eq of IntInf.t | Lt of IntInf.t | Gt of IntInf.t | |
76 | type Start = IntInf.t | |
77 | type Step = IntInf.t | |
78 | datatype t = T of {start: Start, step: Step, bound: Bound, invert: bool} | |
79 | ||
80 | fun toString (T {start, step, bound, invert}): string = | |
81 | let | |
82 | val boundStr = case bound of | |
83 | Eq b => if invert then | |
84 | concat ["!= ", IntInf.toString b] | |
85 | else | |
86 | concat ["= ", IntInf.toString b] | |
87 | | Lt b => if invert then | |
88 | concat ["!< ", IntInf.toString b] | |
89 | else | |
90 | concat ["< ", IntInf.toString b] | |
91 | | Gt b => if invert then | |
92 | concat ["!> ", IntInf.toString b] | |
93 | else | |
94 | concat ["> ", IntInf.toString b] | |
95 | in | |
96 | concat[" Start: ", IntInf.toString start, | |
97 | " Step: ", IntInf.toString step, | |
98 | " Bound: ", boundStr] | |
99 | end | |
100 | ||
101 | fun isInfiniteLoop (T {start, step, bound, invert}): bool = | |
102 | case bound of | |
103 | Eq b => | |
104 | if invert then | |
105 | (if start = b then | |
106 | false | |
107 | else if start < b andalso step > 0 then | |
108 | not (((b - start) mod step) = 0) | |
109 | else if start > b andalso step < 0 then | |
110 | not (((start - b) mod (~step)) = 0) | |
111 | else | |
112 | true) | |
113 | else | |
114 | step = 0 | |
115 | | Lt b => | |
116 | if invert then | |
117 | start >= b andalso step >= 0 | |
118 | else | |
119 | start < b andalso step <= 0 | |
120 | | Gt b => | |
121 | if invert then | |
122 | start <= b andalso step <= 0 | |
123 | else | |
124 | start > b andalso step >= 0 | |
125 | ||
126 | ||
127 | fun iters (start: IntInf.t, step: IntInf.t, max: IntInf.t): IntInf.t = | |
128 | let | |
129 | val range = max - start | |
130 | val iters = range div step | |
131 | val adds = range mod step | |
132 | in | |
133 | if step > range then | |
134 | 1 | |
135 | else | |
136 | iters + adds | |
137 | end | |
138 | ||
139 | (* Assumes isInfiniteLoop is false, otherwise the result is undefined. *) | |
140 | fun iterCount (T {start, step, bound, invert}): IntInf.t = | |
141 | case bound of | |
142 | Eq b => | |
143 | if invert then | |
144 | (b - start) div step | |
145 | else | |
146 | 1 | |
147 | | Lt b => | |
148 | (case (start >= b, invert) of | |
149 | (true, false) => 0 | |
150 | | (true, true) => iters (b - 1, ~step, start) | |
151 | | (false, true) => 0 | |
152 | | (false, false) => iters (start, step, b)) | |
153 | | Gt b => | |
154 | (case (start <= b, invert) of | |
155 | (true, false) => 0 | |
156 | | (true, true) => iters (start, step, b + 1) | |
157 | | (false, true) => 0 | |
158 | | (false, false) => iters (b, ~step, start)) | |
159 | ||
160 | fun makeConstStmt (v: IntInf.t, wsize: WordSize.t): Var.t * Statement.t = | |
161 | let | |
162 | val newWord = WordX.fromIntInf (v, wsize) | |
163 | val newConst = Const.word newWord | |
164 | val newExp = Exp.Const (newConst) | |
165 | val newType = Type.word wsize | |
166 | val newVar = Var.newNoname() | |
167 | val newStatement = Statement.T {exp = newExp, | |
168 | ty = newType, | |
169 | var = SOME(newVar)} | |
170 | in | |
171 | (newVar, newStatement) | |
172 | end | |
173 | ||
174 | fun makeVarStmt (v: IntInf.t, wsize: WordSize.t, var: Var.t) | |
175 | : Var.t * Statement.t list = | |
176 | let | |
177 | val (cVar, cStmt) = makeConstStmt (v, wsize) | |
178 | val newExp = Exp.PrimApp {args = Vector.new2 (var, cVar), | |
179 | prim = Prim.wordAdd wsize, | |
180 | targs = Vector.new0 ()} | |
181 | val newType = Type.word wsize | |
182 | val newVar = Var.newNoname() | |
183 | val newStatement = Statement.T {exp = newExp, | |
184 | ty = newType, | |
185 | var = SOME(newVar)} | |
186 | in | |
187 | (newVar, [cStmt, newStatement]) | |
188 | end | |
189 | ||
190 | (* Given: | |
191 | - a loop | |
192 | - a word size | |
193 | Returns: | |
194 | A variable and statement for the constant value after the loops final | |
195 | iteration. This value will make the loop exit. | |
196 | Assumes isInfiniteLoop is false, otherwise the result is undefined. | |
197 | *) | |
198 | fun makeLastConstant (T {start, step, bound, invert}, | |
199 | wsize: WordSize.t) | |
200 | : Var.t list * Statement.t list = | |
201 | let | |
202 | val ic = iterCount (T {start = start, | |
203 | step = step, | |
204 | bound = bound, | |
205 | invert = invert}) | |
206 | val last = start + (step * ic) | |
207 | val (newVar, newStatement) = makeConstStmt(last, wsize) | |
208 | in | |
209 | ([newVar], [newStatement]) | |
210 | end | |
211 | ||
212 | (* Given: | |
213 | - a loop | |
214 | - a word size | |
215 | - an iteration limit | |
216 | Returns: | |
217 | A pair of variables and statements for those variables | |
218 | for each iteration of the loop. | |
219 | This should go 1 step beyond the end of the loop. | |
220 | Assumes isInfiniteLoop is false, otherwise this will run forever. *) | |
221 | fun makeConstants (T {start, step, bound, invert}, | |
222 | wsize: WordSize.t, | |
223 | limit: IntInf.t) | |
224 | : Var.t list * Statement.t list = | |
225 | case bound of | |
226 | Eq b => | |
227 | if (start = b) <> invert andalso limit > 0 then | |
228 | let | |
229 | val (newVar, newStatement) = makeConstStmt(start, wsize) | |
230 | val nextIter = T {start = start + step, | |
231 | step = step, | |
232 | bound = bound, | |
233 | invert = invert} | |
234 | val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1) | |
235 | in | |
236 | (newVar::rVars, newStatement::rStmts) | |
237 | end | |
238 | else if limit > 0 then | |
239 | let | |
240 | val (newVar, newStatement) = makeConstStmt(start, wsize) | |
241 | in | |
242 | ([newVar], [newStatement]) | |
243 | end | |
244 | else | |
245 | ([], []) | |
246 | | Lt b => | |
247 | if (start < b) <> invert andalso limit > 0 then | |
248 | let | |
249 | val (newVar, newStatement) = makeConstStmt(start, wsize) | |
250 | val nextIter = T {start = start + step, | |
251 | step = step, | |
252 | bound = bound, | |
253 | invert = invert} | |
254 | val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1) | |
255 | in | |
256 | (newVar::rVars, newStatement::rStmts) | |
257 | end | |
258 | else if limit > 0 then | |
259 | let | |
260 | val (newVar, newStatement) = makeConstStmt(start, wsize) | |
261 | in | |
262 | ([newVar], [newStatement]) | |
263 | end | |
264 | else | |
265 | ([], []) | |
266 | | Gt b => | |
267 | if (start > b) <> invert andalso limit > 0 then | |
268 | let | |
269 | val (newVar, newStatement) = makeConstStmt(start, wsize) | |
270 | val nextIter = T {start = start + step, | |
271 | step = step, | |
272 | bound = bound, | |
273 | invert = invert} | |
274 | val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1) | |
275 | in | |
276 | (newVar::rVars, newStatement::rStmts) | |
277 | end | |
278 | else if limit > 0 then | |
279 | let | |
280 | val (newVar, newStatement) = makeConstStmt(start, wsize) | |
281 | in | |
282 | ([newVar], [newStatement]) | |
283 | end | |
284 | else | |
285 | ([], []) | |
286 | ||
287 | fun makeStepsRec (step: Step, wsize: WordSize.t, | |
288 | var: Var.t, times: IntInf.t, | |
289 | vList: Var.t list, sList: Statement.t list) | |
290 | : Var.t list * Statement.t list = | |
291 | if times > 0 then | |
292 | let | |
293 | val stepAdd = step * (times - 1) | |
294 | val (newVar, newStatements) = makeVarStmt(stepAdd, wsize, var) | |
295 | in | |
296 | makeStepsRec (step, wsize, var, times - 1, | |
297 | newVar::vList, newStatements @ sList) | |
298 | end | |
299 | else (vList, sList) | |
300 | ||
301 | fun makeSteps (T {step, ...}, | |
302 | wsize: WordSize.t, | |
303 | var: Var.t, | |
304 | times: IntInf.t) | |
305 | : Var.t list * Statement.t list = | |
306 | makeStepsRec (step, wsize, var, times, [], []) | |
307 | ||
308 | end | |
309 | ||
310 | fun logli (l: Layout.t, i: int): unit = | |
311 | Control.diagnostics | |
312 | (fn display => | |
313 | display(Layout.indent(l, i * 2))) | |
314 | ||
315 | fun logsi (s: string, i: int): unit = | |
316 | logli((Layout.str s), i) | |
317 | ||
318 | fun logs (s: string): unit = | |
319 | logsi(s, 0) | |
320 | ||
321 | fun logstat (x: int ref, s: string): unit = | |
322 | logs (concat[Int.toString(!x), " ", s]) | |
323 | ||
324 | fun listPop lst = | |
325 | case lst of | |
326 | [] => [] | |
327 | | _::tl => tl | |
328 | ||
329 | (* If a block was renamed, return the new name. | |
330 | Otherwise return the old name. *) | |
331 | fun fixLabel (getBlockInfo: Label.t -> BlockInfo, | |
332 | label: Label.t, | |
333 | origLabels: Label.t vector): Label.t = | |
334 | if Vector.contains(origLabels, label, Label.equals) then | |
335 | let | |
336 | val (name, _) = getBlockInfo(label) | |
337 | in | |
338 | name | |
339 | end | |
340 | else | |
341 | label | |
342 | ||
343 | fun varOptEquals (v1: Var.t, v2: Var.t option): bool = | |
344 | case v2 of | |
345 | NONE => false | |
346 | | SOME (v2') => Var.equals (v1, v2') | |
347 | ||
348 | (* For an binary operation where one argument is a constant, | |
349 | load that constant. | |
350 | Returns the variable, the constant, and true if the var was the first arg *) | |
351 | fun varConst (args, loadVar, signed): (Var.t * IntInf.int * bool) option = | |
352 | let | |
353 | val a1 = Vector.sub (args, 0) | |
354 | val a2 = Vector.sub (args, 1) | |
355 | val a1v = loadVar(a1, signed) | |
356 | val a2v = loadVar(a2, signed) | |
357 | in | |
358 | case (a1v, a2v) of | |
359 | (SOME x, NONE) => SOME (a2, x, false) | |
360 | | (NONE, SOME x) => SOME (a1, x, true) | |
361 | | _ => NONE | |
362 | end | |
363 | ||
364 | (* Given: | |
365 | - an argument vector with two arguments | |
366 | - a primative operaton that is an addition or subtraction of a const value | |
367 | - a function from variables to their constant values | |
368 | Returns: | |
369 | - The non-const variable and the constant value in terms of addition *) | |
370 | fun checkPrim (args, prim, loadVar) = | |
371 | case Prim.name prim of | |
372 | Name.Word_add _ => | |
373 | (case varConst(args, loadVar, false) of | |
374 | SOME(nextVar, x, _) => SOME (nextVar, x) | |
375 | | NONE => NONE) | |
376 | | Name.Word_addCheck (_, {signed}) => | |
377 | (case varConst(args, loadVar, signed) of | |
378 | SOME(nextVar, x, _) => SOME(nextVar, x) | |
379 | | NONE => NONE) | |
380 | | Name.Word_sub _ => | |
381 | (case varConst(args, loadVar, false) of | |
382 | SOME(nextVar, x, _) => SOME (nextVar, ~x) | |
383 | | NONE => NONE) | |
384 | | Name.Word_subCheck (_, {signed}) => | |
385 | (case varConst(args, loadVar, signed) of | |
386 | SOME(nextVar, x, _) => SOME (nextVar, ~x) | |
387 | | NONE => NONE) | |
388 | | _ => NONE | |
389 | ||
390 | (* Given: | |
391 | - a variable in the loop | |
392 | - another variable in the loop | |
393 | - the loop body | |
394 | - a function from variables to their constant values | |
395 | - a starting value, if the transfer to the header is an arith transfer | |
396 | Returns: | |
397 | - Some x such that the value of origVar in loop iteration i+1 is equal to | |
398 | (the value of origVar in iteration i) + x, | |
399 | or None if the step couldn't be computed *) | |
400 | fun varChain (origVar, endVar, blocks, loadVar, total) = | |
401 | case Var.equals (origVar, endVar) of | |
402 | true => SOME (total) | |
403 | | false => | |
404 | let | |
405 | val endVarAssign = Vector.peekMap (blocks, fn b => | |
406 | let | |
407 | val stmts = Block.statements b | |
408 | val assignments = Vector.keepAllMap (stmts, fn s => | |
409 | case varOptEquals (endVar, Statement.var s) of | |
410 | false => NONE | |
411 | | true => | |
412 | (case Statement.exp s of | |
413 | Exp.PrimApp {args, prim, ...} => | |
414 | checkPrim (args, prim, loadVar) | |
415 | | _ => NONE)) | |
416 | val label = Block.label b | |
417 | val blockArgs = Block.args b | |
418 | (* If we found the assignment or the block isn't unary, | |
419 | skip this step *) | |
420 | val arithTransfers = | |
421 | if ((Vector.length assignments) > 0) orelse | |
422 | ((Vector.length blockArgs) <> 1) | |
423 | then | |
424 | Vector.new0 () | |
425 | else | |
426 | let | |
427 | val (blockArg, _) = Vector.sub (blockArgs, 0) | |
428 | val blockEntrys = Vector.keepAllMap (blocks, fn b' => | |
429 | case Block.transfer b' of | |
430 | Transfer.Arith {args, prim, success, ...} => | |
431 | if Label.equals (label, success) then | |
432 | SOME(checkPrim(args, prim, loadVar)) | |
433 | else NONE | |
434 | | Transfer.Call {return, ...} => | |
435 | (case return of | |
436 | Return.NonTail {cont, ...} => | |
437 | if Label.equals (label, cont) then | |
438 | SOME(NONE) | |
439 | else NONE | |
440 | | _ => NONE) | |
441 | | Transfer.Case {cases, ...} => | |
442 | (case cases of | |
443 | Cases.Con v => | |
444 | if Vector.exists (v, fn (_, lbl) => | |
445 | Label.equals (label, lbl)) then | |
446 | SOME(NONE) | |
447 | else | |
448 | NONE | |
449 | | Cases.Word (_, v) => | |
450 | if Vector.exists (v, fn (_, lbl) => | |
451 | Label.equals (label, lbl)) then | |
452 | SOME(NONE) | |
453 | else NONE) | |
454 | | Transfer.Goto {args, dst} => | |
455 | if Label.equals (label, dst) then | |
456 | SOME(SOME(Vector.sub (args, 0), 0)) | |
457 | else NONE | |
458 | | _ => NONE) | |
459 | in | |
460 | if Var.equals (endVar, blockArg) then | |
461 | blockEntrys | |
462 | else | |
463 | Vector.new0 () | |
464 | end | |
465 | val assignments' = | |
466 | if Vector.length (arithTransfers) > 0 then | |
467 | case (Vector.fold (arithTransfers, | |
468 | Vector.sub (arithTransfers, 0), | |
469 | fn (trans, trans') => | |
470 | case (trans, trans') of | |
471 | (SOME(a1, v1), SOME(a2, v2)) => | |
472 | if Var.equals (a1, a2) andalso | |
473 | v1 = v2 then | |
474 | trans | |
475 | else | |
476 | NONE | |
477 | | _ => NONE)) of | |
478 | SOME(a, v) => Vector.new1 (a, v) | |
479 | | NONE => assignments | |
480 | else | |
481 | assignments | |
482 | in | |
483 | case Vector.length assignments' of | |
484 | 0 => NONE | |
485 | | 1 => SOME (Vector.sub (assignments', 0)) | |
486 | | _ => raise Fail "Multiple assignments in SSA form!" | |
487 | end) | |
488 | in | |
489 | case endVarAssign of | |
490 | NONE => NONE | |
491 | | SOME (nextVar, x) => | |
492 | varChain(origVar, nextVar, blocks, loadVar, x + total) | |
493 | end | |
494 | ||
495 | (* Given: | |
496 | - a list of loop body labels | |
497 | - a transfer on a boolean value where one branch exits the loop | |
498 | and the other continues | |
499 | Returns: | |
500 | - the label that exits the loop | |
501 | - the label that continues the loop | |
502 | - true if the continue branch is the true branch | |
503 | *) | |
504 | fun loopExit (loopLabels: Label.t vector, transfer: Transfer.t) | |
505 | : (Label.t * Label.t * bool) = | |
506 | case transfer of | |
507 | (* This should be a case statement on a boolean, | |
508 | so all dsts should be unary. | |
509 | One should transfer outside the loop, the other inside. *) | |
510 | Transfer.Case {cases, default, ...} => | |
511 | (case default of | |
512 | SOME(defaultLabel) => | |
513 | let | |
514 | val (caseCon, caseLabel) = | |
515 | case cases of | |
516 | Cases.Con v => Vector.sub (v, 0) | |
517 | | _ => raise Fail "This should be a con" | |
518 | in | |
519 | if Vector.contains (loopLabels, defaultLabel, Label.equals) then | |
520 | (caseLabel, | |
521 | defaultLabel, | |
522 | Con.equals (Con.fromBool false, caseCon)) | |
523 | else | |
524 | (defaultLabel, | |
525 | caseLabel, | |
526 | Con.equals (Con.fromBool true, caseCon)) | |
527 | end | |
528 | | NONE => | |
529 | (case cases of | |
530 | Cases.Con v => | |
531 | let | |
532 | val (c1, d1) = Vector.sub (v, 0) | |
533 | val (c2, d2) = Vector.sub (v, 1) | |
534 | in | |
535 | if Vector.contains (loopLabels, d1, Label.equals) then | |
536 | (d2, d1, Con.equals (Con.fromBool true, c1)) | |
537 | else | |
538 | (d1, d2, Con.equals (Con.fromBool true, c2)) | |
539 | end | |
540 | | _ => raise Fail "This should be a con")) | |
541 | ||
542 | | _ => raise Fail "This should be a case statement" | |
543 | ||
544 | fun isLoopBranch (loopLabels, cases, default) = | |
545 | case default of | |
546 | SOME (defaultLabel) => | |
547 | (case cases of | |
548 | Cases.Con v => | |
549 | if (Vector.length v) = 1 then | |
550 | let | |
551 | val (_, caseLabel) = Vector.sub (v, 0) | |
552 | val defaultInLoop = | |
553 | Vector.contains (loopLabels, defaultLabel, Label.equals) | |
554 | val caseInLoop = | |
555 | Vector.contains (loopLabels, caseLabel, Label.equals) | |
556 | in | |
557 | defaultInLoop <> caseInLoop | |
558 | end | |
559 | else | |
560 | false | |
561 | | _ => false) | |
562 | | NONE => | |
563 | (case cases of | |
564 | Cases.Con v => | |
565 | if (Vector.length v) = 2 then | |
566 | let | |
567 | val (_, c1) = Vector.sub (v, 0) | |
568 | val (_, c2) = Vector.sub (v, 1) | |
569 | val c1il = Vector.contains (loopLabels, c1, Label.equals) | |
570 | val c2il = Vector.contains (loopLabels, c2, Label.equals) | |
571 | in | |
572 | c1il <> c2il | |
573 | end | |
574 | else | |
575 | false | |
576 | | _ => false) | |
577 | ||
578 | fun transfersToHeader (headerLabel, block) = | |
579 | case Block.transfer block of | |
580 | Transfer.Arith {success, ...} => | |
581 | Label.equals (headerLabel, success) | |
582 | | Transfer.Call {return, ...} => | |
583 | (case return of | |
584 | Return.NonTail {handler, ...} => | |
585 | (case handler of | |
586 | Handler.Handle l => Label.equals (headerLabel, l) | |
587 | | _ => false) | |
588 | | _ => false) | |
589 | | Transfer.Case {cases, ...} => | |
590 | (* We don't have to check default because we know the header isn't nullary *) | |
591 | (case cases of | |
592 | Cases.Con v => | |
593 | Vector.exists (v, (fn (_, lbl) => Label.equals (headerLabel, lbl))) | |
594 | | Cases.Word (_, v) => | |
595 | Vector.exists (v, (fn (_, lbl) => Label.equals (headerLabel, lbl)))) | |
596 | | Transfer.Goto {dst, ...} => | |
597 | Label.equals (headerLabel, dst) | |
598 | | Transfer.Runtime {return, ...} => | |
599 | Label.equals(headerLabel, return) | |
600 | | _ => false | |
601 | ||
602 | (* Given: | |
603 | - a loop phi variable | |
604 | - that variables index in the loop header's arguments | |
605 | - that variables constant entry value (if it has one) | |
606 | - the loop header block | |
607 | - the loop body block | |
608 | - a function from variables to their constant values | |
609 | Returns: | |
610 | - a Loop structure for unrolling that phi var, if one exists *) | |
611 | fun checkArg ((argVar, _), argIndex, entryArg, header, loopBody, | |
612 | loadVar: Var.t * bool -> IntInf.t option, domInfo, depth) = | |
613 | case entryArg of | |
614 | NONE => (logsi ("Can't unroll: entry arg not constant", depth) ; | |
615 | ++varEntryArg ; | |
616 | NONE) | |
617 | | SOME (entryX, entryXSigned) => | |
618 | let | |
619 | val headerLabel = Block.label header | |
620 | val unsupportedTransfer = ref false | |
621 | ||
622 | (* For every transfer to the start of the loop | |
623 | get the variable at argIndex *) | |
624 | val loopVars = Vector.keepAllMap (loopBody, fn block => | |
625 | case Block.transfer block of | |
626 | Transfer.Arith {args, prim, success, ...} => | |
627 | if Label.equals (headerLabel, success) then | |
628 | case checkPrim (args, prim, loadVar) of | |
629 | NONE => (unsupportedTransfer := true ; NONE) | |
630 | | SOME (arg, x) => SOME (arg, x) | |
631 | else NONE | |
632 | | Transfer.Call {return, ...} => | |
633 | (case return of | |
634 | Return.NonTail {cont, ...} => | |
635 | if Label.equals (headerLabel, cont) then | |
636 | (unsupportedTransfer := true ; NONE) | |
637 | else NONE | |
638 | | _ => NONE) | |
639 | | Transfer.Case {cases, ...} => | |
640 | (case cases of | |
641 | Cases.Con v => | |
642 | if Vector.exists(v, fn (_, lbl) => | |
643 | Label.equals (headerLabel, lbl)) then | |
644 | (unsupportedTransfer := true ; NONE) | |
645 | else NONE | |
646 | | Cases.Word (_, v) => | |
647 | if Vector.exists(v, fn (_, lbl) => | |
648 | Label.equals (headerLabel, lbl)) then | |
649 | (unsupportedTransfer := true ; NONE) | |
650 | else NONE) | |
651 | | Transfer.Goto {args, dst} => | |
652 | if Label.equals (headerLabel, dst) then | |
653 | SOME (Vector.sub (args, argIndex), 0) | |
654 | else NONE | |
655 | | _ => NONE) | |
656 | in | |
657 | if (Vector.length loopVars) > 1 | |
658 | andalso not (Vector.forall | |
659 | (loopVars, fn (arg, x) => | |
660 | let | |
661 | val (arg0, x0) = Vector.sub (loopVars, 0) | |
662 | in | |
663 | Var.equals (arg0, arg) andalso (x0 = x) | |
664 | end)) | |
665 | then | |
666 | (logsi ("Can't unroll: variant transfer to head of loop", depth) ; | |
667 | ++variantTransfer ; | |
668 | NONE) | |
669 | else if (!unsupportedTransfer) then | |
670 | (logsi ("Can't unroll: unsupported transfer to head of loop", | |
671 | depth) ; | |
672 | ++unsupported ; | |
673 | NONE) | |
674 | else | |
675 | let | |
676 | val (loopVar, x) = Vector.sub (loopVars, 0) | |
677 | in | |
678 | case varChain (argVar, loopVar, loopBody, loadVar, x) of | |
679 | NONE => (logsi ("Can't unroll: can't compute transfer", | |
680 | depth) ; | |
681 | ++ccTransfer ; | |
682 | NONE) | |
683 | | SOME (step) => | |
684 | let | |
685 | fun ltOrGt (vc, signed) = | |
686 | case vc of | |
687 | NONE => NONE | |
688 | | SOME (_, c, b) => | |
689 | if b then | |
690 | SOME(Loop.Lt (c), signed) | |
691 | else | |
692 | SOME(Loop.Gt (c), signed) | |
693 | ||
694 | fun eq (vc, signed) = | |
695 | case vc of | |
696 | NONE => NONE | |
697 | | SOME (_, c, _) => SOME(Loop.Eq (c), signed) | |
698 | val loopLabels = Vector.map (loopBody, Block.label) | |
699 | val transferVarBlock = Vector.peekMap (loopBody, (fn b => | |
700 | let | |
701 | val transferVar = | |
702 | case Block.transfer b of | |
703 | Transfer.Case {cases, default, test} => | |
704 | if isLoopBranch (loopLabels, cases, default) then | |
705 | SOME(test) | |
706 | else NONE | |
707 | | _ => NONE | |
708 | val loopBound = | |
709 | case (transferVar) of | |
710 | NONE => NONE | |
711 | | SOME (tVar) => | |
712 | Vector.peekMap (Block.statements b, | |
713 | (fn s => case Statement.var s of | |
714 | NONE => NONE | |
715 | | SOME (sVar) => | |
716 | if Var.equals (tVar, sVar) then | |
717 | case Statement.exp s of | |
718 | PrimApp {args, prim, ...} => | |
719 | if not (Vector.contains | |
720 | (args, argVar, Var.equals)) | |
721 | then | |
722 | NONE | |
723 | else | |
724 | (case Prim.name prim of | |
725 | Name.Word_lt (_, {signed}) => | |
726 | ltOrGt | |
727 | (varConst (args, | |
728 | loadVar, | |
729 | signed), | |
730 | signed) | |
731 | | Name.Word_equal _ => | |
732 | eq | |
733 | (varConst (args, | |
734 | loadVar, | |
735 | false), | |
736 | false) | |
737 | | _ => NONE) | |
738 | | _ => NONE | |
739 | else NONE)) | |
740 | in | |
741 | case loopBound of | |
742 | NONE => NONE | |
743 | | SOME (bound, signed) => | |
744 | SOME(bound, b, signed) | |
745 | end)) | |
746 | in | |
747 | case transferVarBlock of | |
748 | NONE => | |
749 | (logsi ("Can't unroll: can't determine bound", depth) ; | |
750 | ++varBound ; | |
751 | NONE) | |
752 | | SOME(bound, block, signed) => | |
753 | let | |
754 | val headerTransferBlocks = | |
755 | Vector.keepAll(loopBody, (fn b => | |
756 | transfersToHeader (headerLabel, b))) | |
757 | val boundDominates = Vector.forall (headerTransferBlocks, | |
758 | (fn b => List.exists ((domInfo (Block.label b)), | |
759 | (fn l => Label.equals | |
760 | ((Block.label block), l))))) | |
761 | val loopLabels = Vector.map (loopBody, Block.label) | |
762 | val (_, _, contIsTrue) = | |
763 | loopExit (loopLabels, Block.transfer block) | |
764 | val entryVal = if signed then entryXSigned | |
765 | else entryX | |
766 | in | |
767 | if boundDominates then | |
768 | SOME (argIndex, | |
769 | block, | |
770 | Loop.T {start = entryVal, | |
771 | step = step, | |
772 | bound = bound, | |
773 | invert = not contIsTrue}) | |
774 | else | |
775 | (logsi ("Can't unroll: bound doesn't dominate", depth) ; | |
776 | ++boundDom ; | |
777 | NONE) | |
778 | end | |
779 | end | |
780 | end | |
781 | end | |
782 | (* Check all of a loop's entry point arguments to see if a constant value. | |
783 | Returns a list of int options where SOME(x) is always x for each entry. *) | |
784 | fun findConstantStart (entryArgs: | |
785 | (((IntInf.t * IntInf.t) option) vector) vector) | |
786 | : ((IntInf.t * IntInf.t) option) vector = | |
787 | if (Vector.length entryArgs) > 0 then | |
788 | Vector.rev (Vector.fold (entryArgs, Vector.sub (entryArgs, 0), | |
789 | fn (v1, v2) => Vector.fromList ( | |
790 | Vector.fold2 (v1, v2, [], fn (a1, a2, lst) => | |
791 | case (a1, a2) of | |
792 | (SOME(x1, x1'), SOME(x2, _)) => | |
793 | if x1 = x2 then SOME(x1, x1')::lst | |
794 | else NONE::lst | |
795 | | _ => NONE::lst)))) | |
796 | else Vector.new0 () | |
797 | ||
798 | (* Look for any optimization opportunities in the loop. *) | |
799 | fun findOpportunity(functionBody: Block.t vector, | |
800 | loopBody: Block.t vector, | |
801 | loopHeaders: Block.t vector, | |
802 | loadGlobal: Var.t * bool -> IntInf.t option, | |
803 | domInfo: Label.t -> Label.t list, | |
804 | depth: int): | |
805 | (int * Block.t * Loop.t) option = | |
806 | if (Vector.length loopHeaders) = 1 then | |
807 | let | |
808 | val header = Vector.sub (loopHeaders, 0) | |
809 | val headerArgs = Block.args header | |
810 | val headerLabel = Block.label header | |
811 | val () = logsi (concat["Evaluating loop with header: ", | |
812 | Label.toString headerLabel], depth - 1) | |
813 | fun blockEquals (b1, b2) = | |
814 | Label.equals (Block.label b1, Block.label b2) | |
815 | val emptyArgs = SOME(Vector.new (Vector.length headerArgs, NONE)) | |
816 | val entryArgs = Vector.keepAllMap(functionBody, fn block => | |
817 | if Vector.contains (loopBody, block, blockEquals) then | |
818 | NONE | |
819 | else case Block.transfer block of | |
820 | Transfer.Arith {success, ...} => | |
821 | if Label.equals (headerLabel, success) then | |
822 | emptyArgs | |
823 | else NONE | |
824 | | Transfer.Call {return, ...} => | |
825 | (case return of | |
826 | Return.NonTail {cont, ...} => | |
827 | if Label.equals (headerLabel, cont) then | |
828 | emptyArgs | |
829 | else NONE | |
830 | | _ => NONE) | |
831 | | Transfer.Case {cases, ...} => | |
832 | (case cases of | |
833 | Cases.Con v => | |
834 | if Vector.exists (v, fn (_, lbl) => | |
835 | Label.equals (headerLabel, lbl)) then | |
836 | emptyArgs | |
837 | else | |
838 | NONE | |
839 | | Cases.Word (_, v) => | |
840 | if Vector.exists (v, fn (_, lbl) => | |
841 | Label.equals (headerLabel, lbl)) then | |
842 | emptyArgs | |
843 | else NONE) | |
844 | | Transfer.Goto {args, dst} => | |
845 | if Label.equals (dst, headerLabel) then | |
846 | SOME(Vector.map (args, fn a => | |
847 | case (loadGlobal(a, false), | |
848 | loadGlobal(a, true)) | |
849 | of | |
850 | (NONE, NONE) => NONE | |
851 | | (SOME v1, SOME v2) => SOME (v1, v2) | |
852 | | _ => raise Fail "Impossible")) | |
853 | else NONE | |
854 | | _ => NONE) | |
855 | val () = logsi (concat["Loop has ", | |
856 | Int.toString (Vector.length entryArgs), | |
857 | " entry points"], depth - 1) | |
858 | val constantArgs = findConstantStart entryArgs | |
859 | val unrollableArgs = | |
860 | Vector.keepAllMapi | |
861 | (headerArgs, fn (i, arg) => ( | |
862 | logsi (concat["Checking arg: ", Var.toString (#1 arg)], depth) ; | |
863 | checkArg (arg, i, Vector.sub (constantArgs, i), | |
864 | header, loopBody, loadGlobal, domInfo, depth + 1))) | |
865 | in | |
866 | if (Vector.length unrollableArgs) > 0 then | |
867 | SOME(Vector.sub (unrollableArgs, 0)) | |
868 | else NONE | |
869 | end | |
870 | else | |
871 | (logsi ("Can't optimize: loop has more than 1 header", depth) ; | |
872 | multiHeaders := (!multiHeaders) + 1 ; | |
873 | NONE) | |
874 | ||
875 | fun makeHeader(oldHeader, (newVars, newStmts), newEntry) = | |
876 | let | |
877 | val oldArgs = Block.args oldHeader | |
878 | val newArgs = Vector.map (oldArgs, fn (arg, _) => arg) | |
879 | val newTransfer = Transfer.Goto {args = newArgs, dst = newEntry} | |
880 | in | |
881 | (Block.T {args = oldArgs, | |
882 | label = Block.label oldHeader, | |
883 | statements = Vector.fromList newStmts, | |
884 | transfer = newTransfer}, | |
885 | newVars) | |
886 | end | |
887 | ||
888 | (* Copy an entire loop. In the header, rewrite the transfer to take the loop branch. | |
889 | In the transfers to the top of the loop, rewrite the transfer to goto next. | |
890 | Ensure that the header is the first element in the list. | |
891 | Replace all instances of argi with argVar *) | |
892 | fun copyLoop(blocks: Block.t vector, | |
893 | nextLabel: Label.t, | |
894 | headerLabel: Label.t, | |
895 | tBlock: Block.t, | |
896 | argi: int, | |
897 | argVar: Var.t, | |
898 | rewriteTransfer: bool, | |
899 | blockInfo: Label.t -> BlockInfo, | |
900 | setBlockInfo: Label.t * BlockInfo -> unit): Block.t vector = | |
901 | let | |
902 | val labels = Vector.map (blocks, Block.label) | |
903 | (* Assign a new label for each block *) | |
904 | val newBlocks = Vector.map (blocks, fn b => | |
905 | let | |
906 | val oldName = Block.label b | |
907 | val oldArgs = Block.args b | |
908 | val newName = Label.newNoname() | |
909 | val () = setBlockInfo(oldName, (newName, oldArgs)) | |
910 | in | |
911 | Block.T {args = Block.args b, | |
912 | label = newName, | |
913 | statements = Block.statements b, | |
914 | transfer = Block.transfer b} | |
915 | end) | |
916 | (* Rewrite the transfers of each block *) | |
917 | val fixedBlocks = Vector.map | |
918 | (newBlocks, fn Block.T {args, label, statements, transfer} => | |
919 | let | |
920 | val f = fn l => fixLabel(blockInfo, l, labels) | |
921 | val isHeader = Label.equals (label, f(headerLabel)) | |
922 | val (newArgs, unrolledArg) = | |
923 | if isHeader then | |
924 | (args, SOME(Vector.sub (args, argi))) | |
925 | else (args, NONE) | |
926 | val newStmts = | |
927 | if isHeader then | |
928 | case unrolledArg of | |
929 | NONE => statements | |
930 | | SOME(var, ty) => | |
931 | let | |
932 | val assignExp = Exp.Var (argVar) | |
933 | val assign = Statement.T {exp = assignExp, | |
934 | ty = ty, | |
935 | var = SOME(var)} | |
936 | val assignV = Vector.new1(assign) | |
937 | in | |
938 | Vector.concat [assignV, statements] | |
939 | end | |
940 | else | |
941 | statements | |
942 | val newTransfer = | |
943 | if rewriteTransfer andalso | |
944 | Label.equals (label, f(Block.label tBlock)) | |
945 | then | |
946 | let | |
947 | val (_, contLabel, _) = loopExit(labels, transfer) | |
948 | in | |
949 | Transfer.Goto {args = Vector.new0 (), dst = f(contLabel)} | |
950 | end | |
951 | else | |
952 | case transfer of | |
953 | Transfer.Arith {args, overflow, prim, success, ty} => | |
954 | if Label.equals (success, headerLabel) then | |
955 | Transfer.Arith {args = args, | |
956 | overflow = f(overflow), | |
957 | prim = prim, | |
958 | success = nextLabel, | |
959 | ty = ty} | |
960 | else | |
961 | Transfer.Arith {args = args, | |
962 | overflow = f(overflow), | |
963 | prim = prim, | |
964 | success = f(success), | |
965 | ty = ty} | |
966 | | Transfer.Call {args, func, return} => | |
967 | let | |
968 | val newReturn = | |
969 | case return of | |
970 | Return.NonTail {cont, handler} => | |
971 | let | |
972 | val newHandler = | |
973 | case handler of | |
974 | Handler.Handle l => Handler.Handle(f(l)) | |
975 | | _ => handler | |
976 | in | |
977 | Return.NonTail {cont = f(cont), handler = newHandler} | |
978 | end | |
979 | | _ => return | |
980 | in | |
981 | Transfer.Call {args = args, func = func, return = newReturn} | |
982 | end | |
983 | | Transfer.Case {cases, default, test} => | |
984 | let | |
985 | val newCases = Cases.map(cases, f) | |
986 | val newDefault = case default of | |
987 | NONE => default | |
988 | | SOME(l) => SOME(f(l)) | |
989 | in | |
990 | Transfer.Case {cases = newCases, | |
991 | default = newDefault, | |
992 | test = test} | |
993 | end | |
994 | | Transfer.Goto {args, dst} => | |
995 | if Label.equals (dst, headerLabel) then | |
996 | Transfer.Goto {args = args, dst = nextLabel} | |
997 | else | |
998 | Transfer.Goto {args = args, dst = f(dst)} | |
999 | | Transfer.Runtime {args, prim, return} => | |
1000 | Transfer.Runtime {args = args, prim = prim, return = f(return)} | |
1001 | | _ => transfer | |
1002 | in | |
1003 | Block.T {args = newArgs, | |
1004 | label = label, | |
1005 | statements = newStmts, | |
1006 | transfer = newTransfer} | |
1007 | end) | |
1008 | in | |
1009 | Vector.rev fixedBlocks | |
1010 | end | |
1011 | ||
1012 | (* Unroll a loop. The header should ALWAYS be the first element | |
1013 | in the returned list. *) | |
1014 | fun unrollLoop (oldHeader, tBlock, argi, loopBlocks, argLabels, | |
1015 | exit, rewriteTransfer, blockInfo, setBlockInfo) = | |
1016 | let | |
1017 | val oldHeaderLabel = Block.label oldHeader | |
1018 | in | |
1019 | case argLabels of | |
1020 | [] => [exit] | |
1021 | | hd::tl => | |
1022 | let | |
1023 | val res = unrollLoop (oldHeader, tBlock, argi, | |
1024 | loopBlocks, tl, exit, rewriteTransfer, | |
1025 | blockInfo, setBlockInfo) | |
1026 | val nextBlockLabel = Block.label (List.first res) | |
1027 | val newLoop = copyLoop(loopBlocks, nextBlockLabel, oldHeaderLabel, | |
1028 | tBlock, argi, hd, rewriteTransfer, | |
1029 | blockInfo, setBlockInfo) | |
1030 | in | |
1031 | (Vector.toList newLoop) @ res | |
1032 | end | |
1033 | end | |
1034 | ||
1035 | (* Given: | |
1036 | - an itertion count | |
1037 | - a loop body | |
1038 | Returns (b, x, y, z) such that: | |
1039 | if b is true | |
1040 | - unroll the loop completely | |
1041 | - x, y, and z are undefined. | |
1042 | if b is false | |
1043 | - x is the number of times to expand the loop body | |
1044 | - y is the number of iterations to run the expanded body (must never be 0) | |
1045 | - z is the number of times to peel the loop body | |
1046 | *) | |
1047 | fun shouldOptimize (iterCount, loopBlocks, depth) = | |
1048 | let | |
1049 | val loopSize' = Block.sizeV (loopBlocks, {sizeExp = Exp.size, sizeTransfer = Transfer.size}) | |
1050 | val loopSize = IntInf.fromInt loopSize' | |
1051 | val unrollLimit = IntInf.fromInt (!Control.loopUnrollLimit) | |
1052 | val () = logsi ("iterations * loop size < unroll factor = can total unroll", | |
1053 | depth) | |
1054 | val canTotalUnroll = (iterCount * loopSize) < unrollLimit | |
1055 | val () = logsi (concat[IntInf.toString iterCount, " * ", | |
1056 | IntInf.toString loopSize, " < ", | |
1057 | IntInf.toString unrollLimit, " = ", | |
1058 | Bool.toString canTotalUnroll], depth) | |
1059 | in | |
1060 | if (iterCount = 1) orelse canTotalUnroll then | |
1061 | (* Loop runs once or it's small enough to unroll *) | |
1062 | (true, 0, 0, 0) | |
1063 | else if loopSize >= unrollLimit then | |
1064 | (* Loop is too big to unroll at all, peel off 1 iteration *) | |
1065 | (false, 1, iterCount - 1, 1) | |
1066 | else | |
1067 | let | |
1068 | val exBodySize = unrollLimit div loopSize | |
1069 | val exIters = iterCount div exBodySize | |
1070 | val leftovers = iterCount - (exIters * exBodySize) | |
1071 | in | |
1072 | if (exIters - 1) < 2 then | |
1073 | (* If the unpeeled loop would run 1 or 0 times, just unroll the | |
1074 | whole thing *) | |
1075 | (true, 0, 0, 0) | |
1076 | else | |
1077 | if leftovers = 0 then | |
1078 | (* If we don't get any unpeelings naturally, force one *) | |
1079 | (false, exBodySize, exIters - 1, exBodySize) | |
1080 | else | |
1081 | (* Otherwise stick them on the front of the loop *) | |
1082 | (false, exBodySize, exIters, leftovers) | |
1083 | end | |
1084 | end | |
1085 | ||
1086 | fun expandLoop (oldHeader, loopBlocks, loop, tBlock, argi, argSize, oldArg, | |
1087 | exBody, iterBody, exitLabel, blockInfo, setBlockInfo) = | |
1088 | let | |
1089 | (* Make a new loop header with an additional arg *) | |
1090 | val newLoopEntry = Label.newNoname() | |
1091 | val (newLoopHeader, loopArgLabels) = | |
1092 | makeHeader (oldHeader, | |
1093 | Loop.makeSteps (loop, argSize, oldArg, exBody), | |
1094 | newLoopEntry) | |
1095 | val iterVar = Var.newNoname () | |
1096 | val newLoopHeaderArgs' = Vector.concat | |
1097 | [Block.args newLoopHeader, | |
1098 | Vector.new1 (iterVar, Type.word argSize)] | |
1099 | val newLoopHeader' = | |
1100 | Block.T {args = newLoopHeaderArgs', | |
1101 | label = Label.newNoname (), | |
1102 | statements = Block.statements newLoopHeader, | |
1103 | transfer = Block.transfer newLoopHeader} | |
1104 | ||
1105 | (* Make a new goto to the top of the loop increasing the iter by 1 *) | |
1106 | val loopHeaderGoto = | |
1107 | let | |
1108 | val (newVar, newVarStmts) = Loop.makeVarStmt (1, argSize, iterVar) | |
1109 | val nonIterArgs = Vector.map (Block.args oldHeader, fn (a, _) => a) | |
1110 | val newArgs = Vector.concat [nonIterArgs, Vector.new1 (newVar)] | |
1111 | val newTransfer = Transfer.Goto {args = newArgs, | |
1112 | dst = Block.label newLoopHeader'} | |
1113 | in | |
1114 | Block.T {args = Vector.new0 (), | |
1115 | label = Label.newNoname (), | |
1116 | statements = Vector.fromList newVarStmts, | |
1117 | transfer = newTransfer} | |
1118 | end | |
1119 | ||
1120 | val newLoopExit = | |
1121 | let | |
1122 | val (newLimitVar, newLimitStmt) = | |
1123 | Loop.makeConstStmt (iterBody - 1, argSize) | |
1124 | val (newComp, newCompVar) = | |
1125 | let | |
1126 | val newVar = Var.newNoname () | |
1127 | val newTy = Type.datatypee Tycon.bool | |
1128 | val newExp = | |
1129 | PrimApp {args = Vector.new2 (iterVar, newLimitVar), | |
1130 | prim = Prim.wordLt (argSize, {signed = true}), | |
1131 | targs = Vector.new0 ()} | |
1132 | in | |
1133 | (Statement.T {exp = newExp, | |
1134 | ty = newTy, | |
1135 | var = SOME(newVar)}, | |
1136 | newVar) | |
1137 | end | |
1138 | val exitStatements = Vector.new2(newLimitStmt, newComp) | |
1139 | val exitCases = Cases.Con ( | |
1140 | Vector.new1 (Con.fromBool true, | |
1141 | Block.label loopHeaderGoto)) | |
1142 | val exitTransfer = Transfer.Case {cases = exitCases, | |
1143 | default = SOME(exitLabel), | |
1144 | test = newCompVar} | |
1145 | in | |
1146 | Block.T {args = Block.args oldHeader, | |
1147 | label = Label.newNoname (), | |
1148 | statements = exitStatements, | |
1149 | transfer = exitTransfer} | |
1150 | end | |
1151 | ||
1152 | (* Expand the loop exBody times. Rewrite the bound's transfer, | |
1153 | because we know it will always be true and it won't be eliminated | |
1154 | by shrink. *) | |
1155 | val newLoopBlocks = unrollLoop (oldHeader, tBlock, argi, | |
1156 | loopBlocks, loopArgLabels, newLoopExit, | |
1157 | true, blockInfo, setBlockInfo) | |
1158 | val firstLoopBlock = List.first newLoopBlocks | |
1159 | val loopArgs' = Block.args firstLoopBlock | |
1160 | val loopStatements' = Block.statements firstLoopBlock | |
1161 | val loopTransfer' = Block.transfer firstLoopBlock | |
1162 | val newLoopHead = Block.T {args = loopArgs', | |
1163 | label = newLoopEntry, | |
1164 | statements = loopStatements', | |
1165 | transfer = loopTransfer'} | |
1166 | val newLoopBlocks' = newLoopHeader':: | |
1167 | (loopHeaderGoto:: | |
1168 | (newLoopHead:: | |
1169 | (listPop newLoopBlocks))) | |
1170 | in | |
1171 | newLoopBlocks' | |
1172 | end | |
1173 | ||
1174 | (* Attempt to optimize a single loop. Returns a list of blocks to add to the | |
1175 | program and a list of blocks to remove from the program. *) | |
1176 | fun optimizeLoop(allBlocks, headerNodes, loopNodes, | |
1177 | nodeBlock, loadGlobal, domInfo, depth) = | |
1178 | let | |
1179 | val () = ++loopCount | |
1180 | val headers = Vector.map (headerNodes, nodeBlock) | |
1181 | val loopBlocks = Vector.map (loopNodes, nodeBlock) | |
1182 | val loopBlockNames = Vector.map (loopBlocks, Block.label) | |
1183 | val optOpt = findOpportunity (allBlocks, loopBlocks, headers, | |
1184 | loadGlobal, domInfo, depth + 1) | |
1185 | val {get = blockInfo: Label.t -> BlockInfo, | |
1186 | set = setBlockInfo: Label.t * BlockInfo -> unit, destroy} = | |
1187 | Property.destGetSet(Label.plist, | |
1188 | Property.initRaise("blockInfo", Label.layout)) | |
1189 | in | |
1190 | case optOpt of | |
1191 | NONE => ([], []) | |
1192 | | SOME (argi, tBlock, loop) => | |
1193 | if Loop.isInfiniteLoop loop then | |
1194 | (logsi ("Can't unroll: infinite loop", depth) ; | |
1195 | ++infinite ; | |
1196 | logsi (concat["Index: ", Int.toString argi, Loop.toString loop], | |
1197 | depth) ; | |
1198 | ([], [])) | |
1199 | else | |
1200 | let | |
1201 | val () = ++optCount | |
1202 | val oldHeader = Vector.sub (headers, 0) | |
1203 | val oldArgs = Block.args oldHeader | |
1204 | val (oldArg, oldType) = Vector.sub (oldArgs, argi) | |
1205 | val () = logsi (concat["Can unroll loop on ", | |
1206 | Var.toString oldArg], depth) | |
1207 | val () = logsi (concat["Index: ", Int.toString argi, | |
1208 | Loop.toString loop], depth) | |
1209 | val iterCount = Loop.iterCount loop | |
1210 | val () = logsi (concat["Loop will run ", | |
1211 | IntInf.toString iterCount, | |
1212 | " times"], depth) | |
1213 | val () = logsi (concat["Transfer block is ", | |
1214 | Label.toString (Block.label tBlock)], | |
1215 | depth) | |
1216 | val () = Histogram.inc ((!histogram), iterCount) | |
1217 | val (totalUnroll, exBody, iterBody, peel) = | |
1218 | shouldOptimize (iterCount, loopBlocks, depth + 1) | |
1219 | val argSize = case Type.dest oldType of | |
1220 | Type.Word wsize => wsize | |
1221 | | _ => raise Fail "Argument is not of type word" | |
1222 | in | |
1223 | if totalUnroll then | |
1224 | let | |
1225 | val () = ++total | |
1226 | val () = logsi ("Completely unrolling loop", depth) | |
1227 | val newEntry = Label.newNoname() | |
1228 | val (newHeader, argLabels) = | |
1229 | makeHeader (oldHeader, | |
1230 | Loop.makeConstants (loop, argSize, iterCount+1), | |
1231 | newEntry) | |
1232 | val exitBlock = Block.T {args = oldArgs, | |
1233 | label = Label.newNoname (), | |
1234 | statements = Vector.new0 (), | |
1235 | transfer = Transfer.Bug} | |
1236 | (* For each induction variable value, copy the loop's body *) | |
1237 | val newBlocks = unrollLoop (oldHeader, tBlock, argi, | |
1238 | loopBlocks, argLabels, exitBlock, | |
1239 | false, blockInfo, setBlockInfo) | |
1240 | (* Fix the first entry's label *) | |
1241 | val firstBlock = List.first newBlocks | |
1242 | val args' = Block.args firstBlock | |
1243 | val statements' = Block.statements firstBlock | |
1244 | val transfer' = Block.transfer firstBlock | |
1245 | val newHead = Block.T {args = args', | |
1246 | label = newEntry, | |
1247 | statements = statements', | |
1248 | transfer = transfer'} | |
1249 | val newBlocks' = newHeader::(newHead::(listPop newBlocks)) | |
1250 | val () = destroy() | |
1251 | in | |
1252 | (newBlocks', (Vector.toList loopBlockNames)) | |
1253 | end | |
1254 | else | |
1255 | let | |
1256 | val () = ++partial | |
1257 | val () = logsi ("Partially unrolling loop", depth) | |
1258 | val () = logsi (concat["Body expansion: ", | |
1259 | IntInf.toString exBody, | |
1260 | " Body iterations: ", | |
1261 | IntInf.toString iterBody, | |
1262 | " Peel iterations: ", | |
1263 | IntInf.toString peel], | |
1264 | depth) | |
1265 | val oldArgLabels = Vector.map (oldArgs, fn (a, _) => a) | |
1266 | (* Produce an exit loop iteration. *) | |
1267 | val exitEntry = Label.newNoname() | |
1268 | val (exitHeader, exitConsts) = | |
1269 | makeHeader (oldHeader, | |
1270 | Loop.makeLastConstant (loop, argSize), | |
1271 | exitEntry) | |
1272 | val exitHeader' = | |
1273 | Block.T {args = Block.args exitHeader, | |
1274 | label = Label.newNoname (), | |
1275 | statements = Block.statements exitHeader, | |
1276 | transfer = Block.transfer exitHeader} | |
1277 | val exitBlock = Block.T {args = oldArgs, | |
1278 | label = Label.newNoname (), | |
1279 | statements = Vector.new0 (), | |
1280 | transfer = Transfer.Bug} | |
1281 | val exitBlocks = unrollLoop (oldHeader, tBlock, argi, | |
1282 | loopBlocks, exitConsts, exitBlock, | |
1283 | false, blockInfo, setBlockInfo) | |
1284 | val exitFirstBlock = List.first exitBlocks | |
1285 | val exitArgs = Block.args exitFirstBlock | |
1286 | val exitStatements = Block.statements exitFirstBlock | |
1287 | val exitTransfer = Block.transfer exitFirstBlock | |
1288 | val exitHead = Block.T {args = exitArgs, | |
1289 | label = exitEntry, | |
1290 | statements = exitStatements, | |
1291 | transfer = exitTransfer} | |
1292 | val exitGotoLabel = Label.newNoname() | |
1293 | val exitGoto = Block.T {args = Vector.new0 (), | |
1294 | label = exitGotoLabel, | |
1295 | statements = Vector.new0 (), | |
1296 | transfer = | |
1297 | Transfer.Goto | |
1298 | {args = oldArgLabels, | |
1299 | dst = Block.label exitHeader'}} | |
1300 | val exitBlocks' = exitGoto:: | |
1301 | exitHeader':: | |
1302 | (exitHead::(listPop exitBlocks)) | |
1303 | ||
1304 | (* Expand the loop *) | |
1305 | val exLoopBlocks = expandLoop (oldHeader, loopBlocks, loop, | |
1306 | tBlock, argi, argSize, | |
1307 | oldArg, exBody, iterBody, | |
1308 | exitGotoLabel, | |
1309 | blockInfo, setBlockInfo) | |
1310 | (* Make an entry to the expanded loop *) | |
1311 | val exLoopEntry = | |
1312 | let | |
1313 | val (zeroVar, zeroStmt) = Loop.makeConstStmt(0, argSize) | |
1314 | val exLoopHeader = Block.label (List.first exLoopBlocks) | |
1315 | val transferArgs = | |
1316 | Vector.concat [oldArgLabels, Vector.new1(zeroVar)] | |
1317 | val newTransfer = Transfer.Goto {args = transferArgs, | |
1318 | dst = exLoopHeader} | |
1319 | in | |
1320 | Block.T {args = oldArgs, | |
1321 | label = Label.newNoname(), | |
1322 | statements = Vector.new1 zeroStmt, | |
1323 | transfer = newTransfer} | |
1324 | end | |
1325 | (* Make a replacement loop entry *) | |
1326 | val newEntry = Label.newNoname() | |
1327 | val (newHeader, argLabels) = | |
1328 | makeHeader (oldHeader, | |
1329 | Loop.makeConstants (loop, argSize, peel), | |
1330 | newEntry) | |
1331 | (* For each induction variable value, copy the loop's body *) | |
1332 | val newBlocks = unrollLoop (oldHeader, tBlock, argi, | |
1333 | loopBlocks, argLabels, exLoopEntry, | |
1334 | false, blockInfo, setBlockInfo) | |
1335 | (* Fix the first entry's label *) | |
1336 | val firstBlock = List.first newBlocks | |
1337 | val args' = Block.args firstBlock | |
1338 | val statements' = Block.statements firstBlock | |
1339 | val transfer' = Block.transfer firstBlock | |
1340 | val newHead = Block.T {args = args', | |
1341 | label = newEntry, | |
1342 | statements = statements', | |
1343 | transfer = transfer'} | |
1344 | val newBlocks' = newHeader::(newHead::(listPop newBlocks)) | |
1345 | val () = destroy() | |
1346 | in | |
1347 | (newBlocks' @ exLoopBlocks @ exitBlocks', | |
1348 | (Vector.toList loopBlockNames)) | |
1349 | end | |
1350 | end | |
1351 | end | |
1352 | ||
1353 | (* Traverse sub-forests until the innermost loop is found. *) | |
1354 | fun traverseSubForest ({loops, notInLoop}, | |
1355 | allBlocks, | |
1356 | enclosingHeaders, | |
1357 | labelNode, nodeBlock, loadGlobal, domInfo) = | |
1358 | if (Vector.length loops) = 0 then | |
1359 | optimizeLoop(allBlocks, enclosingHeaders, notInLoop, | |
1360 | nodeBlock, loadGlobal, domInfo, 1) | |
1361 | else | |
1362 | Vector.fold(loops, ([], []), fn (loop, (new, remove)) => | |
1363 | let | |
1364 | val (nBlocks, rBlocks) = | |
1365 | traverseLoop(loop, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo) | |
1366 | in | |
1367 | ((new @ nBlocks), (remove @ rBlocks)) | |
1368 | end) | |
1369 | ||
1370 | (* Traverse loops in the loop forest. *) | |
1371 | and traverseLoop ({headers, child}, | |
1372 | allBlocks, | |
1373 | labelNode, nodeBlock, loadGlobal, domInfo) = | |
1374 | traverseSubForest ((Forest.dest child), allBlocks, | |
1375 | headers, labelNode, nodeBlock, loadGlobal, domInfo) | |
1376 | ||
1377 | (* Traverse the top-level loop forest. *) | |
1378 | fun traverseForest ({loops, notInLoop = _}, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo) = | |
1379 | let | |
1380 | (* Gather the blocks to add/remove *) | |
1381 | val (newBlocks, blocksToRemove) = | |
1382 | Vector.fold(loops, ([], []), fn (loop, (new, remove)) => | |
1383 | let | |
1384 | val (nBlocks, rBlocks) = | |
1385 | traverseLoop(loop, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo) | |
1386 | in | |
1387 | ((new @ nBlocks), (remove @ rBlocks)) | |
1388 | end) | |
1389 | val keep: Block.t -> bool = | |
1390 | (fn b => not (List.contains(blocksToRemove, (Block.label b), Label.equals))) | |
1391 | val reducedBlocks = Vector.keepAll(allBlocks, keep) | |
1392 | in | |
1393 | (Vector.toList reducedBlocks) @ newBlocks | |
1394 | end | |
1395 | ||
1396 | fun setDoms tree = | |
1397 | let | |
1398 | val {get = domInfo: Label.t -> Label.t list, | |
1399 | set = setDomInfo: Label.t * Label.t list -> unit, destroy} = | |
1400 | Property.destGetSet(Label.plist, | |
1401 | Property.initRaise("domInfo", Label.layout)) | |
1402 | fun loop (tree, doms) = | |
1403 | case tree of | |
1404 | Tree.T (block, children) => | |
1405 | (setDomInfo (Block.label block, doms) ; | |
1406 | Vector.foreach (children, fn tree => loop(tree, | |
1407 | (Block.label block)::doms))) | |
1408 | val () = loop (tree, []) | |
1409 | in | |
1410 | (domInfo, destroy) | |
1411 | end | |
1412 | ||
1413 | (* Performs the optimization on the body of a single function. *) | |
1414 | fun optimizeFunction loadGlobal function = | |
1415 | let | |
1416 | val {graph, labelNode, nodeBlock} = Function.controlFlow function | |
1417 | val {args, blocks, mayInline, name, raises, returns, start} = | |
1418 | Function.dest function | |
1419 | val fsize = Function.size (function, {sizeExp = Exp.size, sizeTransfer = Transfer.size}) | |
1420 | val () = logs (concat["Optimizing function: ", Func.toString name, | |
1421 | " of size ", Int.toString fsize]) | |
1422 | val root = labelNode start | |
1423 | val forest = Graph.loopForestSteensgaard(graph, {root = root}) | |
1424 | val dtree = Function.dominatorTree function | |
1425 | val (domInfo, destroy) = setDoms dtree | |
1426 | val newBlocks = traverseForest((Forest.dest forest), | |
1427 | blocks, labelNode, nodeBlock, loadGlobal, domInfo) | |
1428 | val () = destroy() | |
1429 | in | |
1430 | Function.new {args = args, | |
1431 | blocks = Vector.fromList(newBlocks), | |
1432 | mayInline = mayInline, | |
1433 | name = name, | |
1434 | raises = raises, | |
1435 | returns = returns, | |
1436 | start = start} | |
1437 | end | |
1438 | ||
1439 | (* Entry point. *) | |
1440 | fun transform (Program.T {datatypes, globals, functions, main}) = | |
1441 | let | |
1442 | fun loadGlobal (var: Var.t, signed: bool): IntInf.t option = | |
1443 | let | |
1444 | fun matchGlobal v g = | |
1445 | case Statement.var g of | |
1446 | NONE => false | |
1447 | | SOME (v') => Var.equals (v, v') | |
1448 | in | |
1449 | case Vector.peek (globals, matchGlobal var) of | |
1450 | NONE => NONE | |
1451 | | SOME (stmt) => | |
1452 | (case Statement.exp stmt of | |
1453 | Exp.Const c => | |
1454 | (case c of | |
1455 | Const.Word w => | |
1456 | if signed then | |
1457 | SOME(WordX.toIntInfX w) | |
1458 | else | |
1459 | SOME(WordX.toIntInf w) | |
1460 | | _ => NONE) | |
1461 | | _ => NONE) | |
1462 | end | |
1463 | val () = loopCount := 0 | |
1464 | val () = total := 0 | |
1465 | val () = partial := 0 | |
1466 | val () = optCount := 0 | |
1467 | val () = multiHeaders := 0 | |
1468 | val () = varEntryArg := 0 | |
1469 | val () = variantTransfer := 0 | |
1470 | val () = unsupported := 0 | |
1471 | val () = ccTransfer := 0 | |
1472 | val () = varBound := 0 | |
1473 | val () = infinite := 0 | |
1474 | val () = boundDom := 0 | |
1475 | val () = histogram := Histogram.new () | |
1476 | val () = logs (concat["Unrolling loops. Unrolling factor = ", | |
1477 | Int.toString (!Control.loopUnrollLimit)]) | |
1478 | val optimizedFunctions = List.map (functions, optimizeFunction loadGlobal) | |
1479 | val restore = restoreFunction {globals = globals} | |
1480 | val () = logs "Performing SSA restore" | |
1481 | val cleanedFunctions = List.map (optimizedFunctions, restore) | |
1482 | val shrink = shrinkFunction {globals = globals} | |
1483 | val () = logs "Performing shrink" | |
1484 | val shrunkFunctions = List.map (cleanedFunctions, shrink) | |
1485 | val () = logstat (loopCount, | |
1486 | "total innermost loops") | |
1487 | val () = logstat (optCount, | |
1488 | "loops optimized") | |
1489 | val () = logstat (total, | |
1490 | "loops completely unrolled") | |
1491 | val () = logstat (partial, | |
1492 | "loops partially unrolled") | |
1493 | val () = logstat (multiHeaders, | |
1494 | "loops had multiple headers") | |
1495 | val () = logstat (varEntryArg, | |
1496 | "variable entry values") | |
1497 | val () = logstat (variantTransfer, | |
1498 | "loops had variant transfers to the header") | |
1499 | val () = logstat (unsupported, | |
1500 | "loops had unsupported transfers to the header") | |
1501 | val () = logstat (ccTransfer, | |
1502 | "loops had non-computable steps") | |
1503 | val () = logstat (varBound, | |
1504 | "loops had variable bounds") | |
1505 | val () = logstat (infinite, | |
1506 | "infinite loops") | |
1507 | val () = logstat (boundDom, | |
1508 | "loops had non-dominating bounds") | |
1509 | val () = logs ("Iterations: Occurences") | |
1510 | val () = logs (Histogram.toString (!histogram)) | |
1511 | val () = logs "Done." | |
1512 | in | |
1513 | Program.T {datatypes = datatypes, | |
1514 | globals = globals, | |
1515 | functions = shrunkFunctions, | |
1516 | main = main} | |
1517 | end | |
1518 | ||
1519 | end |