Import Upstream version 20180207
[hcoop/debian/mlton.git] / mlton / ssa / loop-unroll.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 2016 Matthew Surawski.
2 *
3 * MLton is released under a BSD-style license.
4 * See the file MLton-LICENSE for details.
5 *)
6
7(* Reduces or eliminates the iteration count of loops by duplicating
8 * the loop body.
9 *)
10functor LoopUnroll (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
11struct
12
13open S
14open Exp Transfer Prim
15
16structure Graph = DirectedGraph
17local
18 open Graph
19in
20 structure Forest = LoopForest
21end
22
23fun ++ (v: int ref): unit =
24 v := (!v) + 1
25
26
27structure Histogram =
28 struct
29 type t = (IntInf.t * int ref) HashSet.t
30
31 fun inc (set: t, key: IntInf.t): unit =
32 let
33 val _ = HashSet.insertIfNew (set, IntInf.hash key,
34 (fn (k, _) => k = key),
35 (fn () => (key, ref 1)),
36 (fn (_, r) => ++r))
37 in
38 ()
39 end
40
41 fun new (): t =
42 HashSet.new {hash = fn (k, _) => IntInf.hash k}
43
44 fun toList (set: t): (IntInf.t * int ref) list =
45 HashSet.toList set
46
47 fun toString (set: t) : string =
48 let
49 val eles = toList set
50 in
51 List.fold (eles, "", fn ((k, r), s) => concat[s,
52 IntInf.toString k, ": ",
53 Int.toString (!r), "\n"])
54 end
55 end
56
57val loopCount = ref 0
58val optCount = ref 0
59val total = ref 0
60val partial = ref 0
61val multiHeaders = ref 0
62val varEntryArg = ref 0
63val variantTransfer = ref 0
64val unsupported = ref 0
65val ccTransfer = ref 0
66val varBound = ref 0
67val infinite = ref 0
68val boundDom = ref 0
69val histogram = ref (Histogram.new ())
70
71type BlockInfo = Label.t * (Var.t * Type.t) vector
72
73structure Loop =
74 struct
75 datatype Bound = Eq of IntInf.t | Lt of IntInf.t | Gt of IntInf.t
76 type Start = IntInf.t
77 type Step = IntInf.t
78 datatype t = T of {start: Start, step: Step, bound: Bound, invert: bool}
79
80 fun toString (T {start, step, bound, invert}): string =
81 let
82 val boundStr = case bound of
83 Eq b => if invert then
84 concat ["!= ", IntInf.toString b]
85 else
86 concat ["= ", IntInf.toString b]
87 | Lt b => if invert then
88 concat ["!< ", IntInf.toString b]
89 else
90 concat ["< ", IntInf.toString b]
91 | Gt b => if invert then
92 concat ["!> ", IntInf.toString b]
93 else
94 concat ["> ", IntInf.toString b]
95 in
96 concat[" Start: ", IntInf.toString start,
97 " Step: ", IntInf.toString step,
98 " Bound: ", boundStr]
99 end
100
101 fun isInfiniteLoop (T {start, step, bound, invert}): bool =
102 case bound of
103 Eq b =>
104 if invert then
105 (if start = b then
106 false
107 else if start < b andalso step > 0 then
108 not (((b - start) mod step) = 0)
109 else if start > b andalso step < 0 then
110 not (((start - b) mod (~step)) = 0)
111 else
112 true)
113 else
114 step = 0
115 | Lt b =>
116 if invert then
117 start >= b andalso step >= 0
118 else
119 start < b andalso step <= 0
120 | Gt b =>
121 if invert then
122 start <= b andalso step <= 0
123 else
124 start > b andalso step >= 0
125
126
127 fun iters (start: IntInf.t, step: IntInf.t, max: IntInf.t): IntInf.t =
128 let
129 val range = max - start
130 val iters = range div step
131 val adds = range mod step
132 in
133 if step > range then
134 1
135 else
136 iters + adds
137 end
138
139 (* Assumes isInfiniteLoop is false, otherwise the result is undefined. *)
140 fun iterCount (T {start, step, bound, invert}): IntInf.t =
141 case bound of
142 Eq b =>
143 if invert then
144 (b - start) div step
145 else
146 1
147 | Lt b =>
148 (case (start >= b, invert) of
149 (true, false) => 0
150 | (true, true) => iters (b - 1, ~step, start)
151 | (false, true) => 0
152 | (false, false) => iters (start, step, b))
153 | Gt b =>
154 (case (start <= b, invert) of
155 (true, false) => 0
156 | (true, true) => iters (start, step, b + 1)
157 | (false, true) => 0
158 | (false, false) => iters (b, ~step, start))
159
160 fun makeConstStmt (v: IntInf.t, wsize: WordSize.t): Var.t * Statement.t =
161 let
162 val newWord = WordX.fromIntInf (v, wsize)
163 val newConst = Const.word newWord
164 val newExp = Exp.Const (newConst)
165 val newType = Type.word wsize
166 val newVar = Var.newNoname()
167 val newStatement = Statement.T {exp = newExp,
168 ty = newType,
169 var = SOME(newVar)}
170 in
171 (newVar, newStatement)
172 end
173
174 fun makeVarStmt (v: IntInf.t, wsize: WordSize.t, var: Var.t)
175 : Var.t * Statement.t list =
176 let
177 val (cVar, cStmt) = makeConstStmt (v, wsize)
178 val newExp = Exp.PrimApp {args = Vector.new2 (var, cVar),
179 prim = Prim.wordAdd wsize,
180 targs = Vector.new0 ()}
181 val newType = Type.word wsize
182 val newVar = Var.newNoname()
183 val newStatement = Statement.T {exp = newExp,
184 ty = newType,
185 var = SOME(newVar)}
186 in
187 (newVar, [cStmt, newStatement])
188 end
189
190 (* Given:
191 - a loop
192 - a word size
193 Returns:
194 A variable and statement for the constant value after the loops final
195 iteration. This value will make the loop exit.
196 Assumes isInfiniteLoop is false, otherwise the result is undefined.
197 *)
198 fun makeLastConstant (T {start, step, bound, invert},
199 wsize: WordSize.t)
200 : Var.t list * Statement.t list =
201 let
202 val ic = iterCount (T {start = start,
203 step = step,
204 bound = bound,
205 invert = invert})
206 val last = start + (step * ic)
207 val (newVar, newStatement) = makeConstStmt(last, wsize)
208 in
209 ([newVar], [newStatement])
210 end
211
212 (* Given:
213 - a loop
214 - a word size
215 - an iteration limit
216 Returns:
217 A pair of variables and statements for those variables
218 for each iteration of the loop.
219 This should go 1 step beyond the end of the loop.
220 Assumes isInfiniteLoop is false, otherwise this will run forever. *)
221 fun makeConstants (T {start, step, bound, invert},
222 wsize: WordSize.t,
223 limit: IntInf.t)
224 : Var.t list * Statement.t list =
225 case bound of
226 Eq b =>
227 if (start = b) <> invert andalso limit > 0 then
228 let
229 val (newVar, newStatement) = makeConstStmt(start, wsize)
230 val nextIter = T {start = start + step,
231 step = step,
232 bound = bound,
233 invert = invert}
234 val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1)
235 in
236 (newVar::rVars, newStatement::rStmts)
237 end
238 else if limit > 0 then
239 let
240 val (newVar, newStatement) = makeConstStmt(start, wsize)
241 in
242 ([newVar], [newStatement])
243 end
244 else
245 ([], [])
246 | Lt b =>
247 if (start < b) <> invert andalso limit > 0 then
248 let
249 val (newVar, newStatement) = makeConstStmt(start, wsize)
250 val nextIter = T {start = start + step,
251 step = step,
252 bound = bound,
253 invert = invert}
254 val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1)
255 in
256 (newVar::rVars, newStatement::rStmts)
257 end
258 else if limit > 0 then
259 let
260 val (newVar, newStatement) = makeConstStmt(start, wsize)
261 in
262 ([newVar], [newStatement])
263 end
264 else
265 ([], [])
266 | Gt b =>
267 if (start > b) <> invert andalso limit > 0 then
268 let
269 val (newVar, newStatement) = makeConstStmt(start, wsize)
270 val nextIter = T {start = start + step,
271 step = step,
272 bound = bound,
273 invert = invert}
274 val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1)
275 in
276 (newVar::rVars, newStatement::rStmts)
277 end
278 else if limit > 0 then
279 let
280 val (newVar, newStatement) = makeConstStmt(start, wsize)
281 in
282 ([newVar], [newStatement])
283 end
284 else
285 ([], [])
286
287 fun makeStepsRec (step: Step, wsize: WordSize.t,
288 var: Var.t, times: IntInf.t,
289 vList: Var.t list, sList: Statement.t list)
290 : Var.t list * Statement.t list =
291 if times > 0 then
292 let
293 val stepAdd = step * (times - 1)
294 val (newVar, newStatements) = makeVarStmt(stepAdd, wsize, var)
295 in
296 makeStepsRec (step, wsize, var, times - 1,
297 newVar::vList, newStatements @ sList)
298 end
299 else (vList, sList)
300
301 fun makeSteps (T {step, ...},
302 wsize: WordSize.t,
303 var: Var.t,
304 times: IntInf.t)
305 : Var.t list * Statement.t list =
306 makeStepsRec (step, wsize, var, times, [], [])
307
308 end
309
310fun logli (l: Layout.t, i: int): unit =
311 Control.diagnostics
312 (fn display =>
313 display(Layout.indent(l, i * 2)))
314
315fun logsi (s: string, i: int): unit =
316 logli((Layout.str s), i)
317
318fun logs (s: string): unit =
319 logsi(s, 0)
320
321fun logstat (x: int ref, s: string): unit =
322 logs (concat[Int.toString(!x), " ", s])
323
324fun listPop lst =
325 case lst of
326 [] => []
327 | _::tl => tl
328
329(* If a block was renamed, return the new name.
330 Otherwise return the old name. *)
331fun fixLabel (getBlockInfo: Label.t -> BlockInfo,
332 label: Label.t,
333 origLabels: Label.t vector): Label.t =
334 if Vector.contains(origLabels, label, Label.equals) then
335 let
336 val (name, _) = getBlockInfo(label)
337 in
338 name
339 end
340 else
341 label
342
343fun varOptEquals (v1: Var.t, v2: Var.t option): bool =
344 case v2 of
345 NONE => false
346 | SOME (v2') => Var.equals (v1, v2')
347
348(* For an binary operation where one argument is a constant,
349 load that constant.
350 Returns the variable, the constant, and true if the var was the first arg *)
351fun varConst (args, loadVar, signed): (Var.t * IntInf.int * bool) option =
352 let
353 val a1 = Vector.sub (args, 0)
354 val a2 = Vector.sub (args, 1)
355 val a1v = loadVar(a1, signed)
356 val a2v = loadVar(a2, signed)
357 in
358 case (a1v, a2v) of
359 (SOME x, NONE) => SOME (a2, x, false)
360 | (NONE, SOME x) => SOME (a1, x, true)
361 | _ => NONE
362 end
363
364(* Given:
365 - an argument vector with two arguments
366 - a primative operaton that is an addition or subtraction of a const value
367 - a function from variables to their constant values
368 Returns:
369 - The non-const variable and the constant value in terms of addition *)
370fun checkPrim (args, prim, loadVar) =
371 case Prim.name prim of
372 Name.Word_add _ =>
373 (case varConst(args, loadVar, false) of
374 SOME(nextVar, x, _) => SOME (nextVar, x)
375 | NONE => NONE)
376 | Name.Word_addCheck (_, {signed}) =>
377 (case varConst(args, loadVar, signed) of
378 SOME(nextVar, x, _) => SOME(nextVar, x)
379 | NONE => NONE)
380 | Name.Word_sub _ =>
381 (case varConst(args, loadVar, false) of
382 SOME(nextVar, x, _) => SOME (nextVar, ~x)
383 | NONE => NONE)
384 | Name.Word_subCheck (_, {signed}) =>
385 (case varConst(args, loadVar, signed) of
386 SOME(nextVar, x, _) => SOME (nextVar, ~x)
387 | NONE => NONE)
388 | _ => NONE
389
390(* Given:
391 - a variable in the loop
392 - another variable in the loop
393 - the loop body
394 - a function from variables to their constant values
395 - a starting value, if the transfer to the header is an arith transfer
396 Returns:
397 - Some x such that the value of origVar in loop iteration i+1 is equal to
398 (the value of origVar in iteration i) + x,
399 or None if the step couldn't be computed *)
400fun varChain (origVar, endVar, blocks, loadVar, total) =
401 case Var.equals (origVar, endVar) of
402 true => SOME (total)
403 | false =>
404 let
405 val endVarAssign = Vector.peekMap (blocks, fn b =>
406 let
407 val stmts = Block.statements b
408 val assignments = Vector.keepAllMap (stmts, fn s =>
409 case varOptEquals (endVar, Statement.var s) of
410 false => NONE
411 | true =>
412 (case Statement.exp s of
413 Exp.PrimApp {args, prim, ...} =>
414 checkPrim (args, prim, loadVar)
415 | _ => NONE))
416 val label = Block.label b
417 val blockArgs = Block.args b
418 (* If we found the assignment or the block isn't unary,
419 skip this step *)
420 val arithTransfers =
421 if ((Vector.length assignments) > 0) orelse
422 ((Vector.length blockArgs) <> 1)
423 then
424 Vector.new0 ()
425 else
426 let
427 val (blockArg, _) = Vector.sub (blockArgs, 0)
428 val blockEntrys = Vector.keepAllMap (blocks, fn b' =>
429 case Block.transfer b' of
430 Transfer.Arith {args, prim, success, ...} =>
431 if Label.equals (label, success) then
432 SOME(checkPrim(args, prim, loadVar))
433 else NONE
434 | Transfer.Call {return, ...} =>
435 (case return of
436 Return.NonTail {cont, ...} =>
437 if Label.equals (label, cont) then
438 SOME(NONE)
439 else NONE
440 | _ => NONE)
441 | Transfer.Case {cases, ...} =>
442 (case cases of
443 Cases.Con v =>
444 if Vector.exists (v, fn (_, lbl) =>
445 Label.equals (label, lbl)) then
446 SOME(NONE)
447 else
448 NONE
449 | Cases.Word (_, v) =>
450 if Vector.exists (v, fn (_, lbl) =>
451 Label.equals (label, lbl)) then
452 SOME(NONE)
453 else NONE)
454 | Transfer.Goto {args, dst} =>
455 if Label.equals (label, dst) then
456 SOME(SOME(Vector.sub (args, 0), 0))
457 else NONE
458 | _ => NONE)
459 in
460 if Var.equals (endVar, blockArg) then
461 blockEntrys
462 else
463 Vector.new0 ()
464 end
465 val assignments' =
466 if Vector.length (arithTransfers) > 0 then
467 case (Vector.fold (arithTransfers,
468 Vector.sub (arithTransfers, 0),
469 fn (trans, trans') =>
470 case (trans, trans') of
471 (SOME(a1, v1), SOME(a2, v2)) =>
472 if Var.equals (a1, a2) andalso
473 v1 = v2 then
474 trans
475 else
476 NONE
477 | _ => NONE)) of
478 SOME(a, v) => Vector.new1 (a, v)
479 | NONE => assignments
480 else
481 assignments
482 in
483 case Vector.length assignments' of
484 0 => NONE
485 | 1 => SOME (Vector.sub (assignments', 0))
486 | _ => raise Fail "Multiple assignments in SSA form!"
487 end)
488 in
489 case endVarAssign of
490 NONE => NONE
491 | SOME (nextVar, x) =>
492 varChain(origVar, nextVar, blocks, loadVar, x + total)
493 end
494
495(* Given:
496 - a list of loop body labels
497 - a transfer on a boolean value where one branch exits the loop
498 and the other continues
499 Returns:
500 - the label that exits the loop
501 - the label that continues the loop
502 - true if the continue branch is the true branch
503 *)
504fun loopExit (loopLabels: Label.t vector, transfer: Transfer.t)
505 : (Label.t * Label.t * bool) =
506 case transfer of
507 (* This should be a case statement on a boolean,
508 so all dsts should be unary.
509 One should transfer outside the loop, the other inside. *)
510 Transfer.Case {cases, default, ...} =>
511 (case default of
512 SOME(defaultLabel) =>
513 let
514 val (caseCon, caseLabel) =
515 case cases of
516 Cases.Con v => Vector.sub (v, 0)
517 | _ => raise Fail "This should be a con"
518 in
519 if Vector.contains (loopLabels, defaultLabel, Label.equals) then
520 (caseLabel,
521 defaultLabel,
522 Con.equals (Con.fromBool false, caseCon))
523 else
524 (defaultLabel,
525 caseLabel,
526 Con.equals (Con.fromBool true, caseCon))
527 end
528 | NONE =>
529 (case cases of
530 Cases.Con v =>
531 let
532 val (c1, d1) = Vector.sub (v, 0)
533 val (c2, d2) = Vector.sub (v, 1)
534 in
535 if Vector.contains (loopLabels, d1, Label.equals) then
536 (d2, d1, Con.equals (Con.fromBool true, c1))
537 else
538 (d1, d2, Con.equals (Con.fromBool true, c2))
539 end
540 | _ => raise Fail "This should be a con"))
541
542 | _ => raise Fail "This should be a case statement"
543
544fun isLoopBranch (loopLabels, cases, default) =
545 case default of
546 SOME (defaultLabel) =>
547 (case cases of
548 Cases.Con v =>
549 if (Vector.length v) = 1 then
550 let
551 val (_, caseLabel) = Vector.sub (v, 0)
552 val defaultInLoop =
553 Vector.contains (loopLabels, defaultLabel, Label.equals)
554 val caseInLoop =
555 Vector.contains (loopLabels, caseLabel, Label.equals)
556 in
557 defaultInLoop <> caseInLoop
558 end
559 else
560 false
561 | _ => false)
562 | NONE =>
563 (case cases of
564 Cases.Con v =>
565 if (Vector.length v) = 2 then
566 let
567 val (_, c1) = Vector.sub (v, 0)
568 val (_, c2) = Vector.sub (v, 1)
569 val c1il = Vector.contains (loopLabels, c1, Label.equals)
570 val c2il = Vector.contains (loopLabels, c2, Label.equals)
571 in
572 c1il <> c2il
573 end
574 else
575 false
576 | _ => false)
577
578fun transfersToHeader (headerLabel, block) =
579 case Block.transfer block of
580 Transfer.Arith {success, ...} =>
581 Label.equals (headerLabel, success)
582 | Transfer.Call {return, ...} =>
583 (case return of
584 Return.NonTail {handler, ...} =>
585 (case handler of
586 Handler.Handle l => Label.equals (headerLabel, l)
587 | _ => false)
588 | _ => false)
589 | Transfer.Case {cases, ...} =>
590 (* We don't have to check default because we know the header isn't nullary *)
591 (case cases of
592 Cases.Con v =>
593 Vector.exists (v, (fn (_, lbl) => Label.equals (headerLabel, lbl)))
594 | Cases.Word (_, v) =>
595 Vector.exists (v, (fn (_, lbl) => Label.equals (headerLabel, lbl))))
596 | Transfer.Goto {dst, ...} =>
597 Label.equals (headerLabel, dst)
598 | Transfer.Runtime {return, ...} =>
599 Label.equals(headerLabel, return)
600 | _ => false
601
602(* Given:
603 - a loop phi variable
604 - that variables index in the loop header's arguments
605 - that variables constant entry value (if it has one)
606 - the loop header block
607 - the loop body block
608 - a function from variables to their constant values
609 Returns:
610 - a Loop structure for unrolling that phi var, if one exists *)
611fun checkArg ((argVar, _), argIndex, entryArg, header, loopBody,
612 loadVar: Var.t * bool -> IntInf.t option, domInfo, depth) =
613 case entryArg of
614 NONE => (logsi ("Can't unroll: entry arg not constant", depth) ;
615 ++varEntryArg ;
616 NONE)
617 | SOME (entryX, entryXSigned) =>
618 let
619 val headerLabel = Block.label header
620 val unsupportedTransfer = ref false
621
622 (* For every transfer to the start of the loop
623 get the variable at argIndex *)
624 val loopVars = Vector.keepAllMap (loopBody, fn block =>
625 case Block.transfer block of
626 Transfer.Arith {args, prim, success, ...} =>
627 if Label.equals (headerLabel, success) then
628 case checkPrim (args, prim, loadVar) of
629 NONE => (unsupportedTransfer := true ; NONE)
630 | SOME (arg, x) => SOME (arg, x)
631 else NONE
632 | Transfer.Call {return, ...} =>
633 (case return of
634 Return.NonTail {cont, ...} =>
635 if Label.equals (headerLabel, cont) then
636 (unsupportedTransfer := true ; NONE)
637 else NONE
638 | _ => NONE)
639 | Transfer.Case {cases, ...} =>
640 (case cases of
641 Cases.Con v =>
642 if Vector.exists(v, fn (_, lbl) =>
643 Label.equals (headerLabel, lbl)) then
644 (unsupportedTransfer := true ; NONE)
645 else NONE
646 | Cases.Word (_, v) =>
647 if Vector.exists(v, fn (_, lbl) =>
648 Label.equals (headerLabel, lbl)) then
649 (unsupportedTransfer := true ; NONE)
650 else NONE)
651 | Transfer.Goto {args, dst} =>
652 if Label.equals (headerLabel, dst) then
653 SOME (Vector.sub (args, argIndex), 0)
654 else NONE
655 | _ => NONE)
656 in
657 if (Vector.length loopVars) > 1
658 andalso not (Vector.forall
659 (loopVars, fn (arg, x) =>
660 let
661 val (arg0, x0) = Vector.sub (loopVars, 0)
662 in
663 Var.equals (arg0, arg) andalso (x0 = x)
664 end))
665 then
666 (logsi ("Can't unroll: variant transfer to head of loop", depth) ;
667 ++variantTransfer ;
668 NONE)
669 else if (!unsupportedTransfer) then
670 (logsi ("Can't unroll: unsupported transfer to head of loop",
671 depth) ;
672 ++unsupported ;
673 NONE)
674 else
675 let
676 val (loopVar, x) = Vector.sub (loopVars, 0)
677 in
678 case varChain (argVar, loopVar, loopBody, loadVar, x) of
679 NONE => (logsi ("Can't unroll: can't compute transfer",
680 depth) ;
681 ++ccTransfer ;
682 NONE)
683 | SOME (step) =>
684 let
685 fun ltOrGt (vc, signed) =
686 case vc of
687 NONE => NONE
688 | SOME (_, c, b) =>
689 if b then
690 SOME(Loop.Lt (c), signed)
691 else
692 SOME(Loop.Gt (c), signed)
693
694 fun eq (vc, signed) =
695 case vc of
696 NONE => NONE
697 | SOME (_, c, _) => SOME(Loop.Eq (c), signed)
698 val loopLabels = Vector.map (loopBody, Block.label)
699 val transferVarBlock = Vector.peekMap (loopBody, (fn b =>
700 let
701 val transferVar =
702 case Block.transfer b of
703 Transfer.Case {cases, default, test} =>
704 if isLoopBranch (loopLabels, cases, default) then
705 SOME(test)
706 else NONE
707 | _ => NONE
708 val loopBound =
709 case (transferVar) of
710 NONE => NONE
711 | SOME (tVar) =>
712 Vector.peekMap (Block.statements b,
713 (fn s => case Statement.var s of
714 NONE => NONE
715 | SOME (sVar) =>
716 if Var.equals (tVar, sVar) then
717 case Statement.exp s of
718 PrimApp {args, prim, ...} =>
719 if not (Vector.contains
720 (args, argVar, Var.equals))
721 then
722 NONE
723 else
724 (case Prim.name prim of
725 Name.Word_lt (_, {signed}) =>
726 ltOrGt
727 (varConst (args,
728 loadVar,
729 signed),
730 signed)
731 | Name.Word_equal _ =>
732 eq
733 (varConst (args,
734 loadVar,
735 false),
736 false)
737 | _ => NONE)
738 | _ => NONE
739 else NONE))
740 in
741 case loopBound of
742 NONE => NONE
743 | SOME (bound, signed) =>
744 SOME(bound, b, signed)
745 end))
746 in
747 case transferVarBlock of
748 NONE =>
749 (logsi ("Can't unroll: can't determine bound", depth) ;
750 ++varBound ;
751 NONE)
752 | SOME(bound, block, signed) =>
753 let
754 val headerTransferBlocks =
755 Vector.keepAll(loopBody, (fn b =>
756 transfersToHeader (headerLabel, b)))
757 val boundDominates = Vector.forall (headerTransferBlocks,
758 (fn b => List.exists ((domInfo (Block.label b)),
759 (fn l => Label.equals
760 ((Block.label block), l)))))
761 val loopLabels = Vector.map (loopBody, Block.label)
762 val (_, _, contIsTrue) =
763 loopExit (loopLabels, Block.transfer block)
764 val entryVal = if signed then entryXSigned
765 else entryX
766 in
767 if boundDominates then
768 SOME (argIndex,
769 block,
770 Loop.T {start = entryVal,
771 step = step,
772 bound = bound,
773 invert = not contIsTrue})
774 else
775 (logsi ("Can't unroll: bound doesn't dominate", depth) ;
776 ++boundDom ;
777 NONE)
778 end
779 end
780 end
781 end
782(* Check all of a loop's entry point arguments to see if a constant value.
783 Returns a list of int options where SOME(x) is always x for each entry. *)
784fun findConstantStart (entryArgs:
785 (((IntInf.t * IntInf.t) option) vector) vector)
786 : ((IntInf.t * IntInf.t) option) vector =
787 if (Vector.length entryArgs) > 0 then
788 Vector.rev (Vector.fold (entryArgs, Vector.sub (entryArgs, 0),
789 fn (v1, v2) => Vector.fromList (
790 Vector.fold2 (v1, v2, [], fn (a1, a2, lst) =>
791 case (a1, a2) of
792 (SOME(x1, x1'), SOME(x2, _)) =>
793 if x1 = x2 then SOME(x1, x1')::lst
794 else NONE::lst
795 | _ => NONE::lst))))
796 else Vector.new0 ()
797
798(* Look for any optimization opportunities in the loop. *)
799fun findOpportunity(functionBody: Block.t vector,
800 loopBody: Block.t vector,
801 loopHeaders: Block.t vector,
802 loadGlobal: Var.t * bool -> IntInf.t option,
803 domInfo: Label.t -> Label.t list,
804 depth: int):
805 (int * Block.t * Loop.t) option =
806 if (Vector.length loopHeaders) = 1 then
807 let
808 val header = Vector.sub (loopHeaders, 0)
809 val headerArgs = Block.args header
810 val headerLabel = Block.label header
811 val () = logsi (concat["Evaluating loop with header: ",
812 Label.toString headerLabel], depth - 1)
813 fun blockEquals (b1, b2) =
814 Label.equals (Block.label b1, Block.label b2)
815 val emptyArgs = SOME(Vector.new (Vector.length headerArgs, NONE))
816 val entryArgs = Vector.keepAllMap(functionBody, fn block =>
817 if Vector.contains (loopBody, block, blockEquals) then
818 NONE
819 else case Block.transfer block of
820 Transfer.Arith {success, ...} =>
821 if Label.equals (headerLabel, success) then
822 emptyArgs
823 else NONE
824 | Transfer.Call {return, ...} =>
825 (case return of
826 Return.NonTail {cont, ...} =>
827 if Label.equals (headerLabel, cont) then
828 emptyArgs
829 else NONE
830 | _ => NONE)
831 | Transfer.Case {cases, ...} =>
832 (case cases of
833 Cases.Con v =>
834 if Vector.exists (v, fn (_, lbl) =>
835 Label.equals (headerLabel, lbl)) then
836 emptyArgs
837 else
838 NONE
839 | Cases.Word (_, v) =>
840 if Vector.exists (v, fn (_, lbl) =>
841 Label.equals (headerLabel, lbl)) then
842 emptyArgs
843 else NONE)
844 | Transfer.Goto {args, dst} =>
845 if Label.equals (dst, headerLabel) then
846 SOME(Vector.map (args, fn a =>
847 case (loadGlobal(a, false),
848 loadGlobal(a, true))
849 of
850 (NONE, NONE) => NONE
851 | (SOME v1, SOME v2) => SOME (v1, v2)
852 | _ => raise Fail "Impossible"))
853 else NONE
854 | _ => NONE)
855 val () = logsi (concat["Loop has ",
856 Int.toString (Vector.length entryArgs),
857 " entry points"], depth - 1)
858 val constantArgs = findConstantStart entryArgs
859 val unrollableArgs =
860 Vector.keepAllMapi
861 (headerArgs, fn (i, arg) => (
862 logsi (concat["Checking arg: ", Var.toString (#1 arg)], depth) ;
863 checkArg (arg, i, Vector.sub (constantArgs, i),
864 header, loopBody, loadGlobal, domInfo, depth + 1)))
865 in
866 if (Vector.length unrollableArgs) > 0 then
867 SOME(Vector.sub (unrollableArgs, 0))
868 else NONE
869 end
870 else
871 (logsi ("Can't optimize: loop has more than 1 header", depth) ;
872 multiHeaders := (!multiHeaders) + 1 ;
873 NONE)
874
875fun makeHeader(oldHeader, (newVars, newStmts), newEntry) =
876 let
877 val oldArgs = Block.args oldHeader
878 val newArgs = Vector.map (oldArgs, fn (arg, _) => arg)
879 val newTransfer = Transfer.Goto {args = newArgs, dst = newEntry}
880 in
881 (Block.T {args = oldArgs,
882 label = Block.label oldHeader,
883 statements = Vector.fromList newStmts,
884 transfer = newTransfer},
885 newVars)
886 end
887
888(* Copy an entire loop. In the header, rewrite the transfer to take the loop branch.
889 In the transfers to the top of the loop, rewrite the transfer to goto next.
890 Ensure that the header is the first element in the list.
891 Replace all instances of argi with argVar *)
892fun copyLoop(blocks: Block.t vector,
893 nextLabel: Label.t,
894 headerLabel: Label.t,
895 tBlock: Block.t,
896 argi: int,
897 argVar: Var.t,
898 rewriteTransfer: bool,
899 blockInfo: Label.t -> BlockInfo,
900 setBlockInfo: Label.t * BlockInfo -> unit): Block.t vector =
901 let
902 val labels = Vector.map (blocks, Block.label)
903 (* Assign a new label for each block *)
904 val newBlocks = Vector.map (blocks, fn b =>
905 let
906 val oldName = Block.label b
907 val oldArgs = Block.args b
908 val newName = Label.newNoname()
909 val () = setBlockInfo(oldName, (newName, oldArgs))
910 in
911 Block.T {args = Block.args b,
912 label = newName,
913 statements = Block.statements b,
914 transfer = Block.transfer b}
915 end)
916 (* Rewrite the transfers of each block *)
917 val fixedBlocks = Vector.map
918 (newBlocks, fn Block.T {args, label, statements, transfer} =>
919 let
920 val f = fn l => fixLabel(blockInfo, l, labels)
921 val isHeader = Label.equals (label, f(headerLabel))
922 val (newArgs, unrolledArg) =
923 if isHeader then
924 (args, SOME(Vector.sub (args, argi)))
925 else (args, NONE)
926 val newStmts =
927 if isHeader then
928 case unrolledArg of
929 NONE => statements
930 | SOME(var, ty) =>
931 let
932 val assignExp = Exp.Var (argVar)
933 val assign = Statement.T {exp = assignExp,
934 ty = ty,
935 var = SOME(var)}
936 val assignV = Vector.new1(assign)
937 in
938 Vector.concat [assignV, statements]
939 end
940 else
941 statements
942 val newTransfer =
943 if rewriteTransfer andalso
944 Label.equals (label, f(Block.label tBlock))
945 then
946 let
947 val (_, contLabel, _) = loopExit(labels, transfer)
948 in
949 Transfer.Goto {args = Vector.new0 (), dst = f(contLabel)}
950 end
951 else
952 case transfer of
953 Transfer.Arith {args, overflow, prim, success, ty} =>
954 if Label.equals (success, headerLabel) then
955 Transfer.Arith {args = args,
956 overflow = f(overflow),
957 prim = prim,
958 success = nextLabel,
959 ty = ty}
960 else
961 Transfer.Arith {args = args,
962 overflow = f(overflow),
963 prim = prim,
964 success = f(success),
965 ty = ty}
966 | Transfer.Call {args, func, return} =>
967 let
968 val newReturn =
969 case return of
970 Return.NonTail {cont, handler} =>
971 let
972 val newHandler =
973 case handler of
974 Handler.Handle l => Handler.Handle(f(l))
975 | _ => handler
976 in
977 Return.NonTail {cont = f(cont), handler = newHandler}
978 end
979 | _ => return
980 in
981 Transfer.Call {args = args, func = func, return = newReturn}
982 end
983 | Transfer.Case {cases, default, test} =>
984 let
985 val newCases = Cases.map(cases, f)
986 val newDefault = case default of
987 NONE => default
988 | SOME(l) => SOME(f(l))
989 in
990 Transfer.Case {cases = newCases,
991 default = newDefault,
992 test = test}
993 end
994 | Transfer.Goto {args, dst} =>
995 if Label.equals (dst, headerLabel) then
996 Transfer.Goto {args = args, dst = nextLabel}
997 else
998 Transfer.Goto {args = args, dst = f(dst)}
999 | Transfer.Runtime {args, prim, return} =>
1000 Transfer.Runtime {args = args, prim = prim, return = f(return)}
1001 | _ => transfer
1002 in
1003 Block.T {args = newArgs,
1004 label = label,
1005 statements = newStmts,
1006 transfer = newTransfer}
1007 end)
1008 in
1009 Vector.rev fixedBlocks
1010 end
1011
1012(* Unroll a loop. The header should ALWAYS be the first element
1013 in the returned list. *)
1014fun unrollLoop (oldHeader, tBlock, argi, loopBlocks, argLabels,
1015 exit, rewriteTransfer, blockInfo, setBlockInfo) =
1016 let
1017 val oldHeaderLabel = Block.label oldHeader
1018 in
1019 case argLabels of
1020 [] => [exit]
1021 | hd::tl =>
1022 let
1023 val res = unrollLoop (oldHeader, tBlock, argi,
1024 loopBlocks, tl, exit, rewriteTransfer,
1025 blockInfo, setBlockInfo)
1026 val nextBlockLabel = Block.label (List.first res)
1027 val newLoop = copyLoop(loopBlocks, nextBlockLabel, oldHeaderLabel,
1028 tBlock, argi, hd, rewriteTransfer,
1029 blockInfo, setBlockInfo)
1030 in
1031 (Vector.toList newLoop) @ res
1032 end
1033 end
1034
1035(* Given:
1036 - an itertion count
1037 - a loop body
1038 Returns (b, x, y, z) such that:
1039 if b is true
1040 - unroll the loop completely
1041 - x, y, and z are undefined.
1042 if b is false
1043 - x is the number of times to expand the loop body
1044 - y is the number of iterations to run the expanded body (must never be 0)
1045 - z is the number of times to peel the loop body
1046 *)
1047fun shouldOptimize (iterCount, loopBlocks, depth) =
1048 let
1049 val loopSize' = Block.sizeV (loopBlocks, {sizeExp = Exp.size, sizeTransfer = Transfer.size})
1050 val loopSize = IntInf.fromInt loopSize'
1051 val unrollLimit = IntInf.fromInt (!Control.loopUnrollLimit)
1052 val () = logsi ("iterations * loop size < unroll factor = can total unroll",
1053 depth)
1054 val canTotalUnroll = (iterCount * loopSize) < unrollLimit
1055 val () = logsi (concat[IntInf.toString iterCount, " * ",
1056 IntInf.toString loopSize, " < ",
1057 IntInf.toString unrollLimit, " = ",
1058 Bool.toString canTotalUnroll], depth)
1059 in
1060 if (iterCount = 1) orelse canTotalUnroll then
1061 (* Loop runs once or it's small enough to unroll *)
1062 (true, 0, 0, 0)
1063 else if loopSize >= unrollLimit then
1064 (* Loop is too big to unroll at all, peel off 1 iteration *)
1065 (false, 1, iterCount - 1, 1)
1066 else
1067 let
1068 val exBodySize = unrollLimit div loopSize
1069 val exIters = iterCount div exBodySize
1070 val leftovers = iterCount - (exIters * exBodySize)
1071 in
1072 if (exIters - 1) < 2 then
1073 (* If the unpeeled loop would run 1 or 0 times, just unroll the
1074 whole thing *)
1075 (true, 0, 0, 0)
1076 else
1077 if leftovers = 0 then
1078 (* If we don't get any unpeelings naturally, force one *)
1079 (false, exBodySize, exIters - 1, exBodySize)
1080 else
1081 (* Otherwise stick them on the front of the loop *)
1082 (false, exBodySize, exIters, leftovers)
1083 end
1084 end
1085
1086fun expandLoop (oldHeader, loopBlocks, loop, tBlock, argi, argSize, oldArg,
1087 exBody, iterBody, exitLabel, blockInfo, setBlockInfo) =
1088 let
1089 (* Make a new loop header with an additional arg *)
1090 val newLoopEntry = Label.newNoname()
1091 val (newLoopHeader, loopArgLabels) =
1092 makeHeader (oldHeader,
1093 Loop.makeSteps (loop, argSize, oldArg, exBody),
1094 newLoopEntry)
1095 val iterVar = Var.newNoname ()
1096 val newLoopHeaderArgs' = Vector.concat
1097 [Block.args newLoopHeader,
1098 Vector.new1 (iterVar, Type.word argSize)]
1099 val newLoopHeader' =
1100 Block.T {args = newLoopHeaderArgs',
1101 label = Label.newNoname (),
1102 statements = Block.statements newLoopHeader,
1103 transfer = Block.transfer newLoopHeader}
1104
1105 (* Make a new goto to the top of the loop increasing the iter by 1 *)
1106 val loopHeaderGoto =
1107 let
1108 val (newVar, newVarStmts) = Loop.makeVarStmt (1, argSize, iterVar)
1109 val nonIterArgs = Vector.map (Block.args oldHeader, fn (a, _) => a)
1110 val newArgs = Vector.concat [nonIterArgs, Vector.new1 (newVar)]
1111 val newTransfer = Transfer.Goto {args = newArgs,
1112 dst = Block.label newLoopHeader'}
1113 in
1114 Block.T {args = Vector.new0 (),
1115 label = Label.newNoname (),
1116 statements = Vector.fromList newVarStmts,
1117 transfer = newTransfer}
1118 end
1119
1120 val newLoopExit =
1121 let
1122 val (newLimitVar, newLimitStmt) =
1123 Loop.makeConstStmt (iterBody - 1, argSize)
1124 val (newComp, newCompVar) =
1125 let
1126 val newVar = Var.newNoname ()
1127 val newTy = Type.datatypee Tycon.bool
1128 val newExp =
1129 PrimApp {args = Vector.new2 (iterVar, newLimitVar),
1130 prim = Prim.wordLt (argSize, {signed = true}),
1131 targs = Vector.new0 ()}
1132 in
1133 (Statement.T {exp = newExp,
1134 ty = newTy,
1135 var = SOME(newVar)},
1136 newVar)
1137 end
1138 val exitStatements = Vector.new2(newLimitStmt, newComp)
1139 val exitCases = Cases.Con (
1140 Vector.new1 (Con.fromBool true,
1141 Block.label loopHeaderGoto))
1142 val exitTransfer = Transfer.Case {cases = exitCases,
1143 default = SOME(exitLabel),
1144 test = newCompVar}
1145 in
1146 Block.T {args = Block.args oldHeader,
1147 label = Label.newNoname (),
1148 statements = exitStatements,
1149 transfer = exitTransfer}
1150 end
1151
1152 (* Expand the loop exBody times. Rewrite the bound's transfer,
1153 because we know it will always be true and it won't be eliminated
1154 by shrink. *)
1155 val newLoopBlocks = unrollLoop (oldHeader, tBlock, argi,
1156 loopBlocks, loopArgLabels, newLoopExit,
1157 true, blockInfo, setBlockInfo)
1158 val firstLoopBlock = List.first newLoopBlocks
1159 val loopArgs' = Block.args firstLoopBlock
1160 val loopStatements' = Block.statements firstLoopBlock
1161 val loopTransfer' = Block.transfer firstLoopBlock
1162 val newLoopHead = Block.T {args = loopArgs',
1163 label = newLoopEntry,
1164 statements = loopStatements',
1165 transfer = loopTransfer'}
1166 val newLoopBlocks' = newLoopHeader'::
1167 (loopHeaderGoto::
1168 (newLoopHead::
1169 (listPop newLoopBlocks)))
1170 in
1171 newLoopBlocks'
1172 end
1173
1174(* Attempt to optimize a single loop. Returns a list of blocks to add to the
1175 program and a list of blocks to remove from the program. *)
1176fun optimizeLoop(allBlocks, headerNodes, loopNodes,
1177 nodeBlock, loadGlobal, domInfo, depth) =
1178 let
1179 val () = ++loopCount
1180 val headers = Vector.map (headerNodes, nodeBlock)
1181 val loopBlocks = Vector.map (loopNodes, nodeBlock)
1182 val loopBlockNames = Vector.map (loopBlocks, Block.label)
1183 val optOpt = findOpportunity (allBlocks, loopBlocks, headers,
1184 loadGlobal, domInfo, depth + 1)
1185 val {get = blockInfo: Label.t -> BlockInfo,
1186 set = setBlockInfo: Label.t * BlockInfo -> unit, destroy} =
1187 Property.destGetSet(Label.plist,
1188 Property.initRaise("blockInfo", Label.layout))
1189 in
1190 case optOpt of
1191 NONE => ([], [])
1192 | SOME (argi, tBlock, loop) =>
1193 if Loop.isInfiniteLoop loop then
1194 (logsi ("Can't unroll: infinite loop", depth) ;
1195 ++infinite ;
1196 logsi (concat["Index: ", Int.toString argi, Loop.toString loop],
1197 depth) ;
1198 ([], []))
1199 else
1200 let
1201 val () = ++optCount
1202 val oldHeader = Vector.sub (headers, 0)
1203 val oldArgs = Block.args oldHeader
1204 val (oldArg, oldType) = Vector.sub (oldArgs, argi)
1205 val () = logsi (concat["Can unroll loop on ",
1206 Var.toString oldArg], depth)
1207 val () = logsi (concat["Index: ", Int.toString argi,
1208 Loop.toString loop], depth)
1209 val iterCount = Loop.iterCount loop
1210 val () = logsi (concat["Loop will run ",
1211 IntInf.toString iterCount,
1212 " times"], depth)
1213 val () = logsi (concat["Transfer block is ",
1214 Label.toString (Block.label tBlock)],
1215 depth)
1216 val () = Histogram.inc ((!histogram), iterCount)
1217 val (totalUnroll, exBody, iterBody, peel) =
1218 shouldOptimize (iterCount, loopBlocks, depth + 1)
1219 val argSize = case Type.dest oldType of
1220 Type.Word wsize => wsize
1221 | _ => raise Fail "Argument is not of type word"
1222 in
1223 if totalUnroll then
1224 let
1225 val () = ++total
1226 val () = logsi ("Completely unrolling loop", depth)
1227 val newEntry = Label.newNoname()
1228 val (newHeader, argLabels) =
1229 makeHeader (oldHeader,
1230 Loop.makeConstants (loop, argSize, iterCount+1),
1231 newEntry)
1232 val exitBlock = Block.T {args = oldArgs,
1233 label = Label.newNoname (),
1234 statements = Vector.new0 (),
1235 transfer = Transfer.Bug}
1236 (* For each induction variable value, copy the loop's body *)
1237 val newBlocks = unrollLoop (oldHeader, tBlock, argi,
1238 loopBlocks, argLabels, exitBlock,
1239 false, blockInfo, setBlockInfo)
1240 (* Fix the first entry's label *)
1241 val firstBlock = List.first newBlocks
1242 val args' = Block.args firstBlock
1243 val statements' = Block.statements firstBlock
1244 val transfer' = Block.transfer firstBlock
1245 val newHead = Block.T {args = args',
1246 label = newEntry,
1247 statements = statements',
1248 transfer = transfer'}
1249 val newBlocks' = newHeader::(newHead::(listPop newBlocks))
1250 val () = destroy()
1251 in
1252 (newBlocks', (Vector.toList loopBlockNames))
1253 end
1254 else
1255 let
1256 val () = ++partial
1257 val () = logsi ("Partially unrolling loop", depth)
1258 val () = logsi (concat["Body expansion: ",
1259 IntInf.toString exBody,
1260 " Body iterations: ",
1261 IntInf.toString iterBody,
1262 " Peel iterations: ",
1263 IntInf.toString peel],
1264 depth)
1265 val oldArgLabels = Vector.map (oldArgs, fn (a, _) => a)
1266 (* Produce an exit loop iteration. *)
1267 val exitEntry = Label.newNoname()
1268 val (exitHeader, exitConsts) =
1269 makeHeader (oldHeader,
1270 Loop.makeLastConstant (loop, argSize),
1271 exitEntry)
1272 val exitHeader' =
1273 Block.T {args = Block.args exitHeader,
1274 label = Label.newNoname (),
1275 statements = Block.statements exitHeader,
1276 transfer = Block.transfer exitHeader}
1277 val exitBlock = Block.T {args = oldArgs,
1278 label = Label.newNoname (),
1279 statements = Vector.new0 (),
1280 transfer = Transfer.Bug}
1281 val exitBlocks = unrollLoop (oldHeader, tBlock, argi,
1282 loopBlocks, exitConsts, exitBlock,
1283 false, blockInfo, setBlockInfo)
1284 val exitFirstBlock = List.first exitBlocks
1285 val exitArgs = Block.args exitFirstBlock
1286 val exitStatements = Block.statements exitFirstBlock
1287 val exitTransfer = Block.transfer exitFirstBlock
1288 val exitHead = Block.T {args = exitArgs,
1289 label = exitEntry,
1290 statements = exitStatements,
1291 transfer = exitTransfer}
1292 val exitGotoLabel = Label.newNoname()
1293 val exitGoto = Block.T {args = Vector.new0 (),
1294 label = exitGotoLabel,
1295 statements = Vector.new0 (),
1296 transfer =
1297 Transfer.Goto
1298 {args = oldArgLabels,
1299 dst = Block.label exitHeader'}}
1300 val exitBlocks' = exitGoto::
1301 exitHeader'::
1302 (exitHead::(listPop exitBlocks))
1303
1304 (* Expand the loop *)
1305 val exLoopBlocks = expandLoop (oldHeader, loopBlocks, loop,
1306 tBlock, argi, argSize,
1307 oldArg, exBody, iterBody,
1308 exitGotoLabel,
1309 blockInfo, setBlockInfo)
1310 (* Make an entry to the expanded loop *)
1311 val exLoopEntry =
1312 let
1313 val (zeroVar, zeroStmt) = Loop.makeConstStmt(0, argSize)
1314 val exLoopHeader = Block.label (List.first exLoopBlocks)
1315 val transferArgs =
1316 Vector.concat [oldArgLabels, Vector.new1(zeroVar)]
1317 val newTransfer = Transfer.Goto {args = transferArgs,
1318 dst = exLoopHeader}
1319 in
1320 Block.T {args = oldArgs,
1321 label = Label.newNoname(),
1322 statements = Vector.new1 zeroStmt,
1323 transfer = newTransfer}
1324 end
1325 (* Make a replacement loop entry *)
1326 val newEntry = Label.newNoname()
1327 val (newHeader, argLabels) =
1328 makeHeader (oldHeader,
1329 Loop.makeConstants (loop, argSize, peel),
1330 newEntry)
1331 (* For each induction variable value, copy the loop's body *)
1332 val newBlocks = unrollLoop (oldHeader, tBlock, argi,
1333 loopBlocks, argLabels, exLoopEntry,
1334 false, blockInfo, setBlockInfo)
1335 (* Fix the first entry's label *)
1336 val firstBlock = List.first newBlocks
1337 val args' = Block.args firstBlock
1338 val statements' = Block.statements firstBlock
1339 val transfer' = Block.transfer firstBlock
1340 val newHead = Block.T {args = args',
1341 label = newEntry,
1342 statements = statements',
1343 transfer = transfer'}
1344 val newBlocks' = newHeader::(newHead::(listPop newBlocks))
1345 val () = destroy()
1346 in
1347 (newBlocks' @ exLoopBlocks @ exitBlocks',
1348 (Vector.toList loopBlockNames))
1349 end
1350 end
1351 end
1352
1353(* Traverse sub-forests until the innermost loop is found. *)
1354fun traverseSubForest ({loops, notInLoop},
1355 allBlocks,
1356 enclosingHeaders,
1357 labelNode, nodeBlock, loadGlobal, domInfo) =
1358 if (Vector.length loops) = 0 then
1359 optimizeLoop(allBlocks, enclosingHeaders, notInLoop,
1360 nodeBlock, loadGlobal, domInfo, 1)
1361 else
1362 Vector.fold(loops, ([], []), fn (loop, (new, remove)) =>
1363 let
1364 val (nBlocks, rBlocks) =
1365 traverseLoop(loop, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo)
1366 in
1367 ((new @ nBlocks), (remove @ rBlocks))
1368 end)
1369
1370(* Traverse loops in the loop forest. *)
1371and traverseLoop ({headers, child},
1372 allBlocks,
1373 labelNode, nodeBlock, loadGlobal, domInfo) =
1374 traverseSubForest ((Forest.dest child), allBlocks,
1375 headers, labelNode, nodeBlock, loadGlobal, domInfo)
1376
1377(* Traverse the top-level loop forest. *)
1378fun traverseForest ({loops, notInLoop = _}, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo) =
1379 let
1380 (* Gather the blocks to add/remove *)
1381 val (newBlocks, blocksToRemove) =
1382 Vector.fold(loops, ([], []), fn (loop, (new, remove)) =>
1383 let
1384 val (nBlocks, rBlocks) =
1385 traverseLoop(loop, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo)
1386 in
1387 ((new @ nBlocks), (remove @ rBlocks))
1388 end)
1389 val keep: Block.t -> bool =
1390 (fn b => not (List.contains(blocksToRemove, (Block.label b), Label.equals)))
1391 val reducedBlocks = Vector.keepAll(allBlocks, keep)
1392 in
1393 (Vector.toList reducedBlocks) @ newBlocks
1394 end
1395
1396fun setDoms tree =
1397 let
1398 val {get = domInfo: Label.t -> Label.t list,
1399 set = setDomInfo: Label.t * Label.t list -> unit, destroy} =
1400 Property.destGetSet(Label.plist,
1401 Property.initRaise("domInfo", Label.layout))
1402 fun loop (tree, doms) =
1403 case tree of
1404 Tree.T (block, children) =>
1405 (setDomInfo (Block.label block, doms) ;
1406 Vector.foreach (children, fn tree => loop(tree,
1407 (Block.label block)::doms)))
1408 val () = loop (tree, [])
1409 in
1410 (domInfo, destroy)
1411 end
1412
1413(* Performs the optimization on the body of a single function. *)
1414fun optimizeFunction loadGlobal function =
1415 let
1416 val {graph, labelNode, nodeBlock} = Function.controlFlow function
1417 val {args, blocks, mayInline, name, raises, returns, start} =
1418 Function.dest function
1419 val fsize = Function.size (function, {sizeExp = Exp.size, sizeTransfer = Transfer.size})
1420 val () = logs (concat["Optimizing function: ", Func.toString name,
1421 " of size ", Int.toString fsize])
1422 val root = labelNode start
1423 val forest = Graph.loopForestSteensgaard(graph, {root = root})
1424 val dtree = Function.dominatorTree function
1425 val (domInfo, destroy) = setDoms dtree
1426 val newBlocks = traverseForest((Forest.dest forest),
1427 blocks, labelNode, nodeBlock, loadGlobal, domInfo)
1428 val () = destroy()
1429 in
1430 Function.new {args = args,
1431 blocks = Vector.fromList(newBlocks),
1432 mayInline = mayInline,
1433 name = name,
1434 raises = raises,
1435 returns = returns,
1436 start = start}
1437 end
1438
1439(* Entry point. *)
1440fun transform (Program.T {datatypes, globals, functions, main}) =
1441 let
1442 fun loadGlobal (var: Var.t, signed: bool): IntInf.t option =
1443 let
1444 fun matchGlobal v g =
1445 case Statement.var g of
1446 NONE => false
1447 | SOME (v') => Var.equals (v, v')
1448 in
1449 case Vector.peek (globals, matchGlobal var) of
1450 NONE => NONE
1451 | SOME (stmt) =>
1452 (case Statement.exp stmt of
1453 Exp.Const c =>
1454 (case c of
1455 Const.Word w =>
1456 if signed then
1457 SOME(WordX.toIntInfX w)
1458 else
1459 SOME(WordX.toIntInf w)
1460 | _ => NONE)
1461 | _ => NONE)
1462 end
1463 val () = loopCount := 0
1464 val () = total := 0
1465 val () = partial := 0
1466 val () = optCount := 0
1467 val () = multiHeaders := 0
1468 val () = varEntryArg := 0
1469 val () = variantTransfer := 0
1470 val () = unsupported := 0
1471 val () = ccTransfer := 0
1472 val () = varBound := 0
1473 val () = infinite := 0
1474 val () = boundDom := 0
1475 val () = histogram := Histogram.new ()
1476 val () = logs (concat["Unrolling loops. Unrolling factor = ",
1477 Int.toString (!Control.loopUnrollLimit)])
1478 val optimizedFunctions = List.map (functions, optimizeFunction loadGlobal)
1479 val restore = restoreFunction {globals = globals}
1480 val () = logs "Performing SSA restore"
1481 val cleanedFunctions = List.map (optimizedFunctions, restore)
1482 val shrink = shrinkFunction {globals = globals}
1483 val () = logs "Performing shrink"
1484 val shrunkFunctions = List.map (cleanedFunctions, shrink)
1485 val () = logstat (loopCount,
1486 "total innermost loops")
1487 val () = logstat (optCount,
1488 "loops optimized")
1489 val () = logstat (total,
1490 "loops completely unrolled")
1491 val () = logstat (partial,
1492 "loops partially unrolled")
1493 val () = logstat (multiHeaders,
1494 "loops had multiple headers")
1495 val () = logstat (varEntryArg,
1496 "variable entry values")
1497 val () = logstat (variantTransfer,
1498 "loops had variant transfers to the header")
1499 val () = logstat (unsupported,
1500 "loops had unsupported transfers to the header")
1501 val () = logstat (ccTransfer,
1502 "loops had non-computable steps")
1503 val () = logstat (varBound,
1504 "loops had variable bounds")
1505 val () = logstat (infinite,
1506 "infinite loops")
1507 val () = logstat (boundDom,
1508 "loops had non-dominating bounds")
1509 val () = logs ("Iterations: Occurences")
1510 val () = logs (Histogram.toString (!histogram))
1511 val () = logs "Done."
1512 in
1513 Program.T {datatypes = datatypes,
1514 globals = globals,
1515 functions = shrunkFunctions,
1516 main = main}
1517 end
1518
1519end