1 (* Copyright (C) 2017 Matthew Fluet.
3 * MLton is released under a BSD-style license.
4 * See the file MLton-LICENSE for details.
7 functor ShareZeroVec (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
13 fun transform (Program.T {datatypes, globals, functions, main}) =
15 val seqIndexSize = WordSize.seqIndex ()
16 val seqIndexTy = Type.word seqIndexSize
17 val (zeroVar, globals) =
18 case Vector.peekMap (globals, fn Statement.T {var, ty, exp} =>
20 (SOME var, Exp.Const (Const.Word w)) =>
22 andalso Type.equals (seqIndexTy, ty)
26 SOME zeroVar => (zeroVar, globals)
28 val zeroVar = Var.newString "zero"
33 exp = Exp.Const (Const.word (WordX.zero seqIndexSize))}
35 (zeroVar, Vector.concat [globals, Vector.new1 zeroVarStmt])
38 val shrink = shrinkFunction {globals = globals}
40 (* initialize a HashSet for new zero-length array globals *)
41 val newGlobals = ref []
43 val hs: {eltTy: Type.t, zeroArrVar: Var.t} HashSet.t =
44 HashSet.new {hash = fn {eltTy, ...} => Type.hash eltTy}
46 fun getZeroArrVar (ty: Type.t): Var.t =
48 val {zeroArrVar, ...} =
49 HashSet.lookupOrInsert
51 fn {eltTy, ...} => Type.equals (eltTy, ty),
54 val zeroArrVar = Var.newString "zeroArr"
57 {var = SOME zeroArrVar,
60 {args = Vector.new1 zeroVar,
61 prim = Prim.arrayAlloc {raw = false},
62 targs = Vector.new1 ty}}
63 val () = List.push (newGlobals, statement)
66 zeroArrVar = zeroArrVar}
73 (* splitStmts (stmts, arrVars)
74 * returns (preStmts, (arrVar, arrTy, eltTy, lenVar), postStmts)
75 * when stmts = ...pre...
76 * val arrVar: arrTy = Array_alloc(eltTy) (lenVar)
78 * and arrVar in arrVars
80 fun splitStmts (stmts, arrVars) =
82 (stmts, fn Statement.T {var, ty, exp} =>
84 PrimApp ({prim, args, targs}) =>
85 (case (var, Prim.name prim) of
86 (SOME var, Prim.Name.Array_alloc {raw = false}) =>
87 if List.contains (arrVars, var, Var.equals)
95 | SOME (i, (arrVar, arrTy, eltTy, lenVar)) =>
96 SOME (Vector.prefix (stmts, i),
97 (* val arrVar: arrTy = Array_alloc(eltTy) (lenVar) *)
98 (arrVar, arrTy, eltTy, lenVar),
99 Vector.dropPrefix (stmts, i + 1))
101 fun transformBlock (block, arrVars) =
102 case splitStmts (Block.statements block, arrVars) of
104 | SOME (preStmts, (arrVar, arrTy, eltTy, lenVar), postStmts) =>
106 val Block.T {label, args, transfer, ...} = block
107 val ifZeroLab = Label.newString "L_zeroLen"
108 val ifNonZeroLab = Label.newString "L_nonZeroLen"
109 val joinLab = Label.newString "L_join"
111 (* new block up to Array_alloc match *)
114 val isZeroVar = Var.newString "isZero"
118 {var = SOME isZeroVar,
121 {args = Vector.new2 (zeroVar, lenVar),
122 prim = Prim.wordEqual seqIndexSize,
123 targs = Vector.new0 ()}})
126 {cases = (Cases.Con o Vector.new2)
127 ((Con.truee, ifZeroLab),
128 (Con.falsee, ifNonZeroLab)),
132 Block.T {label = label,
134 statements = Vector.concat [preStmts,
139 (* new block for if zero array *)
144 {args = Vector.new1 (getZeroArrVar eltTy),
147 Block.T {label = ifZeroLab,
148 args = Vector.new0 (),
149 statements = Vector.new0 (),
153 (* new block for if non-zero array *)
156 val arrVar' = Var.new arrVar
163 {args = Vector.new1 lenVar,
164 prim = Prim.arrayAlloc {raw = false},
165 targs = Vector.new1 eltTy}})
168 {args = Vector.new1 arrVar',
171 Block.T {label = ifNonZeroLab,
172 args = Vector.new0 (),
173 statements = statements,
177 (* new block with statements following match *)
179 Block.T {label = joinLab,
180 args = Vector.new1 (arrVar, arrTy),
181 statements = postStmts,
184 SOME (preBlock, ifZeroBlock, ifNonZeroBlock, joinBlock)
191 val {args, blocks, mayInline, name, raises, returns, start} =
194 (* analysis: compile a list of array vars cast to vectors *)
197 (blocks, [], fn (Block.T {statements, ...}, acc) =>
199 (statements, acc, fn (Statement.T {exp, ...}, acc) =>
201 PrimApp ({prim, args, ...}) =>
202 (case Prim.name prim of
203 Prim.Name.Array_toVector =>
204 (Vector.first args)::acc
208 if List.isEmpty arrVars
209 then f (* no Array_toVector found in the function *)
210 else (* transformation: branch and join at Array_alloc *)
212 fun doBlock (b, acc) =
213 case transformBlock (b, arrVars) of
216 ifZeroBlock, ifNonZeroBlock,
219 ifNonZeroBlock::ifZeroBlock::preBlock::acc)
220 val blocks = Vector.fold (blocks, [], doBlock)
221 val blocks = Vector.fromListRev blocks
223 shrink (Function.new {args = args,
225 mayInline = mayInline,
232 val globals = Vector.concat [globals, Vector.fromList (!newGlobals)]
234 Program.T {datatypes = datatypes,
236 functions = functions,