1 (* Copyright (C) 2016 Matthew Surawski.
3 * MLton is released under a BSD-style license.
4 * See the file MLton-LICENSE for details.
7 (* Reduces or eliminates the iteration count of loops by duplicating
10 functor LoopUnroll (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
14 open Exp Transfer Prim
16 structure Graph = DirectedGraph
20 structure Forest = LoopForest
23 fun ++ (v: int ref): unit =
29 type t = (IntInf.t * int ref) HashSet.t
31 fun inc (set: t, key: IntInf.t): unit =
33 val _ = HashSet.insertIfNew (set, IntInf.hash key,
34 (fn (k, _) => k = key),
35 (fn () => (key, ref 1)),
42 HashSet.new {hash = fn (k, _) => IntInf.hash k}
44 fun toList (set: t): (IntInf.t * int ref) list =
47 fun toString (set: t) : string =
51 List.fold (eles, "", fn ((k, r), s) => concat[s,
52 IntInf.toString k, ": ",
53 Int.toString (!r), "\n"])
61 val multiHeaders = ref 0
62 val varEntryArg = ref 0
63 val variantTransfer = ref 0
64 val unsupported = ref 0
65 val ccTransfer = ref 0
69 val histogram = ref (Histogram.new ())
71 type BlockInfo = Label.t * (Var.t * Type.t) vector
75 datatype Bound = Eq of IntInf.t | Lt of IntInf.t | Gt of IntInf.t
78 datatype t = T of {start: Start, step: Step, bound: Bound, invert: bool}
80 fun toString (T {start, step, bound, invert}): string =
82 val boundStr = case bound of
83 Eq b => if invert then
84 concat ["!= ", IntInf.toString b]
86 concat ["= ", IntInf.toString b]
87 | Lt b => if invert then
88 concat ["!< ", IntInf.toString b]
90 concat ["< ", IntInf.toString b]
91 | Gt b => if invert then
92 concat ["!> ", IntInf.toString b]
94 concat ["> ", IntInf.toString b]
96 concat[" Start: ", IntInf.toString start,
97 " Step: ", IntInf.toString step,
101 fun isInfiniteLoop (T {start, step, bound, invert}): bool =
107 else if start < b andalso step > 0 then
108 not (((b - start) mod step) = 0)
109 else if start > b andalso step < 0 then
110 not (((start - b) mod (~step)) = 0)
117 start >= b andalso step >= 0
119 start < b andalso step <= 0
122 start <= b andalso step <= 0
124 start > b andalso step >= 0
127 fun iters (start: IntInf.t, step: IntInf.t, max: IntInf.t): IntInf.t =
129 val range = max - start
130 val iters = range div step
131 val adds = range mod step
139 (* Assumes isInfiniteLoop is false, otherwise the result is undefined. *)
140 fun iterCount (T {start, step, bound, invert}): IntInf.t =
148 (case (start >= b, invert) of
150 | (true, true) => iters (b - 1, ~step, start)
152 | (false, false) => iters (start, step, b))
154 (case (start <= b, invert) of
156 | (true, true) => iters (start, step, b + 1)
158 | (false, false) => iters (b, ~step, start))
160 fun makeConstStmt (v: IntInf.t, wsize: WordSize.t): Var.t * Statement.t =
162 val newWord = WordX.fromIntInf (v, wsize)
163 val newConst = Const.word newWord
164 val newExp = Exp.Const (newConst)
165 val newType = Type.word wsize
166 val newVar = Var.newNoname()
167 val newStatement = Statement.T {exp = newExp,
171 (newVar, newStatement)
174 fun makeVarStmt (v: IntInf.t, wsize: WordSize.t, var: Var.t)
175 : Var.t * Statement.t list =
177 val (cVar, cStmt) = makeConstStmt (v, wsize)
178 val newExp = Exp.PrimApp {args = Vector.new2 (var, cVar),
179 prim = Prim.wordAdd wsize,
180 targs = Vector.new0 ()}
181 val newType = Type.word wsize
182 val newVar = Var.newNoname()
183 val newStatement = Statement.T {exp = newExp,
187 (newVar, [cStmt, newStatement])
194 A variable and statement for the constant value after the loops final
195 iteration. This value will make the loop exit.
196 Assumes isInfiniteLoop is false, otherwise the result is undefined.
198 fun makeLastConstant (T {start, step, bound, invert},
200 : Var.t list * Statement.t list =
202 val ic = iterCount (T {start = start,
206 val last = start + (step * ic)
207 val (newVar, newStatement) = makeConstStmt(last, wsize)
209 ([newVar], [newStatement])
217 A pair of variables and statements for those variables
218 for each iteration of the loop.
219 This should go 1 step beyond the end of the loop.
220 Assumes isInfiniteLoop is false, otherwise this will run forever. *)
221 fun makeConstants (T {start, step, bound, invert},
224 : Var.t list * Statement.t list =
227 if (start = b) <> invert andalso limit > 0 then
229 val (newVar, newStatement) = makeConstStmt(start, wsize)
230 val nextIter = T {start = start + step,
234 val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1)
236 (newVar::rVars, newStatement::rStmts)
238 else if limit > 0 then
240 val (newVar, newStatement) = makeConstStmt(start, wsize)
242 ([newVar], [newStatement])
247 if (start < b) <> invert andalso limit > 0 then
249 val (newVar, newStatement) = makeConstStmt(start, wsize)
250 val nextIter = T {start = start + step,
254 val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1)
256 (newVar::rVars, newStatement::rStmts)
258 else if limit > 0 then
260 val (newVar, newStatement) = makeConstStmt(start, wsize)
262 ([newVar], [newStatement])
267 if (start > b) <> invert andalso limit > 0 then
269 val (newVar, newStatement) = makeConstStmt(start, wsize)
270 val nextIter = T {start = start + step,
274 val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1)
276 (newVar::rVars, newStatement::rStmts)
278 else if limit > 0 then
280 val (newVar, newStatement) = makeConstStmt(start, wsize)
282 ([newVar], [newStatement])
287 fun makeStepsRec (step: Step, wsize: WordSize.t,
288 var: Var.t, times: IntInf.t,
289 vList: Var.t list, sList: Statement.t list)
290 : Var.t list * Statement.t list =
293 val stepAdd = step * (times - 1)
294 val (newVar, newStatements) = makeVarStmt(stepAdd, wsize, var)
296 makeStepsRec (step, wsize, var, times - 1,
297 newVar::vList, newStatements @ sList)
301 fun makeSteps (T {step, ...},
305 : Var.t list * Statement.t list =
306 makeStepsRec (step, wsize, var, times, [], [])
310 fun logli (l: Layout.t, i: int): unit =
313 display(Layout.indent(l, i * 2)))
315 fun logsi (s: string, i: int): unit =
316 logli((Layout.str s), i)
318 fun logs (s: string): unit =
321 fun logstat (x: int ref, s: string): unit =
322 logs (concat[Int.toString(!x), " ", s])
329 (* If a block was renamed, return the new name.
330 Otherwise return the old name. *)
331 fun fixLabel (getBlockInfo: Label.t -> BlockInfo,
333 origLabels: Label.t vector): Label.t =
334 if Vector.contains(origLabels, label, Label.equals) then
336 val (name, _) = getBlockInfo(label)
343 fun varOptEquals (v1: Var.t, v2: Var.t option): bool =
346 | SOME (v2') => Var.equals (v1, v2')
348 (* For an binary operation where one argument is a constant,
350 Returns the variable, the constant, and true if the var was the first arg *)
351 fun varConst (args, loadVar, signed): (Var.t * IntInf.int * bool) option =
353 val a1 = Vector.sub (args, 0)
354 val a2 = Vector.sub (args, 1)
355 val a1v = loadVar(a1, signed)
356 val a2v = loadVar(a2, signed)
359 (SOME x, NONE) => SOME (a2, x, false)
360 | (NONE, SOME x) => SOME (a1, x, true)
365 - an argument vector with two arguments
366 - a primative operaton that is an addition or subtraction of a const value
367 - a function from variables to their constant values
369 - The non-const variable and the constant value in terms of addition *)
370 fun checkPrim (args, prim, loadVar) =
371 case Prim.name prim of
373 (case varConst(args, loadVar, false) of
374 SOME(nextVar, x, _) => SOME (nextVar, x)
376 | Name.Word_addCheck (_, {signed}) =>
377 (case varConst(args, loadVar, signed) of
378 SOME(nextVar, x, _) => SOME(nextVar, x)
381 (case varConst(args, loadVar, false) of
382 SOME(nextVar, x, _) => SOME (nextVar, ~x)
384 | Name.Word_subCheck (_, {signed}) =>
385 (case varConst(args, loadVar, signed) of
386 SOME(nextVar, x, _) => SOME (nextVar, ~x)
391 - a variable in the loop
392 - another variable in the loop
394 - a function from variables to their constant values
395 - a starting value, if the transfer to the header is an arith transfer
397 - Some x such that the value of origVar in loop iteration i+1 is equal to
398 (the value of origVar in iteration i) + x,
399 or None if the step couldn't be computed *)
400 fun varChain (origVar, endVar, blocks, loadVar, total) =
401 case Var.equals (origVar, endVar) of
405 val endVarAssign = Vector.peekMap (blocks, fn b =>
407 val stmts = Block.statements b
408 val assignments = Vector.keepAllMap (stmts, fn s =>
409 case varOptEquals (endVar, Statement.var s) of
412 (case Statement.exp s of
413 Exp.PrimApp {args, prim, ...} =>
414 checkPrim (args, prim, loadVar)
416 val label = Block.label b
417 val blockArgs = Block.args b
418 (* If we found the assignment or the block isn't unary,
421 if ((Vector.length assignments) > 0) orelse
422 ((Vector.length blockArgs) <> 1)
427 val (blockArg, _) = Vector.sub (blockArgs, 0)
428 val blockEntrys = Vector.keepAllMap (blocks, fn b' =>
429 case Block.transfer b' of
430 Transfer.Arith {args, prim, success, ...} =>
431 if Label.equals (label, success) then
432 SOME(checkPrim(args, prim, loadVar))
434 | Transfer.Call {return, ...} =>
436 Return.NonTail {cont, ...} =>
437 if Label.equals (label, cont) then
441 | Transfer.Case {cases, ...} =>
444 if Vector.exists (v, fn (_, lbl) =>
445 Label.equals (label, lbl)) then
449 | Cases.Word (_, v) =>
450 if Vector.exists (v, fn (_, lbl) =>
451 Label.equals (label, lbl)) then
454 | Transfer.Goto {args, dst} =>
455 if Label.equals (label, dst) then
456 SOME(SOME(Vector.sub (args, 0), 0))
460 if Var.equals (endVar, blockArg) then
466 if Vector.length (arithTransfers) > 0 then
467 case (Vector.fold (arithTransfers,
468 Vector.sub (arithTransfers, 0),
469 fn (trans, trans') =>
470 case (trans, trans') of
471 (SOME(a1, v1), SOME(a2, v2)) =>
472 if Var.equals (a1, a2) andalso
478 SOME(a, v) => Vector.new1 (a, v)
479 | NONE => assignments
483 case Vector.length assignments' of
485 | 1 => SOME (Vector.sub (assignments', 0))
486 | _ => raise Fail "Multiple assignments in SSA form!"
491 | SOME (nextVar, x) =>
492 varChain(origVar, nextVar, blocks, loadVar, x + total)
496 - a list of loop body labels
497 - a transfer on a boolean value where one branch exits the loop
498 and the other continues
500 - the label that exits the loop
501 - the label that continues the loop
502 - true if the continue branch is the true branch
504 fun loopExit (loopLabels: Label.t vector, transfer: Transfer.t)
505 : (Label.t * Label.t * bool) =
507 (* This should be a case statement on a boolean,
508 so all dsts should be unary.
509 One should transfer outside the loop, the other inside. *)
510 Transfer.Case {cases, default, ...} =>
512 SOME(defaultLabel) =>
514 val (caseCon, caseLabel) =
516 Cases.Con v => Vector.sub (v, 0)
517 | _ => raise Fail "This should be a con"
519 if Vector.contains (loopLabels, defaultLabel, Label.equals) then
522 Con.equals (Con.fromBool false, caseCon))
526 Con.equals (Con.fromBool true, caseCon))
532 val (c1, d1) = Vector.sub (v, 0)
533 val (c2, d2) = Vector.sub (v, 1)
535 if Vector.contains (loopLabels, d1, Label.equals) then
536 (d2, d1, Con.equals (Con.fromBool true, c1))
538 (d1, d2, Con.equals (Con.fromBool true, c2))
540 | _ => raise Fail "This should be a con"))
542 | _ => raise Fail "This should be a case statement"
544 fun isLoopBranch (loopLabels, cases, default) =
546 SOME (defaultLabel) =>
549 if (Vector.length v) = 1 then
551 val (_, caseLabel) = Vector.sub (v, 0)
553 Vector.contains (loopLabels, defaultLabel, Label.equals)
555 Vector.contains (loopLabels, caseLabel, Label.equals)
557 defaultInLoop <> caseInLoop
565 if (Vector.length v) = 2 then
567 val (_, c1) = Vector.sub (v, 0)
568 val (_, c2) = Vector.sub (v, 1)
569 val c1il = Vector.contains (loopLabels, c1, Label.equals)
570 val c2il = Vector.contains (loopLabels, c2, Label.equals)
578 fun transfersToHeader (headerLabel, block) =
579 case Block.transfer block of
580 Transfer.Arith {success, ...} =>
581 Label.equals (headerLabel, success)
582 | Transfer.Call {return, ...} =>
584 Return.NonTail {handler, ...} =>
586 Handler.Handle l => Label.equals (headerLabel, l)
589 | Transfer.Case {cases, ...} =>
590 (* We don't have to check default because we know the header isn't nullary *)
593 Vector.exists (v, (fn (_, lbl) => Label.equals (headerLabel, lbl)))
594 | Cases.Word (_, v) =>
595 Vector.exists (v, (fn (_, lbl) => Label.equals (headerLabel, lbl))))
596 | Transfer.Goto {dst, ...} =>
597 Label.equals (headerLabel, dst)
598 | Transfer.Runtime {return, ...} =>
599 Label.equals(headerLabel, return)
603 - a loop phi variable
604 - that variables index in the loop header's arguments
605 - that variables constant entry value (if it has one)
606 - the loop header block
607 - the loop body block
608 - a function from variables to their constant values
610 - a Loop structure for unrolling that phi var, if one exists *)
611 fun checkArg ((argVar, _), argIndex, entryArg, header, loopBody,
612 loadVar: Var.t * bool -> IntInf.t option, domInfo, depth) =
614 NONE => (logsi ("Can't unroll: entry arg not constant", depth) ;
617 | SOME (entryX, entryXSigned) =>
619 val headerLabel = Block.label header
620 val unsupportedTransfer = ref false
622 (* For every transfer to the start of the loop
623 get the variable at argIndex *)
624 val loopVars = Vector.keepAllMap (loopBody, fn block =>
625 case Block.transfer block of
626 Transfer.Arith {args, prim, success, ...} =>
627 if Label.equals (headerLabel, success) then
628 case checkPrim (args, prim, loadVar) of
629 NONE => (unsupportedTransfer := true ; NONE)
630 | SOME (arg, x) => SOME (arg, x)
632 | Transfer.Call {return, ...} =>
634 Return.NonTail {cont, ...} =>
635 if Label.equals (headerLabel, cont) then
636 (unsupportedTransfer := true ; NONE)
639 | Transfer.Case {cases, ...} =>
642 if Vector.exists(v, fn (_, lbl) =>
643 Label.equals (headerLabel, lbl)) then
644 (unsupportedTransfer := true ; NONE)
646 | Cases.Word (_, v) =>
647 if Vector.exists(v, fn (_, lbl) =>
648 Label.equals (headerLabel, lbl)) then
649 (unsupportedTransfer := true ; NONE)
651 | Transfer.Goto {args, dst} =>
652 if Label.equals (headerLabel, dst) then
653 SOME (Vector.sub (args, argIndex), 0)
657 if (Vector.length loopVars) > 1
658 andalso not (Vector.forall
659 (loopVars, fn (arg, x) =>
661 val (arg0, x0) = Vector.sub (loopVars, 0)
663 Var.equals (arg0, arg) andalso (x0 = x)
666 (logsi ("Can't unroll: variant transfer to head of loop", depth) ;
669 else if (!unsupportedTransfer) then
670 (logsi ("Can't unroll: unsupported transfer to head of loop",
676 val (loopVar, x) = Vector.sub (loopVars, 0)
678 case varChain (argVar, loopVar, loopBody, loadVar, x) of
679 NONE => (logsi ("Can't unroll: can't compute transfer",
685 fun ltOrGt (vc, signed) =
690 SOME(Loop.Lt (c), signed)
692 SOME(Loop.Gt (c), signed)
694 fun eq (vc, signed) =
697 | SOME (_, c, _) => SOME(Loop.Eq (c), signed)
698 val loopLabels = Vector.map (loopBody, Block.label)
699 val transferVarBlock = Vector.peekMap (loopBody, (fn b =>
702 case Block.transfer b of
703 Transfer.Case {cases, default, test} =>
704 if isLoopBranch (loopLabels, cases, default) then
709 case (transferVar) of
712 Vector.peekMap (Block.statements b,
713 (fn s => case Statement.var s of
716 if Var.equals (tVar, sVar) then
717 case Statement.exp s of
718 PrimApp {args, prim, ...} =>
719 if not (Vector.contains
720 (args, argVar, Var.equals))
724 (case Prim.name prim of
725 Name.Word_lt (_, {signed}) =>
731 | Name.Word_equal _ =>
743 | SOME (bound, signed) =>
744 SOME(bound, b, signed)
747 case transferVarBlock of
749 (logsi ("Can't unroll: can't determine bound", depth) ;
752 | SOME(bound, block, signed) =>
754 val headerTransferBlocks =
755 Vector.keepAll(loopBody, (fn b =>
756 transfersToHeader (headerLabel, b)))
757 val boundDominates = Vector.forall (headerTransferBlocks,
758 (fn b => List.exists ((domInfo (Block.label b)),
759 (fn l => Label.equals
760 ((Block.label block), l)))))
761 val loopLabels = Vector.map (loopBody, Block.label)
762 val (_, _, contIsTrue) =
763 loopExit (loopLabels, Block.transfer block)
764 val entryVal = if signed then entryXSigned
767 if boundDominates then
770 Loop.T {start = entryVal,
773 invert = not contIsTrue})
775 (logsi ("Can't unroll: bound doesn't dominate", depth) ;
782 (* Check all of a loop's entry point arguments to see if a constant value.
783 Returns a list of int options where SOME(x) is always x for each entry. *)
784 fun findConstantStart (entryArgs:
785 (((IntInf.t * IntInf.t) option) vector) vector)
786 : ((IntInf.t * IntInf.t) option) vector =
787 if (Vector.length entryArgs) > 0 then
788 Vector.rev (Vector.fold (entryArgs, Vector.sub (entryArgs, 0),
789 fn (v1, v2) => Vector.fromList (
790 Vector.fold2 (v1, v2, [], fn (a1, a2, lst) =>
792 (SOME(x1, x1'), SOME(x2, _)) =>
793 if x1 = x2 then SOME(x1, x1')::lst
798 (* Look for any optimization opportunities in the loop. *)
799 fun findOpportunity(functionBody: Block.t vector,
800 loopBody: Block.t vector,
801 loopHeaders: Block.t vector,
802 loadGlobal: Var.t * bool -> IntInf.t option,
803 domInfo: Label.t -> Label.t list,
805 (int * Block.t * Loop.t) option =
806 if (Vector.length loopHeaders) = 1 then
808 val header = Vector.sub (loopHeaders, 0)
809 val headerArgs = Block.args header
810 val headerLabel = Block.label header
811 val () = logsi (concat["Evaluating loop with header: ",
812 Label.toString headerLabel], depth - 1)
813 fun blockEquals (b1, b2) =
814 Label.equals (Block.label b1, Block.label b2)
815 val emptyArgs = SOME(Vector.new (Vector.length headerArgs, NONE))
816 val entryArgs = Vector.keepAllMap(functionBody, fn block =>
817 if Vector.contains (loopBody, block, blockEquals) then
819 else case Block.transfer block of
820 Transfer.Arith {success, ...} =>
821 if Label.equals (headerLabel, success) then
824 | Transfer.Call {return, ...} =>
826 Return.NonTail {cont, ...} =>
827 if Label.equals (headerLabel, cont) then
831 | Transfer.Case {cases, ...} =>
834 if Vector.exists (v, fn (_, lbl) =>
835 Label.equals (headerLabel, lbl)) then
839 | Cases.Word (_, v) =>
840 if Vector.exists (v, fn (_, lbl) =>
841 Label.equals (headerLabel, lbl)) then
844 | Transfer.Goto {args, dst} =>
845 if Label.equals (dst, headerLabel) then
846 SOME(Vector.map (args, fn a =>
847 case (loadGlobal(a, false),
851 | (SOME v1, SOME v2) => SOME (v1, v2)
852 | _ => raise Fail "Impossible"))
855 val () = logsi (concat["Loop has ",
856 Int.toString (Vector.length entryArgs),
857 " entry points"], depth - 1)
858 val constantArgs = findConstantStart entryArgs
861 (headerArgs, fn (i, arg) => (
862 logsi (concat["Checking arg: ", Var.toString (#1 arg)], depth) ;
863 checkArg (arg, i, Vector.sub (constantArgs, i),
864 header, loopBody, loadGlobal, domInfo, depth + 1)))
866 if (Vector.length unrollableArgs) > 0 then
867 SOME(Vector.sub (unrollableArgs, 0))
871 (logsi ("Can't optimize: loop has more than 1 header", depth) ;
872 multiHeaders := (!multiHeaders) + 1 ;
875 fun makeHeader(oldHeader, (newVars, newStmts), newEntry) =
877 val oldArgs = Block.args oldHeader
878 val newArgs = Vector.map (oldArgs, fn (arg, _) => arg)
879 val newTransfer = Transfer.Goto {args = newArgs, dst = newEntry}
881 (Block.T {args = oldArgs,
882 label = Block.label oldHeader,
883 statements = Vector.fromList newStmts,
884 transfer = newTransfer},
888 (* Copy an entire loop. In the header, rewrite the transfer to take the loop branch.
889 In the transfers to the top of the loop, rewrite the transfer to goto next.
890 Ensure that the header is the first element in the list.
891 Replace all instances of argi with argVar *)
892 fun copyLoop(blocks: Block.t vector,
894 headerLabel: Label.t,
898 rewriteTransfer: bool,
899 blockInfo: Label.t -> BlockInfo,
900 setBlockInfo: Label.t * BlockInfo -> unit): Block.t vector =
902 val labels = Vector.map (blocks, Block.label)
903 (* Assign a new label for each block *)
904 val newBlocks = Vector.map (blocks, fn b =>
906 val oldName = Block.label b
907 val oldArgs = Block.args b
908 val newName = Label.newNoname()
909 val () = setBlockInfo(oldName, (newName, oldArgs))
911 Block.T {args = Block.args b,
913 statements = Block.statements b,
914 transfer = Block.transfer b}
916 (* Rewrite the transfers of each block *)
917 val fixedBlocks = Vector.map
918 (newBlocks, fn Block.T {args, label, statements, transfer} =>
920 val f = fn l => fixLabel(blockInfo, l, labels)
921 val isHeader = Label.equals (label, f(headerLabel))
922 val (newArgs, unrolledArg) =
924 (args, SOME(Vector.sub (args, argi)))
932 val assignExp = Exp.Var (argVar)
933 val assign = Statement.T {exp = assignExp,
936 val assignV = Vector.new1(assign)
938 Vector.concat [assignV, statements]
943 if rewriteTransfer andalso
944 Label.equals (label, f(Block.label tBlock))
947 val (_, contLabel, _) = loopExit(labels, transfer)
949 Transfer.Goto {args = Vector.new0 (), dst = f(contLabel)}
953 Transfer.Arith {args, overflow, prim, success, ty} =>
954 if Label.equals (success, headerLabel) then
955 Transfer.Arith {args = args,
956 overflow = f(overflow),
961 Transfer.Arith {args = args,
962 overflow = f(overflow),
964 success = f(success),
966 | Transfer.Call {args, func, return} =>
970 Return.NonTail {cont, handler} =>
974 Handler.Handle l => Handler.Handle(f(l))
977 Return.NonTail {cont = f(cont), handler = newHandler}
981 Transfer.Call {args = args, func = func, return = newReturn}
983 | Transfer.Case {cases, default, test} =>
985 val newCases = Cases.map(cases, f)
986 val newDefault = case default of
988 | SOME(l) => SOME(f(l))
990 Transfer.Case {cases = newCases,
991 default = newDefault,
994 | Transfer.Goto {args, dst} =>
995 if Label.equals (dst, headerLabel) then
996 Transfer.Goto {args = args, dst = nextLabel}
998 Transfer.Goto {args = args, dst = f(dst)}
999 | Transfer.Runtime {args, prim, return} =>
1000 Transfer.Runtime {args = args, prim = prim, return = f(return)}
1003 Block.T {args = newArgs,
1005 statements = newStmts,
1006 transfer = newTransfer}
1009 Vector.rev fixedBlocks
1012 (* Unroll a loop. The header should ALWAYS be the first element
1013 in the returned list. *)
1014 fun unrollLoop (oldHeader, tBlock, argi, loopBlocks, argLabels,
1015 exit, rewriteTransfer, blockInfo, setBlockInfo) =
1017 val oldHeaderLabel = Block.label oldHeader
1023 val res = unrollLoop (oldHeader, tBlock, argi,
1024 loopBlocks, tl, exit, rewriteTransfer,
1025 blockInfo, setBlockInfo)
1026 val nextBlockLabel = Block.label (List.first res)
1027 val newLoop = copyLoop(loopBlocks, nextBlockLabel, oldHeaderLabel,
1028 tBlock, argi, hd, rewriteTransfer,
1029 blockInfo, setBlockInfo)
1031 (Vector.toList newLoop) @ res
1038 Returns (b, x, y, z) such that:
1040 - unroll the loop completely
1041 - x, y, and z are undefined.
1043 - x is the number of times to expand the loop body
1044 - y is the number of iterations to run the expanded body (must never be 0)
1045 - z is the number of times to peel the loop body
1047 fun shouldOptimize (iterCount, loopBlocks, depth) =
1049 val loopSize' = Block.sizeV (loopBlocks, {sizeExp = Exp.size, sizeTransfer = Transfer.size})
1050 val loopSize = IntInf.fromInt loopSize'
1051 val unrollLimit = IntInf.fromInt (!Control.loopUnrollLimit)
1052 val () = logsi ("iterations * loop size < unroll factor = can total unroll",
1054 val canTotalUnroll = (iterCount * loopSize) < unrollLimit
1055 val () = logsi (concat[IntInf.toString iterCount, " * ",
1056 IntInf.toString loopSize, " < ",
1057 IntInf.toString unrollLimit, " = ",
1058 Bool.toString canTotalUnroll], depth)
1060 if (iterCount = 1) orelse canTotalUnroll then
1061 (* Loop runs once or it's small enough to unroll *)
1063 else if loopSize >= unrollLimit then
1064 (* Loop is too big to unroll at all, peel off 1 iteration *)
1065 (false, 1, iterCount - 1, 1)
1068 val exBodySize = unrollLimit div loopSize
1069 val exIters = iterCount div exBodySize
1070 val leftovers = iterCount - (exIters * exBodySize)
1072 if (exIters - 1) < 2 then
1073 (* If the unpeeled loop would run 1 or 0 times, just unroll the
1077 if leftovers = 0 then
1078 (* If we don't get any unpeelings naturally, force one *)
1079 (false, exBodySize, exIters - 1, exBodySize)
1081 (* Otherwise stick them on the front of the loop *)
1082 (false, exBodySize, exIters, leftovers)
1086 fun expandLoop (oldHeader, loopBlocks, loop, tBlock, argi, argSize, oldArg,
1087 exBody, iterBody, exitLabel, blockInfo, setBlockInfo) =
1089 (* Make a new loop header with an additional arg *)
1090 val newLoopEntry = Label.newNoname()
1091 val (newLoopHeader, loopArgLabels) =
1092 makeHeader (oldHeader,
1093 Loop.makeSteps (loop, argSize, oldArg, exBody),
1095 val iterVar = Var.newNoname ()
1096 val newLoopHeaderArgs' = Vector.concat
1097 [Block.args newLoopHeader,
1098 Vector.new1 (iterVar, Type.word argSize)]
1099 val newLoopHeader' =
1100 Block.T {args = newLoopHeaderArgs',
1101 label = Label.newNoname (),
1102 statements = Block.statements newLoopHeader,
1103 transfer = Block.transfer newLoopHeader}
1105 (* Make a new goto to the top of the loop increasing the iter by 1 *)
1106 val loopHeaderGoto =
1108 val (newVar, newVarStmts) = Loop.makeVarStmt (1, argSize, iterVar)
1109 val nonIterArgs = Vector.map (Block.args oldHeader, fn (a, _) => a)
1110 val newArgs = Vector.concat [nonIterArgs, Vector.new1 (newVar)]
1111 val newTransfer = Transfer.Goto {args = newArgs,
1112 dst = Block.label newLoopHeader'}
1114 Block.T {args = Vector.new0 (),
1115 label = Label.newNoname (),
1116 statements = Vector.fromList newVarStmts,
1117 transfer = newTransfer}
1122 val (newLimitVar, newLimitStmt) =
1123 Loop.makeConstStmt (iterBody - 1, argSize)
1124 val (newComp, newCompVar) =
1126 val newVar = Var.newNoname ()
1127 val newTy = Type.datatypee Tycon.bool
1129 PrimApp {args = Vector.new2 (iterVar, newLimitVar),
1130 prim = Prim.wordLt (argSize, {signed = true}),
1131 targs = Vector.new0 ()}
1133 (Statement.T {exp = newExp,
1135 var = SOME(newVar)},
1138 val exitStatements = Vector.new2(newLimitStmt, newComp)
1139 val exitCases = Cases.Con (
1140 Vector.new1 (Con.fromBool true,
1141 Block.label loopHeaderGoto))
1142 val exitTransfer = Transfer.Case {cases = exitCases,
1143 default = SOME(exitLabel),
1146 Block.T {args = Block.args oldHeader,
1147 label = Label.newNoname (),
1148 statements = exitStatements,
1149 transfer = exitTransfer}
1152 (* Expand the loop exBody times. Rewrite the bound's transfer,
1153 because we know it will always be true and it won't be eliminated
1155 val newLoopBlocks = unrollLoop (oldHeader, tBlock, argi,
1156 loopBlocks, loopArgLabels, newLoopExit,
1157 true, blockInfo, setBlockInfo)
1158 val firstLoopBlock = List.first newLoopBlocks
1159 val loopArgs' = Block.args firstLoopBlock
1160 val loopStatements' = Block.statements firstLoopBlock
1161 val loopTransfer' = Block.transfer firstLoopBlock
1162 val newLoopHead = Block.T {args = loopArgs',
1163 label = newLoopEntry,
1164 statements = loopStatements',
1165 transfer = loopTransfer'}
1166 val newLoopBlocks' = newLoopHeader'::
1169 (listPop newLoopBlocks)))
1174 (* Attempt to optimize a single loop. Returns a list of blocks to add to the
1175 program and a list of blocks to remove from the program. *)
1176 fun optimizeLoop(allBlocks, headerNodes, loopNodes,
1177 nodeBlock, loadGlobal, domInfo, depth) =
1179 val () = ++loopCount
1180 val headers = Vector.map (headerNodes, nodeBlock)
1181 val loopBlocks = Vector.map (loopNodes, nodeBlock)
1182 val loopBlockNames = Vector.map (loopBlocks, Block.label)
1183 val optOpt = findOpportunity (allBlocks, loopBlocks, headers,
1184 loadGlobal, domInfo, depth + 1)
1185 val {get = blockInfo: Label.t -> BlockInfo,
1186 set = setBlockInfo: Label.t * BlockInfo -> unit, destroy} =
1187 Property.destGetSet(Label.plist,
1188 Property.initRaise("blockInfo", Label.layout))
1192 | SOME (argi, tBlock, loop) =>
1193 if Loop.isInfiniteLoop loop then
1194 (logsi ("Can't unroll: infinite loop", depth) ;
1196 logsi (concat["Index: ", Int.toString argi, Loop.toString loop],
1202 val oldHeader = Vector.sub (headers, 0)
1203 val oldArgs = Block.args oldHeader
1204 val (oldArg, oldType) = Vector.sub (oldArgs, argi)
1205 val () = logsi (concat["Can unroll loop on ",
1206 Var.toString oldArg], depth)
1207 val () = logsi (concat["Index: ", Int.toString argi,
1208 Loop.toString loop], depth)
1209 val iterCount = Loop.iterCount loop
1210 val () = logsi (concat["Loop will run ",
1211 IntInf.toString iterCount,
1213 val () = logsi (concat["Transfer block is ",
1214 Label.toString (Block.label tBlock)],
1216 val () = Histogram.inc ((!histogram), iterCount)
1217 val (totalUnroll, exBody, iterBody, peel) =
1218 shouldOptimize (iterCount, loopBlocks, depth + 1)
1219 val argSize = case Type.dest oldType of
1220 Type.Word wsize => wsize
1221 | _ => raise Fail "Argument is not of type word"
1226 val () = logsi ("Completely unrolling loop", depth)
1227 val newEntry = Label.newNoname()
1228 val (newHeader, argLabels) =
1229 makeHeader (oldHeader,
1230 Loop.makeConstants (loop, argSize, iterCount+1),
1232 val exitBlock = Block.T {args = oldArgs,
1233 label = Label.newNoname (),
1234 statements = Vector.new0 (),
1235 transfer = Transfer.Bug}
1236 (* For each induction variable value, copy the loop's body *)
1237 val newBlocks = unrollLoop (oldHeader, tBlock, argi,
1238 loopBlocks, argLabels, exitBlock,
1239 false, blockInfo, setBlockInfo)
1240 (* Fix the first entry's label *)
1241 val firstBlock = List.first newBlocks
1242 val args' = Block.args firstBlock
1243 val statements' = Block.statements firstBlock
1244 val transfer' = Block.transfer firstBlock
1245 val newHead = Block.T {args = args',
1247 statements = statements',
1248 transfer = transfer'}
1249 val newBlocks' = newHeader::(newHead::(listPop newBlocks))
1252 (newBlocks', (Vector.toList loopBlockNames))
1257 val () = logsi ("Partially unrolling loop", depth)
1258 val () = logsi (concat["Body expansion: ",
1259 IntInf.toString exBody,
1260 " Body iterations: ",
1261 IntInf.toString iterBody,
1262 " Peel iterations: ",
1263 IntInf.toString peel],
1265 val oldArgLabels = Vector.map (oldArgs, fn (a, _) => a)
1266 (* Produce an exit loop iteration. *)
1267 val exitEntry = Label.newNoname()
1268 val (exitHeader, exitConsts) =
1269 makeHeader (oldHeader,
1270 Loop.makeLastConstant (loop, argSize),
1273 Block.T {args = Block.args exitHeader,
1274 label = Label.newNoname (),
1275 statements = Block.statements exitHeader,
1276 transfer = Block.transfer exitHeader}
1277 val exitBlock = Block.T {args = oldArgs,
1278 label = Label.newNoname (),
1279 statements = Vector.new0 (),
1280 transfer = Transfer.Bug}
1281 val exitBlocks = unrollLoop (oldHeader, tBlock, argi,
1282 loopBlocks, exitConsts, exitBlock,
1283 false, blockInfo, setBlockInfo)
1284 val exitFirstBlock = List.first exitBlocks
1285 val exitArgs = Block.args exitFirstBlock
1286 val exitStatements = Block.statements exitFirstBlock
1287 val exitTransfer = Block.transfer exitFirstBlock
1288 val exitHead = Block.T {args = exitArgs,
1290 statements = exitStatements,
1291 transfer = exitTransfer}
1292 val exitGotoLabel = Label.newNoname()
1293 val exitGoto = Block.T {args = Vector.new0 (),
1294 label = exitGotoLabel,
1295 statements = Vector.new0 (),
1298 {args = oldArgLabels,
1299 dst = Block.label exitHeader'}}
1300 val exitBlocks' = exitGoto::
1302 (exitHead::(listPop exitBlocks))
1304 (* Expand the loop *)
1305 val exLoopBlocks = expandLoop (oldHeader, loopBlocks, loop,
1306 tBlock, argi, argSize,
1307 oldArg, exBody, iterBody,
1309 blockInfo, setBlockInfo)
1310 (* Make an entry to the expanded loop *)
1313 val (zeroVar, zeroStmt) = Loop.makeConstStmt(0, argSize)
1314 val exLoopHeader = Block.label (List.first exLoopBlocks)
1316 Vector.concat [oldArgLabels, Vector.new1(zeroVar)]
1317 val newTransfer = Transfer.Goto {args = transferArgs,
1320 Block.T {args = oldArgs,
1321 label = Label.newNoname(),
1322 statements = Vector.new1 zeroStmt,
1323 transfer = newTransfer}
1325 (* Make a replacement loop entry *)
1326 val newEntry = Label.newNoname()
1327 val (newHeader, argLabels) =
1328 makeHeader (oldHeader,
1329 Loop.makeConstants (loop, argSize, peel),
1331 (* For each induction variable value, copy the loop's body *)
1332 val newBlocks = unrollLoop (oldHeader, tBlock, argi,
1333 loopBlocks, argLabels, exLoopEntry,
1334 false, blockInfo, setBlockInfo)
1335 (* Fix the first entry's label *)
1336 val firstBlock = List.first newBlocks
1337 val args' = Block.args firstBlock
1338 val statements' = Block.statements firstBlock
1339 val transfer' = Block.transfer firstBlock
1340 val newHead = Block.T {args = args',
1342 statements = statements',
1343 transfer = transfer'}
1344 val newBlocks' = newHeader::(newHead::(listPop newBlocks))
1347 (newBlocks' @ exLoopBlocks @ exitBlocks',
1348 (Vector.toList loopBlockNames))
1353 (* Traverse sub-forests until the innermost loop is found. *)
1354 fun traverseSubForest ({loops, notInLoop},
1357 labelNode, nodeBlock, loadGlobal, domInfo) =
1358 if (Vector.length loops) = 0 then
1359 optimizeLoop(allBlocks, enclosingHeaders, notInLoop,
1360 nodeBlock, loadGlobal, domInfo, 1)
1362 Vector.fold(loops, ([], []), fn (loop, (new, remove)) =>
1364 val (nBlocks, rBlocks) =
1365 traverseLoop(loop, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo)
1367 ((new @ nBlocks), (remove @ rBlocks))
1370 (* Traverse loops in the loop forest. *)
1371 and traverseLoop ({headers, child},
1373 labelNode, nodeBlock, loadGlobal, domInfo) =
1374 traverseSubForest ((Forest.dest child), allBlocks,
1375 headers, labelNode, nodeBlock, loadGlobal, domInfo)
1377 (* Traverse the top-level loop forest. *)
1378 fun traverseForest ({loops, notInLoop = _}, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo) =
1380 (* Gather the blocks to add/remove *)
1381 val (newBlocks, blocksToRemove) =
1382 Vector.fold(loops, ([], []), fn (loop, (new, remove)) =>
1384 val (nBlocks, rBlocks) =
1385 traverseLoop(loop, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo)
1387 ((new @ nBlocks), (remove @ rBlocks))
1389 val keep: Block.t -> bool =
1390 (fn b => not (List.contains(blocksToRemove, (Block.label b), Label.equals)))
1391 val reducedBlocks = Vector.keepAll(allBlocks, keep)
1393 (Vector.toList reducedBlocks) @ newBlocks
1398 val {get = domInfo: Label.t -> Label.t list,
1399 set = setDomInfo: Label.t * Label.t list -> unit, destroy} =
1400 Property.destGetSet(Label.plist,
1401 Property.initRaise("domInfo", Label.layout))
1402 fun loop (tree, doms) =
1404 Tree.T (block, children) =>
1405 (setDomInfo (Block.label block, doms) ;
1406 Vector.foreach (children, fn tree => loop(tree,
1407 (Block.label block)::doms)))
1408 val () = loop (tree, [])
1413 (* Performs the optimization on the body of a single function. *)
1414 fun optimizeFunction loadGlobal function =
1416 val {graph, labelNode, nodeBlock} = Function.controlFlow function
1417 val {args, blocks, mayInline, name, raises, returns, start} =
1418 Function.dest function
1419 val fsize = Function.size (function, {sizeExp = Exp.size, sizeTransfer = Transfer.size})
1420 val () = logs (concat["Optimizing function: ", Func.toString name,
1421 " of size ", Int.toString fsize])
1422 val root = labelNode start
1423 val forest = Graph.loopForestSteensgaard(graph, {root = root})
1424 val dtree = Function.dominatorTree function
1425 val (domInfo, destroy) = setDoms dtree
1426 val newBlocks = traverseForest((Forest.dest forest),
1427 blocks, labelNode, nodeBlock, loadGlobal, domInfo)
1430 Function.new {args = args,
1431 blocks = Vector.fromList(newBlocks),
1432 mayInline = mayInline,
1440 fun transform (Program.T {datatypes, globals, functions, main}) =
1442 fun loadGlobal (var: Var.t, signed: bool): IntInf.t option =
1444 fun matchGlobal v g =
1445 case Statement.var g of
1447 | SOME (v') => Var.equals (v, v')
1449 case Vector.peek (globals, matchGlobal var) of
1452 (case Statement.exp stmt of
1457 SOME(WordX.toIntInfX w)
1459 SOME(WordX.toIntInf w)
1463 val () = loopCount := 0
1465 val () = partial := 0
1466 val () = optCount := 0
1467 val () = multiHeaders := 0
1468 val () = varEntryArg := 0
1469 val () = variantTransfer := 0
1470 val () = unsupported := 0
1471 val () = ccTransfer := 0
1472 val () = varBound := 0
1473 val () = infinite := 0
1474 val () = boundDom := 0
1475 val () = histogram := Histogram.new ()
1476 val () = logs (concat["Unrolling loops. Unrolling factor = ",
1477 Int.toString (!Control.loopUnrollLimit)])
1478 val optimizedFunctions = List.map (functions, optimizeFunction loadGlobal)
1479 val restore = restoreFunction {globals = globals}
1480 val () = logs "Performing SSA restore"
1481 val cleanedFunctions = List.map (optimizedFunctions, restore)
1482 val shrink = shrinkFunction {globals = globals}
1483 val () = logs "Performing shrink"
1484 val shrunkFunctions = List.map (cleanedFunctions, shrink)
1485 val () = logstat (loopCount,
1486 "total innermost loops")
1487 val () = logstat (optCount,
1489 val () = logstat (total,
1490 "loops completely unrolled")
1491 val () = logstat (partial,
1492 "loops partially unrolled")
1493 val () = logstat (multiHeaders,
1494 "loops had multiple headers")
1495 val () = logstat (varEntryArg,
1496 "variable entry values")
1497 val () = logstat (variantTransfer,
1498 "loops had variant transfers to the header")
1499 val () = logstat (unsupported,
1500 "loops had unsupported transfers to the header")
1501 val () = logstat (ccTransfer,
1502 "loops had non-computable steps")
1503 val () = logstat (varBound,
1504 "loops had variable bounds")
1505 val () = logstat (infinite,
1507 val () = logstat (boundDom,
1508 "loops had non-dominating bounds")
1509 val () = logs ("Iterations: Occurences")
1510 val () = logs (Histogram.toString (!histogram))
1511 val () = logs "Done."
1513 Program.T {datatypes = datatypes,
1515 functions = shrunkFunctions,