Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* Copyright (C) 2017 Matthew Fluet. |
2 | * | |
3 | * MLton is released under a BSD-style license. | |
4 | * See the file MLton-LICENSE for details. | |
5 | *) | |
6 | ||
7 | functor ShareZeroVec (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM = | |
8 | struct | |
9 | ||
10 | open S | |
11 | open Exp | |
12 | ||
13 | fun transform (Program.T {datatypes, globals, functions, main}) = | |
14 | let | |
15 | val seqIndexSize = WordSize.seqIndex () | |
16 | val seqIndexTy = Type.word seqIndexSize | |
17 | val (zeroVar, globals) = | |
18 | case Vector.peekMap (globals, fn Statement.T {var, ty, exp} => | |
19 | case (var, exp) of | |
20 | (SOME var, Exp.Const (Const.Word w)) => | |
21 | if WordX.isZero w | |
22 | andalso Type.equals (seqIndexTy, ty) | |
23 | then SOME var | |
24 | else NONE | |
25 | | _ => NONE) of | |
26 | SOME zeroVar => (zeroVar, globals) | |
27 | | _ => let | |
28 | val zeroVar = Var.newString "zero" | |
29 | val zeroVarStmt = | |
30 | Statement.T | |
31 | {var = SOME zeroVar, | |
32 | ty = seqIndexTy, | |
33 | exp = Exp.Const (Const.word (WordX.zero seqIndexSize))} | |
34 | in | |
35 | (zeroVar, Vector.concat [globals, Vector.new1 zeroVarStmt]) | |
36 | end | |
37 | ||
38 | val shrink = shrinkFunction {globals = globals} | |
39 | ||
40 | (* initialize a HashSet for new zero-length array globals *) | |
41 | val newGlobals = ref [] | |
42 | local | |
43 | val hs: {eltTy: Type.t, zeroArrVar: Var.t} HashSet.t = | |
44 | HashSet.new {hash = fn {eltTy, ...} => Type.hash eltTy} | |
45 | in | |
46 | fun getZeroArrVar (ty: Type.t): Var.t = | |
47 | let | |
48 | val {zeroArrVar, ...} = | |
49 | HashSet.lookupOrInsert | |
50 | (hs, Type.hash ty, | |
51 | fn {eltTy, ...} => Type.equals (eltTy, ty), | |
52 | fn () => | |
53 | let | |
54 | val zeroArrVar = Var.newString "zeroArr" | |
55 | val statement = | |
56 | Statement.T | |
57 | {var = SOME zeroArrVar, | |
58 | ty = Type.array ty, | |
59 | exp = PrimApp | |
60 | {args = Vector.new1 zeroVar, | |
61 | prim = Prim.arrayAlloc {raw = false}, | |
62 | targs = Vector.new1 ty}} | |
63 | val () = List.push (newGlobals, statement) | |
64 | in | |
65 | {eltTy = ty, | |
66 | zeroArrVar = zeroArrVar} | |
67 | end) | |
68 | in | |
69 | zeroArrVar | |
70 | end | |
71 | end | |
72 | ||
73 | (* splitStmts (stmts, arrVars) | |
74 | * returns (preStmts, (arrVar, arrTy, eltTy, lenVar), postStmts) | |
75 | * when stmts = ...pre... | |
76 | * val arrVar: arrTy = Array_alloc(eltTy) (lenVar) | |
77 | * ...post... | |
78 | * and arrVar in arrVars | |
79 | *) | |
80 | fun splitStmts (stmts, arrVars) = | |
81 | case Vector.peekMapi | |
82 | (stmts, fn Statement.T {var, ty, exp} => | |
83 | case exp of | |
84 | PrimApp ({prim, args, targs}) => | |
85 | (case (var, Prim.name prim) of | |
86 | (SOME var, Prim.Name.Array_alloc {raw = false}) => | |
87 | if List.contains (arrVars, var, Var.equals) | |
88 | then SOME (var, ty, | |
89 | Vector.first targs, | |
90 | Vector.first args) | |
91 | else NONE | |
92 | | _ => NONE) | |
93 | | _ => NONE) of | |
94 | NONE => NONE | |
95 | | SOME (i, (arrVar, arrTy, eltTy, lenVar)) => | |
96 | SOME (Vector.prefix (stmts, i), | |
97 | (* val arrVar: arrTy = Array_alloc(eltTy) (lenVar) *) | |
98 | (arrVar, arrTy, eltTy, lenVar), | |
99 | Vector.dropPrefix (stmts, i + 1)) | |
100 | ||
101 | fun transformBlock (block, arrVars) = | |
102 | case splitStmts (Block.statements block, arrVars) of | |
103 | NONE => NONE | |
104 | | SOME (preStmts, (arrVar, arrTy, eltTy, lenVar), postStmts) => | |
105 | let | |
106 | val Block.T {label, args, transfer, ...} = block | |
107 | val ifZeroLab = Label.newString "L_zeroLen" | |
108 | val ifNonZeroLab = Label.newString "L_nonZeroLen" | |
109 | val joinLab = Label.newString "L_join" | |
110 | ||
111 | (* new block up to Array_alloc match *) | |
112 | val preBlock = | |
113 | let | |
114 | val isZeroVar = Var.newString "isZero" | |
115 | val newStatements = | |
116 | Vector.new1 | |
117 | (Statement.T | |
118 | {var = SOME isZeroVar, | |
119 | ty = Type.bool, | |
120 | exp = PrimApp | |
121 | {args = Vector.new2 (zeroVar, lenVar), | |
122 | prim = Prim.wordEqual seqIndexSize, | |
123 | targs = Vector.new0 ()}}) | |
124 | val transfer = | |
125 | Transfer.Case | |
126 | {cases = (Cases.Con o Vector.new2) | |
127 | ((Con.truee, ifZeroLab), | |
128 | (Con.falsee, ifNonZeroLab)), | |
129 | default = NONE, | |
130 | test = isZeroVar} | |
131 | in | |
132 | Block.T {label = label, | |
133 | args = args, | |
134 | statements = Vector.concat [preStmts, | |
135 | newStatements], | |
136 | transfer = transfer} | |
137 | end | |
138 | ||
139 | (* new block for if zero array *) | |
140 | val ifZeroBlock = | |
141 | let | |
142 | val transfer = | |
143 | Transfer.Goto | |
144 | {args = Vector.new1 (getZeroArrVar eltTy), | |
145 | dst = joinLab} | |
146 | in | |
147 | Block.T {label = ifZeroLab, | |
148 | args = Vector.new0 (), | |
149 | statements = Vector.new0 (), | |
150 | transfer = transfer} | |
151 | end | |
152 | ||
153 | (* new block for if non-zero array *) | |
154 | val ifNonZeroBlock = | |
155 | let | |
156 | val arrVar' = Var.new arrVar | |
157 | val statements = | |
158 | Vector.new1 | |
159 | (Statement.T | |
160 | {var = SOME arrVar', | |
161 | ty = arrTy, | |
162 | exp = PrimApp | |
163 | {args = Vector.new1 lenVar, | |
164 | prim = Prim.arrayAlloc {raw = false}, | |
165 | targs = Vector.new1 eltTy}}) | |
166 | val transfer = | |
167 | Transfer.Goto | |
168 | {args = Vector.new1 arrVar', | |
169 | dst = joinLab} | |
170 | in | |
171 | Block.T {label = ifNonZeroLab, | |
172 | args = Vector.new0 (), | |
173 | statements = statements, | |
174 | transfer = transfer} | |
175 | end | |
176 | ||
177 | (* new block with statements following match *) | |
178 | val joinBlock = | |
179 | Block.T {label = joinLab, | |
180 | args = Vector.new1 (arrVar, arrTy), | |
181 | statements = postStmts, | |
182 | transfer = transfer} | |
183 | in | |
184 | SOME (preBlock, ifZeroBlock, ifNonZeroBlock, joinBlock) | |
185 | end | |
186 | ||
187 | val functions = | |
188 | List.revMap | |
189 | (functions, fn f => | |
190 | let | |
191 | val {args, blocks, mayInline, name, raises, returns, start} = | |
192 | Function.dest f | |
193 | ||
194 | (* analysis: compile a list of array vars cast to vectors *) | |
195 | val arrVars = | |
196 | Vector.fold | |
197 | (blocks, [], fn (Block.T {statements, ...}, acc) => | |
198 | Vector.fold | |
199 | (statements, acc, fn (Statement.T {exp, ...}, acc) => | |
200 | case exp of | |
201 | PrimApp ({prim, args, ...}) => | |
202 | (case Prim.name prim of | |
203 | Prim.Name.Array_toVector => | |
204 | (Vector.first args)::acc | |
205 | | _ => acc) | |
206 | | _ => acc)) | |
207 | in | |
208 | if List.isEmpty arrVars | |
209 | then f (* no Array_toVector found in the function *) | |
210 | else (* transformation: branch and join at Array_alloc *) | |
211 | let | |
212 | fun doBlock (b, acc) = | |
213 | case transformBlock (b, arrVars) of | |
214 | NONE => b::acc | |
215 | | SOME (preBlock, | |
216 | ifZeroBlock, ifNonZeroBlock, | |
217 | joinBlock) => | |
218 | doBlock (joinBlock, | |
219 | ifNonZeroBlock::ifZeroBlock::preBlock::acc) | |
220 | val blocks = Vector.fold (blocks, [], doBlock) | |
221 | val blocks = Vector.fromListRev blocks | |
222 | in | |
223 | shrink (Function.new {args = args, | |
224 | blocks = blocks, | |
225 | mayInline = mayInline, | |
226 | name = name, | |
227 | raises = raises, | |
228 | returns = returns, | |
229 | start = start}) | |
230 | end | |
231 | end) | |
232 | val globals = Vector.concat [globals, Vector.fromList (!newGlobals)] | |
233 | in | |
234 | Program.T {datatypes = datatypes, | |
235 | globals = globals, | |
236 | functions = functions, | |
237 | main = main} | |
238 | end | |
239 | ||
240 | end |