| 1 | (* Copyright (C) 2017 Matthew Fluet. |
| 2 | * |
| 3 | * MLton is released under a BSD-style license. |
| 4 | * See the file MLton-LICENSE for details. |
| 5 | *) |
| 6 | |
| 7 | functor ShareZeroVec (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM = |
| 8 | struct |
| 9 | |
| 10 | open S |
| 11 | open Exp |
| 12 | |
| 13 | fun transform (Program.T {datatypes, globals, functions, main}) = |
| 14 | let |
| 15 | val seqIndexSize = WordSize.seqIndex () |
| 16 | val seqIndexTy = Type.word seqIndexSize |
| 17 | val (zeroVar, globals) = |
| 18 | case Vector.peekMap (globals, fn Statement.T {var, ty, exp} => |
| 19 | case (var, exp) of |
| 20 | (SOME var, Exp.Const (Const.Word w)) => |
| 21 | if WordX.isZero w |
| 22 | andalso Type.equals (seqIndexTy, ty) |
| 23 | then SOME var |
| 24 | else NONE |
| 25 | | _ => NONE) of |
| 26 | SOME zeroVar => (zeroVar, globals) |
| 27 | | _ => let |
| 28 | val zeroVar = Var.newString "zero" |
| 29 | val zeroVarStmt = |
| 30 | Statement.T |
| 31 | {var = SOME zeroVar, |
| 32 | ty = seqIndexTy, |
| 33 | exp = Exp.Const (Const.word (WordX.zero seqIndexSize))} |
| 34 | in |
| 35 | (zeroVar, Vector.concat [globals, Vector.new1 zeroVarStmt]) |
| 36 | end |
| 37 | |
| 38 | val shrink = shrinkFunction {globals = globals} |
| 39 | |
| 40 | (* initialize a HashSet for new zero-length array globals *) |
| 41 | val newGlobals = ref [] |
| 42 | local |
| 43 | val hs: {eltTy: Type.t, zeroArrVar: Var.t} HashSet.t = |
| 44 | HashSet.new {hash = fn {eltTy, ...} => Type.hash eltTy} |
| 45 | in |
| 46 | fun getZeroArrVar (ty: Type.t): Var.t = |
| 47 | let |
| 48 | val {zeroArrVar, ...} = |
| 49 | HashSet.lookupOrInsert |
| 50 | (hs, Type.hash ty, |
| 51 | fn {eltTy, ...} => Type.equals (eltTy, ty), |
| 52 | fn () => |
| 53 | let |
| 54 | val zeroArrVar = Var.newString "zeroArr" |
| 55 | val statement = |
| 56 | Statement.T |
| 57 | {var = SOME zeroArrVar, |
| 58 | ty = Type.array ty, |
| 59 | exp = PrimApp |
| 60 | {args = Vector.new1 zeroVar, |
| 61 | prim = Prim.arrayAlloc {raw = false}, |
| 62 | targs = Vector.new1 ty}} |
| 63 | val () = List.push (newGlobals, statement) |
| 64 | in |
| 65 | {eltTy = ty, |
| 66 | zeroArrVar = zeroArrVar} |
| 67 | end) |
| 68 | in |
| 69 | zeroArrVar |
| 70 | end |
| 71 | end |
| 72 | |
| 73 | (* splitStmts (stmts, arrVars) |
| 74 | * returns (preStmts, (arrVar, arrTy, eltTy, lenVar), postStmts) |
| 75 | * when stmts = ...pre... |
| 76 | * val arrVar: arrTy = Array_alloc(eltTy) (lenVar) |
| 77 | * ...post... |
| 78 | * and arrVar in arrVars |
| 79 | *) |
| 80 | fun splitStmts (stmts, arrVars) = |
| 81 | case Vector.peekMapi |
| 82 | (stmts, fn Statement.T {var, ty, exp} => |
| 83 | case exp of |
| 84 | PrimApp ({prim, args, targs}) => |
| 85 | (case (var, Prim.name prim) of |
| 86 | (SOME var, Prim.Name.Array_alloc {raw = false}) => |
| 87 | if List.contains (arrVars, var, Var.equals) |
| 88 | then SOME (var, ty, |
| 89 | Vector.first targs, |
| 90 | Vector.first args) |
| 91 | else NONE |
| 92 | | _ => NONE) |
| 93 | | _ => NONE) of |
| 94 | NONE => NONE |
| 95 | | SOME (i, (arrVar, arrTy, eltTy, lenVar)) => |
| 96 | SOME (Vector.prefix (stmts, i), |
| 97 | (* val arrVar: arrTy = Array_alloc(eltTy) (lenVar) *) |
| 98 | (arrVar, arrTy, eltTy, lenVar), |
| 99 | Vector.dropPrefix (stmts, i + 1)) |
| 100 | |
| 101 | fun transformBlock (block, arrVars) = |
| 102 | case splitStmts (Block.statements block, arrVars) of |
| 103 | NONE => NONE |
| 104 | | SOME (preStmts, (arrVar, arrTy, eltTy, lenVar), postStmts) => |
| 105 | let |
| 106 | val Block.T {label, args, transfer, ...} = block |
| 107 | val ifZeroLab = Label.newString "L_zeroLen" |
| 108 | val ifNonZeroLab = Label.newString "L_nonZeroLen" |
| 109 | val joinLab = Label.newString "L_join" |
| 110 | |
| 111 | (* new block up to Array_alloc match *) |
| 112 | val preBlock = |
| 113 | let |
| 114 | val isZeroVar = Var.newString "isZero" |
| 115 | val newStatements = |
| 116 | Vector.new1 |
| 117 | (Statement.T |
| 118 | {var = SOME isZeroVar, |
| 119 | ty = Type.bool, |
| 120 | exp = PrimApp |
| 121 | {args = Vector.new2 (zeroVar, lenVar), |
| 122 | prim = Prim.wordEqual seqIndexSize, |
| 123 | targs = Vector.new0 ()}}) |
| 124 | val transfer = |
| 125 | Transfer.Case |
| 126 | {cases = (Cases.Con o Vector.new2) |
| 127 | ((Con.truee, ifZeroLab), |
| 128 | (Con.falsee, ifNonZeroLab)), |
| 129 | default = NONE, |
| 130 | test = isZeroVar} |
| 131 | in |
| 132 | Block.T {label = label, |
| 133 | args = args, |
| 134 | statements = Vector.concat [preStmts, |
| 135 | newStatements], |
| 136 | transfer = transfer} |
| 137 | end |
| 138 | |
| 139 | (* new block for if zero array *) |
| 140 | val ifZeroBlock = |
| 141 | let |
| 142 | val transfer = |
| 143 | Transfer.Goto |
| 144 | {args = Vector.new1 (getZeroArrVar eltTy), |
| 145 | dst = joinLab} |
| 146 | in |
| 147 | Block.T {label = ifZeroLab, |
| 148 | args = Vector.new0 (), |
| 149 | statements = Vector.new0 (), |
| 150 | transfer = transfer} |
| 151 | end |
| 152 | |
| 153 | (* new block for if non-zero array *) |
| 154 | val ifNonZeroBlock = |
| 155 | let |
| 156 | val arrVar' = Var.new arrVar |
| 157 | val statements = |
| 158 | Vector.new1 |
| 159 | (Statement.T |
| 160 | {var = SOME arrVar', |
| 161 | ty = arrTy, |
| 162 | exp = PrimApp |
| 163 | {args = Vector.new1 lenVar, |
| 164 | prim = Prim.arrayAlloc {raw = false}, |
| 165 | targs = Vector.new1 eltTy}}) |
| 166 | val transfer = |
| 167 | Transfer.Goto |
| 168 | {args = Vector.new1 arrVar', |
| 169 | dst = joinLab} |
| 170 | in |
| 171 | Block.T {label = ifNonZeroLab, |
| 172 | args = Vector.new0 (), |
| 173 | statements = statements, |
| 174 | transfer = transfer} |
| 175 | end |
| 176 | |
| 177 | (* new block with statements following match *) |
| 178 | val joinBlock = |
| 179 | Block.T {label = joinLab, |
| 180 | args = Vector.new1 (arrVar, arrTy), |
| 181 | statements = postStmts, |
| 182 | transfer = transfer} |
| 183 | in |
| 184 | SOME (preBlock, ifZeroBlock, ifNonZeroBlock, joinBlock) |
| 185 | end |
| 186 | |
| 187 | val functions = |
| 188 | List.revMap |
| 189 | (functions, fn f => |
| 190 | let |
| 191 | val {args, blocks, mayInline, name, raises, returns, start} = |
| 192 | Function.dest f |
| 193 | |
| 194 | (* analysis: compile a list of array vars cast to vectors *) |
| 195 | val arrVars = |
| 196 | Vector.fold |
| 197 | (blocks, [], fn (Block.T {statements, ...}, acc) => |
| 198 | Vector.fold |
| 199 | (statements, acc, fn (Statement.T {exp, ...}, acc) => |
| 200 | case exp of |
| 201 | PrimApp ({prim, args, ...}) => |
| 202 | (case Prim.name prim of |
| 203 | Prim.Name.Array_toVector => |
| 204 | (Vector.first args)::acc |
| 205 | | _ => acc) |
| 206 | | _ => acc)) |
| 207 | in |
| 208 | if List.isEmpty arrVars |
| 209 | then f (* no Array_toVector found in the function *) |
| 210 | else (* transformation: branch and join at Array_alloc *) |
| 211 | let |
| 212 | fun doBlock (b, acc) = |
| 213 | case transformBlock (b, arrVars) of |
| 214 | NONE => b::acc |
| 215 | | SOME (preBlock, |
| 216 | ifZeroBlock, ifNonZeroBlock, |
| 217 | joinBlock) => |
| 218 | doBlock (joinBlock, |
| 219 | ifNonZeroBlock::ifZeroBlock::preBlock::acc) |
| 220 | val blocks = Vector.fold (blocks, [], doBlock) |
| 221 | val blocks = Vector.fromListRev blocks |
| 222 | in |
| 223 | shrink (Function.new {args = args, |
| 224 | blocks = blocks, |
| 225 | mayInline = mayInline, |
| 226 | name = name, |
| 227 | raises = raises, |
| 228 | returns = returns, |
| 229 | start = start}) |
| 230 | end |
| 231 | end) |
| 232 | val globals = Vector.concat [globals, Vector.fromList (!newGlobals)] |
| 233 | in |
| 234 | Program.T {datatypes = datatypes, |
| 235 | globals = globals, |
| 236 | functions = functions, |
| 237 | main = main} |
| 238 | end |
| 239 | |
| 240 | end |