| 1 | (* Copyright (C) 2016 Matthew Surawski. |
| 2 | * |
| 3 | * MLton is released under a BSD-style license. |
| 4 | * See the file MLton-LICENSE for details. |
| 5 | *) |
| 6 | |
| 7 | (* Reduces or eliminates the iteration count of loops by duplicating |
| 8 | * the loop body. |
| 9 | *) |
| 10 | functor LoopUnroll (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM = |
| 11 | struct |
| 12 | |
| 13 | open S |
| 14 | open Exp Transfer Prim |
| 15 | |
| 16 | structure Graph = DirectedGraph |
| 17 | local |
| 18 | open Graph |
| 19 | in |
| 20 | structure Forest = LoopForest |
| 21 | end |
| 22 | |
| 23 | fun ++ (v: int ref): unit = |
| 24 | v := (!v) + 1 |
| 25 | |
| 26 | |
| 27 | structure Histogram = |
| 28 | struct |
| 29 | type t = (IntInf.t * int ref) HashSet.t |
| 30 | |
| 31 | fun inc (set: t, key: IntInf.t): unit = |
| 32 | let |
| 33 | val _ = HashSet.insertIfNew (set, IntInf.hash key, |
| 34 | (fn (k, _) => k = key), |
| 35 | (fn () => (key, ref 1)), |
| 36 | (fn (_, r) => ++r)) |
| 37 | in |
| 38 | () |
| 39 | end |
| 40 | |
| 41 | fun new (): t = |
| 42 | HashSet.new {hash = fn (k, _) => IntInf.hash k} |
| 43 | |
| 44 | fun toList (set: t): (IntInf.t * int ref) list = |
| 45 | HashSet.toList set |
| 46 | |
| 47 | fun toString (set: t) : string = |
| 48 | let |
| 49 | val eles = toList set |
| 50 | in |
| 51 | List.fold (eles, "", fn ((k, r), s) => concat[s, |
| 52 | IntInf.toString k, ": ", |
| 53 | Int.toString (!r), "\n"]) |
| 54 | end |
| 55 | end |
| 56 | |
| 57 | val loopCount = ref 0 |
| 58 | val optCount = ref 0 |
| 59 | val total = ref 0 |
| 60 | val partial = ref 0 |
| 61 | val multiHeaders = ref 0 |
| 62 | val varEntryArg = ref 0 |
| 63 | val variantTransfer = ref 0 |
| 64 | val unsupported = ref 0 |
| 65 | val ccTransfer = ref 0 |
| 66 | val varBound = ref 0 |
| 67 | val infinite = ref 0 |
| 68 | val boundDom = ref 0 |
| 69 | val histogram = ref (Histogram.new ()) |
| 70 | |
| 71 | type BlockInfo = Label.t * (Var.t * Type.t) vector |
| 72 | |
| 73 | structure Loop = |
| 74 | struct |
| 75 | datatype Bound = Eq of IntInf.t | Lt of IntInf.t | Gt of IntInf.t |
| 76 | type Start = IntInf.t |
| 77 | type Step = IntInf.t |
| 78 | datatype t = T of {start: Start, step: Step, bound: Bound, invert: bool} |
| 79 | |
| 80 | fun toString (T {start, step, bound, invert}): string = |
| 81 | let |
| 82 | val boundStr = case bound of |
| 83 | Eq b => if invert then |
| 84 | concat ["!= ", IntInf.toString b] |
| 85 | else |
| 86 | concat ["= ", IntInf.toString b] |
| 87 | | Lt b => if invert then |
| 88 | concat ["!< ", IntInf.toString b] |
| 89 | else |
| 90 | concat ["< ", IntInf.toString b] |
| 91 | | Gt b => if invert then |
| 92 | concat ["!> ", IntInf.toString b] |
| 93 | else |
| 94 | concat ["> ", IntInf.toString b] |
| 95 | in |
| 96 | concat[" Start: ", IntInf.toString start, |
| 97 | " Step: ", IntInf.toString step, |
| 98 | " Bound: ", boundStr] |
| 99 | end |
| 100 | |
| 101 | fun isInfiniteLoop (T {start, step, bound, invert}): bool = |
| 102 | case bound of |
| 103 | Eq b => |
| 104 | if invert then |
| 105 | (if start = b then |
| 106 | false |
| 107 | else if start < b andalso step > 0 then |
| 108 | not (((b - start) mod step) = 0) |
| 109 | else if start > b andalso step < 0 then |
| 110 | not (((start - b) mod (~step)) = 0) |
| 111 | else |
| 112 | true) |
| 113 | else |
| 114 | step = 0 |
| 115 | | Lt b => |
| 116 | if invert then |
| 117 | start >= b andalso step >= 0 |
| 118 | else |
| 119 | start < b andalso step <= 0 |
| 120 | | Gt b => |
| 121 | if invert then |
| 122 | start <= b andalso step <= 0 |
| 123 | else |
| 124 | start > b andalso step >= 0 |
| 125 | |
| 126 | |
| 127 | fun iters (start: IntInf.t, step: IntInf.t, max: IntInf.t): IntInf.t = |
| 128 | let |
| 129 | val range = max - start |
| 130 | val iters = range div step |
| 131 | val adds = range mod step |
| 132 | in |
| 133 | if step > range then |
| 134 | 1 |
| 135 | else |
| 136 | iters + adds |
| 137 | end |
| 138 | |
| 139 | (* Assumes isInfiniteLoop is false, otherwise the result is undefined. *) |
| 140 | fun iterCount (T {start, step, bound, invert}): IntInf.t = |
| 141 | case bound of |
| 142 | Eq b => |
| 143 | if invert then |
| 144 | (b - start) div step |
| 145 | else |
| 146 | 1 |
| 147 | | Lt b => |
| 148 | (case (start >= b, invert) of |
| 149 | (true, false) => 0 |
| 150 | | (true, true) => iters (b - 1, ~step, start) |
| 151 | | (false, true) => 0 |
| 152 | | (false, false) => iters (start, step, b)) |
| 153 | | Gt b => |
| 154 | (case (start <= b, invert) of |
| 155 | (true, false) => 0 |
| 156 | | (true, true) => iters (start, step, b + 1) |
| 157 | | (false, true) => 0 |
| 158 | | (false, false) => iters (b, ~step, start)) |
| 159 | |
| 160 | fun makeConstStmt (v: IntInf.t, wsize: WordSize.t): Var.t * Statement.t = |
| 161 | let |
| 162 | val newWord = WordX.fromIntInf (v, wsize) |
| 163 | val newConst = Const.word newWord |
| 164 | val newExp = Exp.Const (newConst) |
| 165 | val newType = Type.word wsize |
| 166 | val newVar = Var.newNoname() |
| 167 | val newStatement = Statement.T {exp = newExp, |
| 168 | ty = newType, |
| 169 | var = SOME(newVar)} |
| 170 | in |
| 171 | (newVar, newStatement) |
| 172 | end |
| 173 | |
| 174 | fun makeVarStmt (v: IntInf.t, wsize: WordSize.t, var: Var.t) |
| 175 | : Var.t * Statement.t list = |
| 176 | let |
| 177 | val (cVar, cStmt) = makeConstStmt (v, wsize) |
| 178 | val newExp = Exp.PrimApp {args = Vector.new2 (var, cVar), |
| 179 | prim = Prim.wordAdd wsize, |
| 180 | targs = Vector.new0 ()} |
| 181 | val newType = Type.word wsize |
| 182 | val newVar = Var.newNoname() |
| 183 | val newStatement = Statement.T {exp = newExp, |
| 184 | ty = newType, |
| 185 | var = SOME(newVar)} |
| 186 | in |
| 187 | (newVar, [cStmt, newStatement]) |
| 188 | end |
| 189 | |
| 190 | (* Given: |
| 191 | - a loop |
| 192 | - a word size |
| 193 | Returns: |
| 194 | A variable and statement for the constant value after the loops final |
| 195 | iteration. This value will make the loop exit. |
| 196 | Assumes isInfiniteLoop is false, otherwise the result is undefined. |
| 197 | *) |
| 198 | fun makeLastConstant (T {start, step, bound, invert}, |
| 199 | wsize: WordSize.t) |
| 200 | : Var.t list * Statement.t list = |
| 201 | let |
| 202 | val ic = iterCount (T {start = start, |
| 203 | step = step, |
| 204 | bound = bound, |
| 205 | invert = invert}) |
| 206 | val last = start + (step * ic) |
| 207 | val (newVar, newStatement) = makeConstStmt(last, wsize) |
| 208 | in |
| 209 | ([newVar], [newStatement]) |
| 210 | end |
| 211 | |
| 212 | (* Given: |
| 213 | - a loop |
| 214 | - a word size |
| 215 | - an iteration limit |
| 216 | Returns: |
| 217 | A pair of variables and statements for those variables |
| 218 | for each iteration of the loop. |
| 219 | This should go 1 step beyond the end of the loop. |
| 220 | Assumes isInfiniteLoop is false, otherwise this will run forever. *) |
| 221 | fun makeConstants (T {start, step, bound, invert}, |
| 222 | wsize: WordSize.t, |
| 223 | limit: IntInf.t) |
| 224 | : Var.t list * Statement.t list = |
| 225 | case bound of |
| 226 | Eq b => |
| 227 | if (start = b) <> invert andalso limit > 0 then |
| 228 | let |
| 229 | val (newVar, newStatement) = makeConstStmt(start, wsize) |
| 230 | val nextIter = T {start = start + step, |
| 231 | step = step, |
| 232 | bound = bound, |
| 233 | invert = invert} |
| 234 | val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1) |
| 235 | in |
| 236 | (newVar::rVars, newStatement::rStmts) |
| 237 | end |
| 238 | else if limit > 0 then |
| 239 | let |
| 240 | val (newVar, newStatement) = makeConstStmt(start, wsize) |
| 241 | in |
| 242 | ([newVar], [newStatement]) |
| 243 | end |
| 244 | else |
| 245 | ([], []) |
| 246 | | Lt b => |
| 247 | if (start < b) <> invert andalso limit > 0 then |
| 248 | let |
| 249 | val (newVar, newStatement) = makeConstStmt(start, wsize) |
| 250 | val nextIter = T {start = start + step, |
| 251 | step = step, |
| 252 | bound = bound, |
| 253 | invert = invert} |
| 254 | val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1) |
| 255 | in |
| 256 | (newVar::rVars, newStatement::rStmts) |
| 257 | end |
| 258 | else if limit > 0 then |
| 259 | let |
| 260 | val (newVar, newStatement) = makeConstStmt(start, wsize) |
| 261 | in |
| 262 | ([newVar], [newStatement]) |
| 263 | end |
| 264 | else |
| 265 | ([], []) |
| 266 | | Gt b => |
| 267 | if (start > b) <> invert andalso limit > 0 then |
| 268 | let |
| 269 | val (newVar, newStatement) = makeConstStmt(start, wsize) |
| 270 | val nextIter = T {start = start + step, |
| 271 | step = step, |
| 272 | bound = bound, |
| 273 | invert = invert} |
| 274 | val (rVars, rStmts) = makeConstants (nextIter, wsize, limit - 1) |
| 275 | in |
| 276 | (newVar::rVars, newStatement::rStmts) |
| 277 | end |
| 278 | else if limit > 0 then |
| 279 | let |
| 280 | val (newVar, newStatement) = makeConstStmt(start, wsize) |
| 281 | in |
| 282 | ([newVar], [newStatement]) |
| 283 | end |
| 284 | else |
| 285 | ([], []) |
| 286 | |
| 287 | fun makeStepsRec (step: Step, wsize: WordSize.t, |
| 288 | var: Var.t, times: IntInf.t, |
| 289 | vList: Var.t list, sList: Statement.t list) |
| 290 | : Var.t list * Statement.t list = |
| 291 | if times > 0 then |
| 292 | let |
| 293 | val stepAdd = step * (times - 1) |
| 294 | val (newVar, newStatements) = makeVarStmt(stepAdd, wsize, var) |
| 295 | in |
| 296 | makeStepsRec (step, wsize, var, times - 1, |
| 297 | newVar::vList, newStatements @ sList) |
| 298 | end |
| 299 | else (vList, sList) |
| 300 | |
| 301 | fun makeSteps (T {step, ...}, |
| 302 | wsize: WordSize.t, |
| 303 | var: Var.t, |
| 304 | times: IntInf.t) |
| 305 | : Var.t list * Statement.t list = |
| 306 | makeStepsRec (step, wsize, var, times, [], []) |
| 307 | |
| 308 | end |
| 309 | |
| 310 | fun logli (l: Layout.t, i: int): unit = |
| 311 | Control.diagnostics |
| 312 | (fn display => |
| 313 | display(Layout.indent(l, i * 2))) |
| 314 | |
| 315 | fun logsi (s: string, i: int): unit = |
| 316 | logli((Layout.str s), i) |
| 317 | |
| 318 | fun logs (s: string): unit = |
| 319 | logsi(s, 0) |
| 320 | |
| 321 | fun logstat (x: int ref, s: string): unit = |
| 322 | logs (concat[Int.toString(!x), " ", s]) |
| 323 | |
| 324 | fun listPop lst = |
| 325 | case lst of |
| 326 | [] => [] |
| 327 | | _::tl => tl |
| 328 | |
| 329 | (* If a block was renamed, return the new name. |
| 330 | Otherwise return the old name. *) |
| 331 | fun fixLabel (getBlockInfo: Label.t -> BlockInfo, |
| 332 | label: Label.t, |
| 333 | origLabels: Label.t vector): Label.t = |
| 334 | if Vector.contains(origLabels, label, Label.equals) then |
| 335 | let |
| 336 | val (name, _) = getBlockInfo(label) |
| 337 | in |
| 338 | name |
| 339 | end |
| 340 | else |
| 341 | label |
| 342 | |
| 343 | fun varOptEquals (v1: Var.t, v2: Var.t option): bool = |
| 344 | case v2 of |
| 345 | NONE => false |
| 346 | | SOME (v2') => Var.equals (v1, v2') |
| 347 | |
| 348 | (* For an binary operation where one argument is a constant, |
| 349 | load that constant. |
| 350 | Returns the variable, the constant, and true if the var was the first arg *) |
| 351 | fun varConst (args, loadVar, signed): (Var.t * IntInf.int * bool) option = |
| 352 | let |
| 353 | val a1 = Vector.sub (args, 0) |
| 354 | val a2 = Vector.sub (args, 1) |
| 355 | val a1v = loadVar(a1, signed) |
| 356 | val a2v = loadVar(a2, signed) |
| 357 | in |
| 358 | case (a1v, a2v) of |
| 359 | (SOME x, NONE) => SOME (a2, x, false) |
| 360 | | (NONE, SOME x) => SOME (a1, x, true) |
| 361 | | _ => NONE |
| 362 | end |
| 363 | |
| 364 | (* Given: |
| 365 | - an argument vector with two arguments |
| 366 | - a primative operaton that is an addition or subtraction of a const value |
| 367 | - a function from variables to their constant values |
| 368 | Returns: |
| 369 | - The non-const variable and the constant value in terms of addition *) |
| 370 | fun checkPrim (args, prim, loadVar) = |
| 371 | case Prim.name prim of |
| 372 | Name.Word_add _ => |
| 373 | (case varConst(args, loadVar, false) of |
| 374 | SOME(nextVar, x, _) => SOME (nextVar, x) |
| 375 | | NONE => NONE) |
| 376 | | Name.Word_addCheck (_, {signed}) => |
| 377 | (case varConst(args, loadVar, signed) of |
| 378 | SOME(nextVar, x, _) => SOME(nextVar, x) |
| 379 | | NONE => NONE) |
| 380 | | Name.Word_sub _ => |
| 381 | (case varConst(args, loadVar, false) of |
| 382 | SOME(nextVar, x, _) => SOME (nextVar, ~x) |
| 383 | | NONE => NONE) |
| 384 | | Name.Word_subCheck (_, {signed}) => |
| 385 | (case varConst(args, loadVar, signed) of |
| 386 | SOME(nextVar, x, _) => SOME (nextVar, ~x) |
| 387 | | NONE => NONE) |
| 388 | | _ => NONE |
| 389 | |
| 390 | (* Given: |
| 391 | - a variable in the loop |
| 392 | - another variable in the loop |
| 393 | - the loop body |
| 394 | - a function from variables to their constant values |
| 395 | - a starting value, if the transfer to the header is an arith transfer |
| 396 | Returns: |
| 397 | - Some x such that the value of origVar in loop iteration i+1 is equal to |
| 398 | (the value of origVar in iteration i) + x, |
| 399 | or None if the step couldn't be computed *) |
| 400 | fun varChain (origVar, endVar, blocks, loadVar, total) = |
| 401 | case Var.equals (origVar, endVar) of |
| 402 | true => SOME (total) |
| 403 | | false => |
| 404 | let |
| 405 | val endVarAssign = Vector.peekMap (blocks, fn b => |
| 406 | let |
| 407 | val stmts = Block.statements b |
| 408 | val assignments = Vector.keepAllMap (stmts, fn s => |
| 409 | case varOptEquals (endVar, Statement.var s) of |
| 410 | false => NONE |
| 411 | | true => |
| 412 | (case Statement.exp s of |
| 413 | Exp.PrimApp {args, prim, ...} => |
| 414 | checkPrim (args, prim, loadVar) |
| 415 | | _ => NONE)) |
| 416 | val label = Block.label b |
| 417 | val blockArgs = Block.args b |
| 418 | (* If we found the assignment or the block isn't unary, |
| 419 | skip this step *) |
| 420 | val arithTransfers = |
| 421 | if ((Vector.length assignments) > 0) orelse |
| 422 | ((Vector.length blockArgs) <> 1) |
| 423 | then |
| 424 | Vector.new0 () |
| 425 | else |
| 426 | let |
| 427 | val (blockArg, _) = Vector.sub (blockArgs, 0) |
| 428 | val blockEntrys = Vector.keepAllMap (blocks, fn b' => |
| 429 | case Block.transfer b' of |
| 430 | Transfer.Arith {args, prim, success, ...} => |
| 431 | if Label.equals (label, success) then |
| 432 | SOME(checkPrim(args, prim, loadVar)) |
| 433 | else NONE |
| 434 | | Transfer.Call {return, ...} => |
| 435 | (case return of |
| 436 | Return.NonTail {cont, ...} => |
| 437 | if Label.equals (label, cont) then |
| 438 | SOME(NONE) |
| 439 | else NONE |
| 440 | | _ => NONE) |
| 441 | | Transfer.Case {cases, ...} => |
| 442 | (case cases of |
| 443 | Cases.Con v => |
| 444 | if Vector.exists (v, fn (_, lbl) => |
| 445 | Label.equals (label, lbl)) then |
| 446 | SOME(NONE) |
| 447 | else |
| 448 | NONE |
| 449 | | Cases.Word (_, v) => |
| 450 | if Vector.exists (v, fn (_, lbl) => |
| 451 | Label.equals (label, lbl)) then |
| 452 | SOME(NONE) |
| 453 | else NONE) |
| 454 | | Transfer.Goto {args, dst} => |
| 455 | if Label.equals (label, dst) then |
| 456 | SOME(SOME(Vector.sub (args, 0), 0)) |
| 457 | else NONE |
| 458 | | _ => NONE) |
| 459 | in |
| 460 | if Var.equals (endVar, blockArg) then |
| 461 | blockEntrys |
| 462 | else |
| 463 | Vector.new0 () |
| 464 | end |
| 465 | val assignments' = |
| 466 | if Vector.length (arithTransfers) > 0 then |
| 467 | case (Vector.fold (arithTransfers, |
| 468 | Vector.sub (arithTransfers, 0), |
| 469 | fn (trans, trans') => |
| 470 | case (trans, trans') of |
| 471 | (SOME(a1, v1), SOME(a2, v2)) => |
| 472 | if Var.equals (a1, a2) andalso |
| 473 | v1 = v2 then |
| 474 | trans |
| 475 | else |
| 476 | NONE |
| 477 | | _ => NONE)) of |
| 478 | SOME(a, v) => Vector.new1 (a, v) |
| 479 | | NONE => assignments |
| 480 | else |
| 481 | assignments |
| 482 | in |
| 483 | case Vector.length assignments' of |
| 484 | 0 => NONE |
| 485 | | 1 => SOME (Vector.sub (assignments', 0)) |
| 486 | | _ => raise Fail "Multiple assignments in SSA form!" |
| 487 | end) |
| 488 | in |
| 489 | case endVarAssign of |
| 490 | NONE => NONE |
| 491 | | SOME (nextVar, x) => |
| 492 | varChain(origVar, nextVar, blocks, loadVar, x + total) |
| 493 | end |
| 494 | |
| 495 | (* Given: |
| 496 | - a list of loop body labels |
| 497 | - a transfer on a boolean value where one branch exits the loop |
| 498 | and the other continues |
| 499 | Returns: |
| 500 | - the label that exits the loop |
| 501 | - the label that continues the loop |
| 502 | - true if the continue branch is the true branch |
| 503 | *) |
| 504 | fun loopExit (loopLabels: Label.t vector, transfer: Transfer.t) |
| 505 | : (Label.t * Label.t * bool) = |
| 506 | case transfer of |
| 507 | (* This should be a case statement on a boolean, |
| 508 | so all dsts should be unary. |
| 509 | One should transfer outside the loop, the other inside. *) |
| 510 | Transfer.Case {cases, default, ...} => |
| 511 | (case default of |
| 512 | SOME(defaultLabel) => |
| 513 | let |
| 514 | val (caseCon, caseLabel) = |
| 515 | case cases of |
| 516 | Cases.Con v => Vector.sub (v, 0) |
| 517 | | _ => raise Fail "This should be a con" |
| 518 | in |
| 519 | if Vector.contains (loopLabels, defaultLabel, Label.equals) then |
| 520 | (caseLabel, |
| 521 | defaultLabel, |
| 522 | Con.equals (Con.fromBool false, caseCon)) |
| 523 | else |
| 524 | (defaultLabel, |
| 525 | caseLabel, |
| 526 | Con.equals (Con.fromBool true, caseCon)) |
| 527 | end |
| 528 | | NONE => |
| 529 | (case cases of |
| 530 | Cases.Con v => |
| 531 | let |
| 532 | val (c1, d1) = Vector.sub (v, 0) |
| 533 | val (c2, d2) = Vector.sub (v, 1) |
| 534 | in |
| 535 | if Vector.contains (loopLabels, d1, Label.equals) then |
| 536 | (d2, d1, Con.equals (Con.fromBool true, c1)) |
| 537 | else |
| 538 | (d1, d2, Con.equals (Con.fromBool true, c2)) |
| 539 | end |
| 540 | | _ => raise Fail "This should be a con")) |
| 541 | |
| 542 | | _ => raise Fail "This should be a case statement" |
| 543 | |
| 544 | fun isLoopBranch (loopLabels, cases, default) = |
| 545 | case default of |
| 546 | SOME (defaultLabel) => |
| 547 | (case cases of |
| 548 | Cases.Con v => |
| 549 | if (Vector.length v) = 1 then |
| 550 | let |
| 551 | val (_, caseLabel) = Vector.sub (v, 0) |
| 552 | val defaultInLoop = |
| 553 | Vector.contains (loopLabels, defaultLabel, Label.equals) |
| 554 | val caseInLoop = |
| 555 | Vector.contains (loopLabels, caseLabel, Label.equals) |
| 556 | in |
| 557 | defaultInLoop <> caseInLoop |
| 558 | end |
| 559 | else |
| 560 | false |
| 561 | | _ => false) |
| 562 | | NONE => |
| 563 | (case cases of |
| 564 | Cases.Con v => |
| 565 | if (Vector.length v) = 2 then |
| 566 | let |
| 567 | val (_, c1) = Vector.sub (v, 0) |
| 568 | val (_, c2) = Vector.sub (v, 1) |
| 569 | val c1il = Vector.contains (loopLabels, c1, Label.equals) |
| 570 | val c2il = Vector.contains (loopLabels, c2, Label.equals) |
| 571 | in |
| 572 | c1il <> c2il |
| 573 | end |
| 574 | else |
| 575 | false |
| 576 | | _ => false) |
| 577 | |
| 578 | fun transfersToHeader (headerLabel, block) = |
| 579 | case Block.transfer block of |
| 580 | Transfer.Arith {success, ...} => |
| 581 | Label.equals (headerLabel, success) |
| 582 | | Transfer.Call {return, ...} => |
| 583 | (case return of |
| 584 | Return.NonTail {handler, ...} => |
| 585 | (case handler of |
| 586 | Handler.Handle l => Label.equals (headerLabel, l) |
| 587 | | _ => false) |
| 588 | | _ => false) |
| 589 | | Transfer.Case {cases, ...} => |
| 590 | (* We don't have to check default because we know the header isn't nullary *) |
| 591 | (case cases of |
| 592 | Cases.Con v => |
| 593 | Vector.exists (v, (fn (_, lbl) => Label.equals (headerLabel, lbl))) |
| 594 | | Cases.Word (_, v) => |
| 595 | Vector.exists (v, (fn (_, lbl) => Label.equals (headerLabel, lbl)))) |
| 596 | | Transfer.Goto {dst, ...} => |
| 597 | Label.equals (headerLabel, dst) |
| 598 | | Transfer.Runtime {return, ...} => |
| 599 | Label.equals(headerLabel, return) |
| 600 | | _ => false |
| 601 | |
| 602 | (* Given: |
| 603 | - a loop phi variable |
| 604 | - that variables index in the loop header's arguments |
| 605 | - that variables constant entry value (if it has one) |
| 606 | - the loop header block |
| 607 | - the loop body block |
| 608 | - a function from variables to their constant values |
| 609 | Returns: |
| 610 | - a Loop structure for unrolling that phi var, if one exists *) |
| 611 | fun checkArg ((argVar, _), argIndex, entryArg, header, loopBody, |
| 612 | loadVar: Var.t * bool -> IntInf.t option, domInfo, depth) = |
| 613 | case entryArg of |
| 614 | NONE => (logsi ("Can't unroll: entry arg not constant", depth) ; |
| 615 | ++varEntryArg ; |
| 616 | NONE) |
| 617 | | SOME (entryX, entryXSigned) => |
| 618 | let |
| 619 | val headerLabel = Block.label header |
| 620 | val unsupportedTransfer = ref false |
| 621 | |
| 622 | (* For every transfer to the start of the loop |
| 623 | get the variable at argIndex *) |
| 624 | val loopVars = Vector.keepAllMap (loopBody, fn block => |
| 625 | case Block.transfer block of |
| 626 | Transfer.Arith {args, prim, success, ...} => |
| 627 | if Label.equals (headerLabel, success) then |
| 628 | case checkPrim (args, prim, loadVar) of |
| 629 | NONE => (unsupportedTransfer := true ; NONE) |
| 630 | | SOME (arg, x) => SOME (arg, x) |
| 631 | else NONE |
| 632 | | Transfer.Call {return, ...} => |
| 633 | (case return of |
| 634 | Return.NonTail {cont, ...} => |
| 635 | if Label.equals (headerLabel, cont) then |
| 636 | (unsupportedTransfer := true ; NONE) |
| 637 | else NONE |
| 638 | | _ => NONE) |
| 639 | | Transfer.Case {cases, ...} => |
| 640 | (case cases of |
| 641 | Cases.Con v => |
| 642 | if Vector.exists(v, fn (_, lbl) => |
| 643 | Label.equals (headerLabel, lbl)) then |
| 644 | (unsupportedTransfer := true ; NONE) |
| 645 | else NONE |
| 646 | | Cases.Word (_, v) => |
| 647 | if Vector.exists(v, fn (_, lbl) => |
| 648 | Label.equals (headerLabel, lbl)) then |
| 649 | (unsupportedTransfer := true ; NONE) |
| 650 | else NONE) |
| 651 | | Transfer.Goto {args, dst} => |
| 652 | if Label.equals (headerLabel, dst) then |
| 653 | SOME (Vector.sub (args, argIndex), 0) |
| 654 | else NONE |
| 655 | | _ => NONE) |
| 656 | in |
| 657 | if (Vector.length loopVars) > 1 |
| 658 | andalso not (Vector.forall |
| 659 | (loopVars, fn (arg, x) => |
| 660 | let |
| 661 | val (arg0, x0) = Vector.sub (loopVars, 0) |
| 662 | in |
| 663 | Var.equals (arg0, arg) andalso (x0 = x) |
| 664 | end)) |
| 665 | then |
| 666 | (logsi ("Can't unroll: variant transfer to head of loop", depth) ; |
| 667 | ++variantTransfer ; |
| 668 | NONE) |
| 669 | else if (!unsupportedTransfer) then |
| 670 | (logsi ("Can't unroll: unsupported transfer to head of loop", |
| 671 | depth) ; |
| 672 | ++unsupported ; |
| 673 | NONE) |
| 674 | else |
| 675 | let |
| 676 | val (loopVar, x) = Vector.sub (loopVars, 0) |
| 677 | in |
| 678 | case varChain (argVar, loopVar, loopBody, loadVar, x) of |
| 679 | NONE => (logsi ("Can't unroll: can't compute transfer", |
| 680 | depth) ; |
| 681 | ++ccTransfer ; |
| 682 | NONE) |
| 683 | | SOME (step) => |
| 684 | let |
| 685 | fun ltOrGt (vc, signed) = |
| 686 | case vc of |
| 687 | NONE => NONE |
| 688 | | SOME (_, c, b) => |
| 689 | if b then |
| 690 | SOME(Loop.Lt (c), signed) |
| 691 | else |
| 692 | SOME(Loop.Gt (c), signed) |
| 693 | |
| 694 | fun eq (vc, signed) = |
| 695 | case vc of |
| 696 | NONE => NONE |
| 697 | | SOME (_, c, _) => SOME(Loop.Eq (c), signed) |
| 698 | val loopLabels = Vector.map (loopBody, Block.label) |
| 699 | val transferVarBlock = Vector.peekMap (loopBody, (fn b => |
| 700 | let |
| 701 | val transferVar = |
| 702 | case Block.transfer b of |
| 703 | Transfer.Case {cases, default, test} => |
| 704 | if isLoopBranch (loopLabels, cases, default) then |
| 705 | SOME(test) |
| 706 | else NONE |
| 707 | | _ => NONE |
| 708 | val loopBound = |
| 709 | case (transferVar) of |
| 710 | NONE => NONE |
| 711 | | SOME (tVar) => |
| 712 | Vector.peekMap (Block.statements b, |
| 713 | (fn s => case Statement.var s of |
| 714 | NONE => NONE |
| 715 | | SOME (sVar) => |
| 716 | if Var.equals (tVar, sVar) then |
| 717 | case Statement.exp s of |
| 718 | PrimApp {args, prim, ...} => |
| 719 | if not (Vector.contains |
| 720 | (args, argVar, Var.equals)) |
| 721 | then |
| 722 | NONE |
| 723 | else |
| 724 | (case Prim.name prim of |
| 725 | Name.Word_lt (_, {signed}) => |
| 726 | ltOrGt |
| 727 | (varConst (args, |
| 728 | loadVar, |
| 729 | signed), |
| 730 | signed) |
| 731 | | Name.Word_equal _ => |
| 732 | eq |
| 733 | (varConst (args, |
| 734 | loadVar, |
| 735 | false), |
| 736 | false) |
| 737 | | _ => NONE) |
| 738 | | _ => NONE |
| 739 | else NONE)) |
| 740 | in |
| 741 | case loopBound of |
| 742 | NONE => NONE |
| 743 | | SOME (bound, signed) => |
| 744 | SOME(bound, b, signed) |
| 745 | end)) |
| 746 | in |
| 747 | case transferVarBlock of |
| 748 | NONE => |
| 749 | (logsi ("Can't unroll: can't determine bound", depth) ; |
| 750 | ++varBound ; |
| 751 | NONE) |
| 752 | | SOME(bound, block, signed) => |
| 753 | let |
| 754 | val headerTransferBlocks = |
| 755 | Vector.keepAll(loopBody, (fn b => |
| 756 | transfersToHeader (headerLabel, b))) |
| 757 | val boundDominates = Vector.forall (headerTransferBlocks, |
| 758 | (fn b => List.exists ((domInfo (Block.label b)), |
| 759 | (fn l => Label.equals |
| 760 | ((Block.label block), l))))) |
| 761 | val loopLabels = Vector.map (loopBody, Block.label) |
| 762 | val (_, _, contIsTrue) = |
| 763 | loopExit (loopLabels, Block.transfer block) |
| 764 | val entryVal = if signed then entryXSigned |
| 765 | else entryX |
| 766 | in |
| 767 | if boundDominates then |
| 768 | SOME (argIndex, |
| 769 | block, |
| 770 | Loop.T {start = entryVal, |
| 771 | step = step, |
| 772 | bound = bound, |
| 773 | invert = not contIsTrue}) |
| 774 | else |
| 775 | (logsi ("Can't unroll: bound doesn't dominate", depth) ; |
| 776 | ++boundDom ; |
| 777 | NONE) |
| 778 | end |
| 779 | end |
| 780 | end |
| 781 | end |
| 782 | (* Check all of a loop's entry point arguments to see if a constant value. |
| 783 | Returns a list of int options where SOME(x) is always x for each entry. *) |
| 784 | fun findConstantStart (entryArgs: |
| 785 | (((IntInf.t * IntInf.t) option) vector) vector) |
| 786 | : ((IntInf.t * IntInf.t) option) vector = |
| 787 | if (Vector.length entryArgs) > 0 then |
| 788 | Vector.rev (Vector.fold (entryArgs, Vector.sub (entryArgs, 0), |
| 789 | fn (v1, v2) => Vector.fromList ( |
| 790 | Vector.fold2 (v1, v2, [], fn (a1, a2, lst) => |
| 791 | case (a1, a2) of |
| 792 | (SOME(x1, x1'), SOME(x2, _)) => |
| 793 | if x1 = x2 then SOME(x1, x1')::lst |
| 794 | else NONE::lst |
| 795 | | _ => NONE::lst)))) |
| 796 | else Vector.new0 () |
| 797 | |
| 798 | (* Look for any optimization opportunities in the loop. *) |
| 799 | fun findOpportunity(functionBody: Block.t vector, |
| 800 | loopBody: Block.t vector, |
| 801 | loopHeaders: Block.t vector, |
| 802 | loadGlobal: Var.t * bool -> IntInf.t option, |
| 803 | domInfo: Label.t -> Label.t list, |
| 804 | depth: int): |
| 805 | (int * Block.t * Loop.t) option = |
| 806 | if (Vector.length loopHeaders) = 1 then |
| 807 | let |
| 808 | val header = Vector.sub (loopHeaders, 0) |
| 809 | val headerArgs = Block.args header |
| 810 | val headerLabel = Block.label header |
| 811 | val () = logsi (concat["Evaluating loop with header: ", |
| 812 | Label.toString headerLabel], depth - 1) |
| 813 | fun blockEquals (b1, b2) = |
| 814 | Label.equals (Block.label b1, Block.label b2) |
| 815 | val emptyArgs = SOME(Vector.new (Vector.length headerArgs, NONE)) |
| 816 | val entryArgs = Vector.keepAllMap(functionBody, fn block => |
| 817 | if Vector.contains (loopBody, block, blockEquals) then |
| 818 | NONE |
| 819 | else case Block.transfer block of |
| 820 | Transfer.Arith {success, ...} => |
| 821 | if Label.equals (headerLabel, success) then |
| 822 | emptyArgs |
| 823 | else NONE |
| 824 | | Transfer.Call {return, ...} => |
| 825 | (case return of |
| 826 | Return.NonTail {cont, ...} => |
| 827 | if Label.equals (headerLabel, cont) then |
| 828 | emptyArgs |
| 829 | else NONE |
| 830 | | _ => NONE) |
| 831 | | Transfer.Case {cases, ...} => |
| 832 | (case cases of |
| 833 | Cases.Con v => |
| 834 | if Vector.exists (v, fn (_, lbl) => |
| 835 | Label.equals (headerLabel, lbl)) then |
| 836 | emptyArgs |
| 837 | else |
| 838 | NONE |
| 839 | | Cases.Word (_, v) => |
| 840 | if Vector.exists (v, fn (_, lbl) => |
| 841 | Label.equals (headerLabel, lbl)) then |
| 842 | emptyArgs |
| 843 | else NONE) |
| 844 | | Transfer.Goto {args, dst} => |
| 845 | if Label.equals (dst, headerLabel) then |
| 846 | SOME(Vector.map (args, fn a => |
| 847 | case (loadGlobal(a, false), |
| 848 | loadGlobal(a, true)) |
| 849 | of |
| 850 | (NONE, NONE) => NONE |
| 851 | | (SOME v1, SOME v2) => SOME (v1, v2) |
| 852 | | _ => raise Fail "Impossible")) |
| 853 | else NONE |
| 854 | | _ => NONE) |
| 855 | val () = logsi (concat["Loop has ", |
| 856 | Int.toString (Vector.length entryArgs), |
| 857 | " entry points"], depth - 1) |
| 858 | val constantArgs = findConstantStart entryArgs |
| 859 | val unrollableArgs = |
| 860 | Vector.keepAllMapi |
| 861 | (headerArgs, fn (i, arg) => ( |
| 862 | logsi (concat["Checking arg: ", Var.toString (#1 arg)], depth) ; |
| 863 | checkArg (arg, i, Vector.sub (constantArgs, i), |
| 864 | header, loopBody, loadGlobal, domInfo, depth + 1))) |
| 865 | in |
| 866 | if (Vector.length unrollableArgs) > 0 then |
| 867 | SOME(Vector.sub (unrollableArgs, 0)) |
| 868 | else NONE |
| 869 | end |
| 870 | else |
| 871 | (logsi ("Can't optimize: loop has more than 1 header", depth) ; |
| 872 | multiHeaders := (!multiHeaders) + 1 ; |
| 873 | NONE) |
| 874 | |
| 875 | fun makeHeader(oldHeader, (newVars, newStmts), newEntry) = |
| 876 | let |
| 877 | val oldArgs = Block.args oldHeader |
| 878 | val newArgs = Vector.map (oldArgs, fn (arg, _) => arg) |
| 879 | val newTransfer = Transfer.Goto {args = newArgs, dst = newEntry} |
| 880 | in |
| 881 | (Block.T {args = oldArgs, |
| 882 | label = Block.label oldHeader, |
| 883 | statements = Vector.fromList newStmts, |
| 884 | transfer = newTransfer}, |
| 885 | newVars) |
| 886 | end |
| 887 | |
| 888 | (* Copy an entire loop. In the header, rewrite the transfer to take the loop branch. |
| 889 | In the transfers to the top of the loop, rewrite the transfer to goto next. |
| 890 | Ensure that the header is the first element in the list. |
| 891 | Replace all instances of argi with argVar *) |
| 892 | fun copyLoop(blocks: Block.t vector, |
| 893 | nextLabel: Label.t, |
| 894 | headerLabel: Label.t, |
| 895 | tBlock: Block.t, |
| 896 | argi: int, |
| 897 | argVar: Var.t, |
| 898 | rewriteTransfer: bool, |
| 899 | blockInfo: Label.t -> BlockInfo, |
| 900 | setBlockInfo: Label.t * BlockInfo -> unit): Block.t vector = |
| 901 | let |
| 902 | val labels = Vector.map (blocks, Block.label) |
| 903 | (* Assign a new label for each block *) |
| 904 | val newBlocks = Vector.map (blocks, fn b => |
| 905 | let |
| 906 | val oldName = Block.label b |
| 907 | val oldArgs = Block.args b |
| 908 | val newName = Label.newNoname() |
| 909 | val () = setBlockInfo(oldName, (newName, oldArgs)) |
| 910 | in |
| 911 | Block.T {args = Block.args b, |
| 912 | label = newName, |
| 913 | statements = Block.statements b, |
| 914 | transfer = Block.transfer b} |
| 915 | end) |
| 916 | (* Rewrite the transfers of each block *) |
| 917 | val fixedBlocks = Vector.map |
| 918 | (newBlocks, fn Block.T {args, label, statements, transfer} => |
| 919 | let |
| 920 | val f = fn l => fixLabel(blockInfo, l, labels) |
| 921 | val isHeader = Label.equals (label, f(headerLabel)) |
| 922 | val (newArgs, unrolledArg) = |
| 923 | if isHeader then |
| 924 | (args, SOME(Vector.sub (args, argi))) |
| 925 | else (args, NONE) |
| 926 | val newStmts = |
| 927 | if isHeader then |
| 928 | case unrolledArg of |
| 929 | NONE => statements |
| 930 | | SOME(var, ty) => |
| 931 | let |
| 932 | val assignExp = Exp.Var (argVar) |
| 933 | val assign = Statement.T {exp = assignExp, |
| 934 | ty = ty, |
| 935 | var = SOME(var)} |
| 936 | val assignV = Vector.new1(assign) |
| 937 | in |
| 938 | Vector.concat [assignV, statements] |
| 939 | end |
| 940 | else |
| 941 | statements |
| 942 | val newTransfer = |
| 943 | if rewriteTransfer andalso |
| 944 | Label.equals (label, f(Block.label tBlock)) |
| 945 | then |
| 946 | let |
| 947 | val (_, contLabel, _) = loopExit(labels, transfer) |
| 948 | in |
| 949 | Transfer.Goto {args = Vector.new0 (), dst = f(contLabel)} |
| 950 | end |
| 951 | else |
| 952 | case transfer of |
| 953 | Transfer.Arith {args, overflow, prim, success, ty} => |
| 954 | if Label.equals (success, headerLabel) then |
| 955 | Transfer.Arith {args = args, |
| 956 | overflow = f(overflow), |
| 957 | prim = prim, |
| 958 | success = nextLabel, |
| 959 | ty = ty} |
| 960 | else |
| 961 | Transfer.Arith {args = args, |
| 962 | overflow = f(overflow), |
| 963 | prim = prim, |
| 964 | success = f(success), |
| 965 | ty = ty} |
| 966 | | Transfer.Call {args, func, return} => |
| 967 | let |
| 968 | val newReturn = |
| 969 | case return of |
| 970 | Return.NonTail {cont, handler} => |
| 971 | let |
| 972 | val newHandler = |
| 973 | case handler of |
| 974 | Handler.Handle l => Handler.Handle(f(l)) |
| 975 | | _ => handler |
| 976 | in |
| 977 | Return.NonTail {cont = f(cont), handler = newHandler} |
| 978 | end |
| 979 | | _ => return |
| 980 | in |
| 981 | Transfer.Call {args = args, func = func, return = newReturn} |
| 982 | end |
| 983 | | Transfer.Case {cases, default, test} => |
| 984 | let |
| 985 | val newCases = Cases.map(cases, f) |
| 986 | val newDefault = case default of |
| 987 | NONE => default |
| 988 | | SOME(l) => SOME(f(l)) |
| 989 | in |
| 990 | Transfer.Case {cases = newCases, |
| 991 | default = newDefault, |
| 992 | test = test} |
| 993 | end |
| 994 | | Transfer.Goto {args, dst} => |
| 995 | if Label.equals (dst, headerLabel) then |
| 996 | Transfer.Goto {args = args, dst = nextLabel} |
| 997 | else |
| 998 | Transfer.Goto {args = args, dst = f(dst)} |
| 999 | | Transfer.Runtime {args, prim, return} => |
| 1000 | Transfer.Runtime {args = args, prim = prim, return = f(return)} |
| 1001 | | _ => transfer |
| 1002 | in |
| 1003 | Block.T {args = newArgs, |
| 1004 | label = label, |
| 1005 | statements = newStmts, |
| 1006 | transfer = newTransfer} |
| 1007 | end) |
| 1008 | in |
| 1009 | Vector.rev fixedBlocks |
| 1010 | end |
| 1011 | |
| 1012 | (* Unroll a loop. The header should ALWAYS be the first element |
| 1013 | in the returned list. *) |
| 1014 | fun unrollLoop (oldHeader, tBlock, argi, loopBlocks, argLabels, |
| 1015 | exit, rewriteTransfer, blockInfo, setBlockInfo) = |
| 1016 | let |
| 1017 | val oldHeaderLabel = Block.label oldHeader |
| 1018 | in |
| 1019 | case argLabels of |
| 1020 | [] => [exit] |
| 1021 | | hd::tl => |
| 1022 | let |
| 1023 | val res = unrollLoop (oldHeader, tBlock, argi, |
| 1024 | loopBlocks, tl, exit, rewriteTransfer, |
| 1025 | blockInfo, setBlockInfo) |
| 1026 | val nextBlockLabel = Block.label (List.first res) |
| 1027 | val newLoop = copyLoop(loopBlocks, nextBlockLabel, oldHeaderLabel, |
| 1028 | tBlock, argi, hd, rewriteTransfer, |
| 1029 | blockInfo, setBlockInfo) |
| 1030 | in |
| 1031 | (Vector.toList newLoop) @ res |
| 1032 | end |
| 1033 | end |
| 1034 | |
| 1035 | (* Given: |
| 1036 | - an itertion count |
| 1037 | - a loop body |
| 1038 | Returns (b, x, y, z) such that: |
| 1039 | if b is true |
| 1040 | - unroll the loop completely |
| 1041 | - x, y, and z are undefined. |
| 1042 | if b is false |
| 1043 | - x is the number of times to expand the loop body |
| 1044 | - y is the number of iterations to run the expanded body (must never be 0) |
| 1045 | - z is the number of times to peel the loop body |
| 1046 | *) |
| 1047 | fun shouldOptimize (iterCount, loopBlocks, depth) = |
| 1048 | let |
| 1049 | val loopSize' = Block.sizeV (loopBlocks, {sizeExp = Exp.size, sizeTransfer = Transfer.size}) |
| 1050 | val loopSize = IntInf.fromInt loopSize' |
| 1051 | val unrollLimit = IntInf.fromInt (!Control.loopUnrollLimit) |
| 1052 | val () = logsi ("iterations * loop size < unroll factor = can total unroll", |
| 1053 | depth) |
| 1054 | val canTotalUnroll = (iterCount * loopSize) < unrollLimit |
| 1055 | val () = logsi (concat[IntInf.toString iterCount, " * ", |
| 1056 | IntInf.toString loopSize, " < ", |
| 1057 | IntInf.toString unrollLimit, " = ", |
| 1058 | Bool.toString canTotalUnroll], depth) |
| 1059 | in |
| 1060 | if (iterCount = 1) orelse canTotalUnroll then |
| 1061 | (* Loop runs once or it's small enough to unroll *) |
| 1062 | (true, 0, 0, 0) |
| 1063 | else if loopSize >= unrollLimit then |
| 1064 | (* Loop is too big to unroll at all, peel off 1 iteration *) |
| 1065 | (false, 1, iterCount - 1, 1) |
| 1066 | else |
| 1067 | let |
| 1068 | val exBodySize = unrollLimit div loopSize |
| 1069 | val exIters = iterCount div exBodySize |
| 1070 | val leftovers = iterCount - (exIters * exBodySize) |
| 1071 | in |
| 1072 | if (exIters - 1) < 2 then |
| 1073 | (* If the unpeeled loop would run 1 or 0 times, just unroll the |
| 1074 | whole thing *) |
| 1075 | (true, 0, 0, 0) |
| 1076 | else |
| 1077 | if leftovers = 0 then |
| 1078 | (* If we don't get any unpeelings naturally, force one *) |
| 1079 | (false, exBodySize, exIters - 1, exBodySize) |
| 1080 | else |
| 1081 | (* Otherwise stick them on the front of the loop *) |
| 1082 | (false, exBodySize, exIters, leftovers) |
| 1083 | end |
| 1084 | end |
| 1085 | |
| 1086 | fun expandLoop (oldHeader, loopBlocks, loop, tBlock, argi, argSize, oldArg, |
| 1087 | exBody, iterBody, exitLabel, blockInfo, setBlockInfo) = |
| 1088 | let |
| 1089 | (* Make a new loop header with an additional arg *) |
| 1090 | val newLoopEntry = Label.newNoname() |
| 1091 | val (newLoopHeader, loopArgLabels) = |
| 1092 | makeHeader (oldHeader, |
| 1093 | Loop.makeSteps (loop, argSize, oldArg, exBody), |
| 1094 | newLoopEntry) |
| 1095 | val iterVar = Var.newNoname () |
| 1096 | val newLoopHeaderArgs' = Vector.concat |
| 1097 | [Block.args newLoopHeader, |
| 1098 | Vector.new1 (iterVar, Type.word argSize)] |
| 1099 | val newLoopHeader' = |
| 1100 | Block.T {args = newLoopHeaderArgs', |
| 1101 | label = Label.newNoname (), |
| 1102 | statements = Block.statements newLoopHeader, |
| 1103 | transfer = Block.transfer newLoopHeader} |
| 1104 | |
| 1105 | (* Make a new goto to the top of the loop increasing the iter by 1 *) |
| 1106 | val loopHeaderGoto = |
| 1107 | let |
| 1108 | val (newVar, newVarStmts) = Loop.makeVarStmt (1, argSize, iterVar) |
| 1109 | val nonIterArgs = Vector.map (Block.args oldHeader, fn (a, _) => a) |
| 1110 | val newArgs = Vector.concat [nonIterArgs, Vector.new1 (newVar)] |
| 1111 | val newTransfer = Transfer.Goto {args = newArgs, |
| 1112 | dst = Block.label newLoopHeader'} |
| 1113 | in |
| 1114 | Block.T {args = Vector.new0 (), |
| 1115 | label = Label.newNoname (), |
| 1116 | statements = Vector.fromList newVarStmts, |
| 1117 | transfer = newTransfer} |
| 1118 | end |
| 1119 | |
| 1120 | val newLoopExit = |
| 1121 | let |
| 1122 | val (newLimitVar, newLimitStmt) = |
| 1123 | Loop.makeConstStmt (iterBody - 1, argSize) |
| 1124 | val (newComp, newCompVar) = |
| 1125 | let |
| 1126 | val newVar = Var.newNoname () |
| 1127 | val newTy = Type.datatypee Tycon.bool |
| 1128 | val newExp = |
| 1129 | PrimApp {args = Vector.new2 (iterVar, newLimitVar), |
| 1130 | prim = Prim.wordLt (argSize, {signed = true}), |
| 1131 | targs = Vector.new0 ()} |
| 1132 | in |
| 1133 | (Statement.T {exp = newExp, |
| 1134 | ty = newTy, |
| 1135 | var = SOME(newVar)}, |
| 1136 | newVar) |
| 1137 | end |
| 1138 | val exitStatements = Vector.new2(newLimitStmt, newComp) |
| 1139 | val exitCases = Cases.Con ( |
| 1140 | Vector.new1 (Con.fromBool true, |
| 1141 | Block.label loopHeaderGoto)) |
| 1142 | val exitTransfer = Transfer.Case {cases = exitCases, |
| 1143 | default = SOME(exitLabel), |
| 1144 | test = newCompVar} |
| 1145 | in |
| 1146 | Block.T {args = Block.args oldHeader, |
| 1147 | label = Label.newNoname (), |
| 1148 | statements = exitStatements, |
| 1149 | transfer = exitTransfer} |
| 1150 | end |
| 1151 | |
| 1152 | (* Expand the loop exBody times. Rewrite the bound's transfer, |
| 1153 | because we know it will always be true and it won't be eliminated |
| 1154 | by shrink. *) |
| 1155 | val newLoopBlocks = unrollLoop (oldHeader, tBlock, argi, |
| 1156 | loopBlocks, loopArgLabels, newLoopExit, |
| 1157 | true, blockInfo, setBlockInfo) |
| 1158 | val firstLoopBlock = List.first newLoopBlocks |
| 1159 | val loopArgs' = Block.args firstLoopBlock |
| 1160 | val loopStatements' = Block.statements firstLoopBlock |
| 1161 | val loopTransfer' = Block.transfer firstLoopBlock |
| 1162 | val newLoopHead = Block.T {args = loopArgs', |
| 1163 | label = newLoopEntry, |
| 1164 | statements = loopStatements', |
| 1165 | transfer = loopTransfer'} |
| 1166 | val newLoopBlocks' = newLoopHeader':: |
| 1167 | (loopHeaderGoto:: |
| 1168 | (newLoopHead:: |
| 1169 | (listPop newLoopBlocks))) |
| 1170 | in |
| 1171 | newLoopBlocks' |
| 1172 | end |
| 1173 | |
| 1174 | (* Attempt to optimize a single loop. Returns a list of blocks to add to the |
| 1175 | program and a list of blocks to remove from the program. *) |
| 1176 | fun optimizeLoop(allBlocks, headerNodes, loopNodes, |
| 1177 | nodeBlock, loadGlobal, domInfo, depth) = |
| 1178 | let |
| 1179 | val () = ++loopCount |
| 1180 | val headers = Vector.map (headerNodes, nodeBlock) |
| 1181 | val loopBlocks = Vector.map (loopNodes, nodeBlock) |
| 1182 | val loopBlockNames = Vector.map (loopBlocks, Block.label) |
| 1183 | val optOpt = findOpportunity (allBlocks, loopBlocks, headers, |
| 1184 | loadGlobal, domInfo, depth + 1) |
| 1185 | val {get = blockInfo: Label.t -> BlockInfo, |
| 1186 | set = setBlockInfo: Label.t * BlockInfo -> unit, destroy} = |
| 1187 | Property.destGetSet(Label.plist, |
| 1188 | Property.initRaise("blockInfo", Label.layout)) |
| 1189 | in |
| 1190 | case optOpt of |
| 1191 | NONE => ([], []) |
| 1192 | | SOME (argi, tBlock, loop) => |
| 1193 | if Loop.isInfiniteLoop loop then |
| 1194 | (logsi ("Can't unroll: infinite loop", depth) ; |
| 1195 | ++infinite ; |
| 1196 | logsi (concat["Index: ", Int.toString argi, Loop.toString loop], |
| 1197 | depth) ; |
| 1198 | ([], [])) |
| 1199 | else |
| 1200 | let |
| 1201 | val () = ++optCount |
| 1202 | val oldHeader = Vector.sub (headers, 0) |
| 1203 | val oldArgs = Block.args oldHeader |
| 1204 | val (oldArg, oldType) = Vector.sub (oldArgs, argi) |
| 1205 | val () = logsi (concat["Can unroll loop on ", |
| 1206 | Var.toString oldArg], depth) |
| 1207 | val () = logsi (concat["Index: ", Int.toString argi, |
| 1208 | Loop.toString loop], depth) |
| 1209 | val iterCount = Loop.iterCount loop |
| 1210 | val () = logsi (concat["Loop will run ", |
| 1211 | IntInf.toString iterCount, |
| 1212 | " times"], depth) |
| 1213 | val () = logsi (concat["Transfer block is ", |
| 1214 | Label.toString (Block.label tBlock)], |
| 1215 | depth) |
| 1216 | val () = Histogram.inc ((!histogram), iterCount) |
| 1217 | val (totalUnroll, exBody, iterBody, peel) = |
| 1218 | shouldOptimize (iterCount, loopBlocks, depth + 1) |
| 1219 | val argSize = case Type.dest oldType of |
| 1220 | Type.Word wsize => wsize |
| 1221 | | _ => raise Fail "Argument is not of type word" |
| 1222 | in |
| 1223 | if totalUnroll then |
| 1224 | let |
| 1225 | val () = ++total |
| 1226 | val () = logsi ("Completely unrolling loop", depth) |
| 1227 | val newEntry = Label.newNoname() |
| 1228 | val (newHeader, argLabels) = |
| 1229 | makeHeader (oldHeader, |
| 1230 | Loop.makeConstants (loop, argSize, iterCount+1), |
| 1231 | newEntry) |
| 1232 | val exitBlock = Block.T {args = oldArgs, |
| 1233 | label = Label.newNoname (), |
| 1234 | statements = Vector.new0 (), |
| 1235 | transfer = Transfer.Bug} |
| 1236 | (* For each induction variable value, copy the loop's body *) |
| 1237 | val newBlocks = unrollLoop (oldHeader, tBlock, argi, |
| 1238 | loopBlocks, argLabels, exitBlock, |
| 1239 | false, blockInfo, setBlockInfo) |
| 1240 | (* Fix the first entry's label *) |
| 1241 | val firstBlock = List.first newBlocks |
| 1242 | val args' = Block.args firstBlock |
| 1243 | val statements' = Block.statements firstBlock |
| 1244 | val transfer' = Block.transfer firstBlock |
| 1245 | val newHead = Block.T {args = args', |
| 1246 | label = newEntry, |
| 1247 | statements = statements', |
| 1248 | transfer = transfer'} |
| 1249 | val newBlocks' = newHeader::(newHead::(listPop newBlocks)) |
| 1250 | val () = destroy() |
| 1251 | in |
| 1252 | (newBlocks', (Vector.toList loopBlockNames)) |
| 1253 | end |
| 1254 | else |
| 1255 | let |
| 1256 | val () = ++partial |
| 1257 | val () = logsi ("Partially unrolling loop", depth) |
| 1258 | val () = logsi (concat["Body expansion: ", |
| 1259 | IntInf.toString exBody, |
| 1260 | " Body iterations: ", |
| 1261 | IntInf.toString iterBody, |
| 1262 | " Peel iterations: ", |
| 1263 | IntInf.toString peel], |
| 1264 | depth) |
| 1265 | val oldArgLabels = Vector.map (oldArgs, fn (a, _) => a) |
| 1266 | (* Produce an exit loop iteration. *) |
| 1267 | val exitEntry = Label.newNoname() |
| 1268 | val (exitHeader, exitConsts) = |
| 1269 | makeHeader (oldHeader, |
| 1270 | Loop.makeLastConstant (loop, argSize), |
| 1271 | exitEntry) |
| 1272 | val exitHeader' = |
| 1273 | Block.T {args = Block.args exitHeader, |
| 1274 | label = Label.newNoname (), |
| 1275 | statements = Block.statements exitHeader, |
| 1276 | transfer = Block.transfer exitHeader} |
| 1277 | val exitBlock = Block.T {args = oldArgs, |
| 1278 | label = Label.newNoname (), |
| 1279 | statements = Vector.new0 (), |
| 1280 | transfer = Transfer.Bug} |
| 1281 | val exitBlocks = unrollLoop (oldHeader, tBlock, argi, |
| 1282 | loopBlocks, exitConsts, exitBlock, |
| 1283 | false, blockInfo, setBlockInfo) |
| 1284 | val exitFirstBlock = List.first exitBlocks |
| 1285 | val exitArgs = Block.args exitFirstBlock |
| 1286 | val exitStatements = Block.statements exitFirstBlock |
| 1287 | val exitTransfer = Block.transfer exitFirstBlock |
| 1288 | val exitHead = Block.T {args = exitArgs, |
| 1289 | label = exitEntry, |
| 1290 | statements = exitStatements, |
| 1291 | transfer = exitTransfer} |
| 1292 | val exitGotoLabel = Label.newNoname() |
| 1293 | val exitGoto = Block.T {args = Vector.new0 (), |
| 1294 | label = exitGotoLabel, |
| 1295 | statements = Vector.new0 (), |
| 1296 | transfer = |
| 1297 | Transfer.Goto |
| 1298 | {args = oldArgLabels, |
| 1299 | dst = Block.label exitHeader'}} |
| 1300 | val exitBlocks' = exitGoto:: |
| 1301 | exitHeader':: |
| 1302 | (exitHead::(listPop exitBlocks)) |
| 1303 | |
| 1304 | (* Expand the loop *) |
| 1305 | val exLoopBlocks = expandLoop (oldHeader, loopBlocks, loop, |
| 1306 | tBlock, argi, argSize, |
| 1307 | oldArg, exBody, iterBody, |
| 1308 | exitGotoLabel, |
| 1309 | blockInfo, setBlockInfo) |
| 1310 | (* Make an entry to the expanded loop *) |
| 1311 | val exLoopEntry = |
| 1312 | let |
| 1313 | val (zeroVar, zeroStmt) = Loop.makeConstStmt(0, argSize) |
| 1314 | val exLoopHeader = Block.label (List.first exLoopBlocks) |
| 1315 | val transferArgs = |
| 1316 | Vector.concat [oldArgLabels, Vector.new1(zeroVar)] |
| 1317 | val newTransfer = Transfer.Goto {args = transferArgs, |
| 1318 | dst = exLoopHeader} |
| 1319 | in |
| 1320 | Block.T {args = oldArgs, |
| 1321 | label = Label.newNoname(), |
| 1322 | statements = Vector.new1 zeroStmt, |
| 1323 | transfer = newTransfer} |
| 1324 | end |
| 1325 | (* Make a replacement loop entry *) |
| 1326 | val newEntry = Label.newNoname() |
| 1327 | val (newHeader, argLabels) = |
| 1328 | makeHeader (oldHeader, |
| 1329 | Loop.makeConstants (loop, argSize, peel), |
| 1330 | newEntry) |
| 1331 | (* For each induction variable value, copy the loop's body *) |
| 1332 | val newBlocks = unrollLoop (oldHeader, tBlock, argi, |
| 1333 | loopBlocks, argLabels, exLoopEntry, |
| 1334 | false, blockInfo, setBlockInfo) |
| 1335 | (* Fix the first entry's label *) |
| 1336 | val firstBlock = List.first newBlocks |
| 1337 | val args' = Block.args firstBlock |
| 1338 | val statements' = Block.statements firstBlock |
| 1339 | val transfer' = Block.transfer firstBlock |
| 1340 | val newHead = Block.T {args = args', |
| 1341 | label = newEntry, |
| 1342 | statements = statements', |
| 1343 | transfer = transfer'} |
| 1344 | val newBlocks' = newHeader::(newHead::(listPop newBlocks)) |
| 1345 | val () = destroy() |
| 1346 | in |
| 1347 | (newBlocks' @ exLoopBlocks @ exitBlocks', |
| 1348 | (Vector.toList loopBlockNames)) |
| 1349 | end |
| 1350 | end |
| 1351 | end |
| 1352 | |
| 1353 | (* Traverse sub-forests until the innermost loop is found. *) |
| 1354 | fun traverseSubForest ({loops, notInLoop}, |
| 1355 | allBlocks, |
| 1356 | enclosingHeaders, |
| 1357 | labelNode, nodeBlock, loadGlobal, domInfo) = |
| 1358 | if (Vector.length loops) = 0 then |
| 1359 | optimizeLoop(allBlocks, enclosingHeaders, notInLoop, |
| 1360 | nodeBlock, loadGlobal, domInfo, 1) |
| 1361 | else |
| 1362 | Vector.fold(loops, ([], []), fn (loop, (new, remove)) => |
| 1363 | let |
| 1364 | val (nBlocks, rBlocks) = |
| 1365 | traverseLoop(loop, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo) |
| 1366 | in |
| 1367 | ((new @ nBlocks), (remove @ rBlocks)) |
| 1368 | end) |
| 1369 | |
| 1370 | (* Traverse loops in the loop forest. *) |
| 1371 | and traverseLoop ({headers, child}, |
| 1372 | allBlocks, |
| 1373 | labelNode, nodeBlock, loadGlobal, domInfo) = |
| 1374 | traverseSubForest ((Forest.dest child), allBlocks, |
| 1375 | headers, labelNode, nodeBlock, loadGlobal, domInfo) |
| 1376 | |
| 1377 | (* Traverse the top-level loop forest. *) |
| 1378 | fun traverseForest ({loops, notInLoop = _}, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo) = |
| 1379 | let |
| 1380 | (* Gather the blocks to add/remove *) |
| 1381 | val (newBlocks, blocksToRemove) = |
| 1382 | Vector.fold(loops, ([], []), fn (loop, (new, remove)) => |
| 1383 | let |
| 1384 | val (nBlocks, rBlocks) = |
| 1385 | traverseLoop(loop, allBlocks, labelNode, nodeBlock, loadGlobal, domInfo) |
| 1386 | in |
| 1387 | ((new @ nBlocks), (remove @ rBlocks)) |
| 1388 | end) |
| 1389 | val keep: Block.t -> bool = |
| 1390 | (fn b => not (List.contains(blocksToRemove, (Block.label b), Label.equals))) |
| 1391 | val reducedBlocks = Vector.keepAll(allBlocks, keep) |
| 1392 | in |
| 1393 | (Vector.toList reducedBlocks) @ newBlocks |
| 1394 | end |
| 1395 | |
| 1396 | fun setDoms tree = |
| 1397 | let |
| 1398 | val {get = domInfo: Label.t -> Label.t list, |
| 1399 | set = setDomInfo: Label.t * Label.t list -> unit, destroy} = |
| 1400 | Property.destGetSet(Label.plist, |
| 1401 | Property.initRaise("domInfo", Label.layout)) |
| 1402 | fun loop (tree, doms) = |
| 1403 | case tree of |
| 1404 | Tree.T (block, children) => |
| 1405 | (setDomInfo (Block.label block, doms) ; |
| 1406 | Vector.foreach (children, fn tree => loop(tree, |
| 1407 | (Block.label block)::doms))) |
| 1408 | val () = loop (tree, []) |
| 1409 | in |
| 1410 | (domInfo, destroy) |
| 1411 | end |
| 1412 | |
| 1413 | (* Performs the optimization on the body of a single function. *) |
| 1414 | fun optimizeFunction loadGlobal function = |
| 1415 | let |
| 1416 | val {graph, labelNode, nodeBlock} = Function.controlFlow function |
| 1417 | val {args, blocks, mayInline, name, raises, returns, start} = |
| 1418 | Function.dest function |
| 1419 | val fsize = Function.size (function, {sizeExp = Exp.size, sizeTransfer = Transfer.size}) |
| 1420 | val () = logs (concat["Optimizing function: ", Func.toString name, |
| 1421 | " of size ", Int.toString fsize]) |
| 1422 | val root = labelNode start |
| 1423 | val forest = Graph.loopForestSteensgaard(graph, {root = root}) |
| 1424 | val dtree = Function.dominatorTree function |
| 1425 | val (domInfo, destroy) = setDoms dtree |
| 1426 | val newBlocks = traverseForest((Forest.dest forest), |
| 1427 | blocks, labelNode, nodeBlock, loadGlobal, domInfo) |
| 1428 | val () = destroy() |
| 1429 | in |
| 1430 | Function.new {args = args, |
| 1431 | blocks = Vector.fromList(newBlocks), |
| 1432 | mayInline = mayInline, |
| 1433 | name = name, |
| 1434 | raises = raises, |
| 1435 | returns = returns, |
| 1436 | start = start} |
| 1437 | end |
| 1438 | |
| 1439 | (* Entry point. *) |
| 1440 | fun transform (Program.T {datatypes, globals, functions, main}) = |
| 1441 | let |
| 1442 | fun loadGlobal (var: Var.t, signed: bool): IntInf.t option = |
| 1443 | let |
| 1444 | fun matchGlobal v g = |
| 1445 | case Statement.var g of |
| 1446 | NONE => false |
| 1447 | | SOME (v') => Var.equals (v, v') |
| 1448 | in |
| 1449 | case Vector.peek (globals, matchGlobal var) of |
| 1450 | NONE => NONE |
| 1451 | | SOME (stmt) => |
| 1452 | (case Statement.exp stmt of |
| 1453 | Exp.Const c => |
| 1454 | (case c of |
| 1455 | Const.Word w => |
| 1456 | if signed then |
| 1457 | SOME(WordX.toIntInfX w) |
| 1458 | else |
| 1459 | SOME(WordX.toIntInf w) |
| 1460 | | _ => NONE) |
| 1461 | | _ => NONE) |
| 1462 | end |
| 1463 | val () = loopCount := 0 |
| 1464 | val () = total := 0 |
| 1465 | val () = partial := 0 |
| 1466 | val () = optCount := 0 |
| 1467 | val () = multiHeaders := 0 |
| 1468 | val () = varEntryArg := 0 |
| 1469 | val () = variantTransfer := 0 |
| 1470 | val () = unsupported := 0 |
| 1471 | val () = ccTransfer := 0 |
| 1472 | val () = varBound := 0 |
| 1473 | val () = infinite := 0 |
| 1474 | val () = boundDom := 0 |
| 1475 | val () = histogram := Histogram.new () |
| 1476 | val () = logs (concat["Unrolling loops. Unrolling factor = ", |
| 1477 | Int.toString (!Control.loopUnrollLimit)]) |
| 1478 | val optimizedFunctions = List.map (functions, optimizeFunction loadGlobal) |
| 1479 | val restore = restoreFunction {globals = globals} |
| 1480 | val () = logs "Performing SSA restore" |
| 1481 | val cleanedFunctions = List.map (optimizedFunctions, restore) |
| 1482 | val shrink = shrinkFunction {globals = globals} |
| 1483 | val () = logs "Performing shrink" |
| 1484 | val shrunkFunctions = List.map (cleanedFunctions, shrink) |
| 1485 | val () = logstat (loopCount, |
| 1486 | "total innermost loops") |
| 1487 | val () = logstat (optCount, |
| 1488 | "loops optimized") |
| 1489 | val () = logstat (total, |
| 1490 | "loops completely unrolled") |
| 1491 | val () = logstat (partial, |
| 1492 | "loops partially unrolled") |
| 1493 | val () = logstat (multiHeaders, |
| 1494 | "loops had multiple headers") |
| 1495 | val () = logstat (varEntryArg, |
| 1496 | "variable entry values") |
| 1497 | val () = logstat (variantTransfer, |
| 1498 | "loops had variant transfers to the header") |
| 1499 | val () = logstat (unsupported, |
| 1500 | "loops had unsupported transfers to the header") |
| 1501 | val () = logstat (ccTransfer, |
| 1502 | "loops had non-computable steps") |
| 1503 | val () = logstat (varBound, |
| 1504 | "loops had variable bounds") |
| 1505 | val () = logstat (infinite, |
| 1506 | "infinite loops") |
| 1507 | val () = logstat (boundDom, |
| 1508 | "loops had non-dominating bounds") |
| 1509 | val () = logs ("Iterations: Occurences") |
| 1510 | val () = logs (Histogram.toString (!histogram)) |
| 1511 | val () = logs "Done." |
| 1512 | in |
| 1513 | Program.T {datatypes = datatypes, |
| 1514 | globals = globals, |
| 1515 | functions = shrunkFunctions, |
| 1516 | main = main} |
| 1517 | end |
| 1518 | |
| 1519 | end |