Import Upstream version 20180207
[hcoop/debian/mlton.git] / mlton / ssa / share-zero-vec.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 2017 Matthew Fluet.
2 *
3 * MLton is released under a BSD-style license.
4 * See the file MLton-LICENSE for details.
5 *)
6
7functor ShareZeroVec (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
8struct
9
10open S
11open Exp
12
13fun transform (Program.T {datatypes, globals, functions, main}) =
14 let
15 val seqIndexSize = WordSize.seqIndex ()
16 val seqIndexTy = Type.word seqIndexSize
17 val (zeroVar, globals) =
18 case Vector.peekMap (globals, fn Statement.T {var, ty, exp} =>
19 case (var, exp) of
20 (SOME var, Exp.Const (Const.Word w)) =>
21 if WordX.isZero w
22 andalso Type.equals (seqIndexTy, ty)
23 then SOME var
24 else NONE
25 | _ => NONE) of
26 SOME zeroVar => (zeroVar, globals)
27 | _ => let
28 val zeroVar = Var.newString "zero"
29 val zeroVarStmt =
30 Statement.T
31 {var = SOME zeroVar,
32 ty = seqIndexTy,
33 exp = Exp.Const (Const.word (WordX.zero seqIndexSize))}
34 in
35 (zeroVar, Vector.concat [globals, Vector.new1 zeroVarStmt])
36 end
37
38 val shrink = shrinkFunction {globals = globals}
39
40 (* initialize a HashSet for new zero-length array globals *)
41 val newGlobals = ref []
42 local
43 val hs: {eltTy: Type.t, zeroArrVar: Var.t} HashSet.t =
44 HashSet.new {hash = fn {eltTy, ...} => Type.hash eltTy}
45 in
46 fun getZeroArrVar (ty: Type.t): Var.t =
47 let
48 val {zeroArrVar, ...} =
49 HashSet.lookupOrInsert
50 (hs, Type.hash ty,
51 fn {eltTy, ...} => Type.equals (eltTy, ty),
52 fn () =>
53 let
54 val zeroArrVar = Var.newString "zeroArr"
55 val statement =
56 Statement.T
57 {var = SOME zeroArrVar,
58 ty = Type.array ty,
59 exp = PrimApp
60 {args = Vector.new1 zeroVar,
61 prim = Prim.arrayAlloc {raw = false},
62 targs = Vector.new1 ty}}
63 val () = List.push (newGlobals, statement)
64 in
65 {eltTy = ty,
66 zeroArrVar = zeroArrVar}
67 end)
68 in
69 zeroArrVar
70 end
71 end
72
73 (* splitStmts (stmts, arrVars)
74 * returns (preStmts, (arrVar, arrTy, eltTy, lenVar), postStmts)
75 * when stmts = ...pre...
76 * val arrVar: arrTy = Array_alloc(eltTy) (lenVar)
77 * ...post...
78 * and arrVar in arrVars
79 *)
80 fun splitStmts (stmts, arrVars) =
81 case Vector.peekMapi
82 (stmts, fn Statement.T {var, ty, exp} =>
83 case exp of
84 PrimApp ({prim, args, targs}) =>
85 (case (var, Prim.name prim) of
86 (SOME var, Prim.Name.Array_alloc {raw = false}) =>
87 if List.contains (arrVars, var, Var.equals)
88 then SOME (var, ty,
89 Vector.first targs,
90 Vector.first args)
91 else NONE
92 | _ => NONE)
93 | _ => NONE) of
94 NONE => NONE
95 | SOME (i, (arrVar, arrTy, eltTy, lenVar)) =>
96 SOME (Vector.prefix (stmts, i),
97 (* val arrVar: arrTy = Array_alloc(eltTy) (lenVar) *)
98 (arrVar, arrTy, eltTy, lenVar),
99 Vector.dropPrefix (stmts, i + 1))
100
101 fun transformBlock (block, arrVars) =
102 case splitStmts (Block.statements block, arrVars) of
103 NONE => NONE
104 | SOME (preStmts, (arrVar, arrTy, eltTy, lenVar), postStmts) =>
105 let
106 val Block.T {label, args, transfer, ...} = block
107 val ifZeroLab = Label.newString "L_zeroLen"
108 val ifNonZeroLab = Label.newString "L_nonZeroLen"
109 val joinLab = Label.newString "L_join"
110
111 (* new block up to Array_alloc match *)
112 val preBlock =
113 let
114 val isZeroVar = Var.newString "isZero"
115 val newStatements =
116 Vector.new1
117 (Statement.T
118 {var = SOME isZeroVar,
119 ty = Type.bool,
120 exp = PrimApp
121 {args = Vector.new2 (zeroVar, lenVar),
122 prim = Prim.wordEqual seqIndexSize,
123 targs = Vector.new0 ()}})
124 val transfer =
125 Transfer.Case
126 {cases = (Cases.Con o Vector.new2)
127 ((Con.truee, ifZeroLab),
128 (Con.falsee, ifNonZeroLab)),
129 default = NONE,
130 test = isZeroVar}
131 in
132 Block.T {label = label,
133 args = args,
134 statements = Vector.concat [preStmts,
135 newStatements],
136 transfer = transfer}
137 end
138
139 (* new block for if zero array *)
140 val ifZeroBlock =
141 let
142 val transfer =
143 Transfer.Goto
144 {args = Vector.new1 (getZeroArrVar eltTy),
145 dst = joinLab}
146 in
147 Block.T {label = ifZeroLab,
148 args = Vector.new0 (),
149 statements = Vector.new0 (),
150 transfer = transfer}
151 end
152
153 (* new block for if non-zero array *)
154 val ifNonZeroBlock =
155 let
156 val arrVar' = Var.new arrVar
157 val statements =
158 Vector.new1
159 (Statement.T
160 {var = SOME arrVar',
161 ty = arrTy,
162 exp = PrimApp
163 {args = Vector.new1 lenVar,
164 prim = Prim.arrayAlloc {raw = false},
165 targs = Vector.new1 eltTy}})
166 val transfer =
167 Transfer.Goto
168 {args = Vector.new1 arrVar',
169 dst = joinLab}
170 in
171 Block.T {label = ifNonZeroLab,
172 args = Vector.new0 (),
173 statements = statements,
174 transfer = transfer}
175 end
176
177 (* new block with statements following match *)
178 val joinBlock =
179 Block.T {label = joinLab,
180 args = Vector.new1 (arrVar, arrTy),
181 statements = postStmts,
182 transfer = transfer}
183 in
184 SOME (preBlock, ifZeroBlock, ifNonZeroBlock, joinBlock)
185 end
186
187 val functions =
188 List.revMap
189 (functions, fn f =>
190 let
191 val {args, blocks, mayInline, name, raises, returns, start} =
192 Function.dest f
193
194 (* analysis: compile a list of array vars cast to vectors *)
195 val arrVars =
196 Vector.fold
197 (blocks, [], fn (Block.T {statements, ...}, acc) =>
198 Vector.fold
199 (statements, acc, fn (Statement.T {exp, ...}, acc) =>
200 case exp of
201 PrimApp ({prim, args, ...}) =>
202 (case Prim.name prim of
203 Prim.Name.Array_toVector =>
204 (Vector.first args)::acc
205 | _ => acc)
206 | _ => acc))
207 in
208 if List.isEmpty arrVars
209 then f (* no Array_toVector found in the function *)
210 else (* transformation: branch and join at Array_alloc *)
211 let
212 fun doBlock (b, acc) =
213 case transformBlock (b, arrVars) of
214 NONE => b::acc
215 | SOME (preBlock,
216 ifZeroBlock, ifNonZeroBlock,
217 joinBlock) =>
218 doBlock (joinBlock,
219 ifNonZeroBlock::ifZeroBlock::preBlock::acc)
220 val blocks = Vector.fold (blocks, [], doBlock)
221 val blocks = Vector.fromListRev blocks
222 in
223 shrink (Function.new {args = args,
224 blocks = blocks,
225 mayInline = mayInline,
226 name = name,
227 raises = raises,
228 returns = returns,
229 start = start})
230 end
231 end)
232 val globals = Vector.concat [globals, Vector.fromList (!newGlobals)]
233 in
234 Program.T {datatypes = datatypes,
235 globals = globals,
236 functions = functions,
237 main = main}
238 end
239
240end