Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* Copyright (C) 1999-2006, 2008 Henry Cejtin, Matthew Fluet, Suresh |
2 | * Jagannathan, and Stephen Weeks. | |
3 | * Copyright (C) 1997-2000 NEC Research Institute. | |
4 | * | |
5 | * MLton is released under a BSD-style license. | |
6 | * See the file MLton-LICENSE for details. | |
7 | *) | |
8 | ||
9 | (* | |
10 | * Duplicate a let bound function at each variable reference | |
11 | * if cost is smaller than threshold. | |
12 | * | |
13 | *) | |
14 | functor Polyvariance (S: XML_TRANSFORM_STRUCTS): XML_TRANSFORM = | |
15 | struct | |
16 | ||
17 | open S | |
18 | datatype z = datatype Dec.t | |
19 | datatype z = datatype PrimExp.t | |
20 | ||
21 | structure Type = | |
22 | struct | |
23 | open Type | |
24 | ||
25 | fun containsArrow t = containsTycon (t, Tycon.arrow) | |
26 | ||
27 | fun isHigherOrder t = | |
28 | case deArrowOpt t of | |
29 | NONE => false | |
30 | | SOME (t1, t2) => containsArrow t1 orelse isHigherOrder t2 | |
31 | ||
32 | (* | |
33 | val isHigherOrder = | |
34 | Trace.trace | |
35 | ("Polyvariance.isHigherOrder", layout, Bool.layout) | |
36 | isHigherOrder | |
37 | *) | |
38 | ||
39 | end | |
40 | ||
41 | fun lambdaSize (Program.T {body, ...}): Lambda.t -> int = | |
42 | let | |
43 | val {get = size: Lambda.t -> int, set, ...} = | |
44 | Property.getSetOnce (Lambda.plist, | |
45 | Property.initRaise ("size", Lambda.layout)) | |
46 | fun loopExp (e: Exp.t, n: int): int = | |
47 | List.fold | |
48 | (Exp.decs e, n, fn (d, n) => | |
49 | case d of | |
50 | MonoVal {exp, ...} => loopPrimExp (exp, n + 1) | |
51 | | PolyVal {exp, ...} => loopExp (exp, n + 1) | |
52 | | Fun {decs, ...} => Vector.fold (decs, n, fn ({lambda, ...}, n) => | |
53 | loopLambda (lambda, n)) | |
54 | | Exception _ => n + 1) | |
55 | and loopLambda (l: Lambda.t, n): int = | |
56 | let val m = loopExp (Lambda.body l, 0) | |
57 | in set (l, m); m + n | |
58 | end | |
59 | and loopPrimExp (e: PrimExp.t, n: int): int = | |
60 | case e of | |
61 | Case {cases, default, ...} => | |
62 | let | |
63 | val n = n + 1 | |
64 | in | |
65 | Cases.fold | |
66 | (cases, | |
67 | (case default of | |
68 | NONE => n | |
69 | | SOME (e, _) => loopExp (e, n)), | |
70 | fn (e, n) => loopExp (e, n)) | |
71 | end | |
72 | | Handle {try, handler, ...} => | |
73 | loopExp (try, loopExp (handler, n + 1)) | |
74 | | Lambda l => loopLambda (l, n + 1) | |
75 | | Profile _ => n | |
76 | | _ => n + 1 | |
77 | val _ = loopExp (body, 0) | |
78 | in | |
79 | size | |
80 | end | |
81 | ||
82 | fun shouldDuplicate (program as Program.T {body, ...}, hofo, small, product) | |
83 | : Var.t -> bool = | |
84 | let | |
85 | val costs: (Var.t * int * int * int) list ref = ref [] | |
86 | val lambdaSize = lambdaSize program | |
87 | fun isOK (var: Var.t, size: int, numOccurrences: int): bool = | |
88 | let val cost = (numOccurrences - 1) * (size - small) | |
89 | in List.push (costs, (var, size, numOccurrences, cost)) | |
90 | ; cost <= product | |
91 | end | |
92 | type info = {numOccurrences: int ref, | |
93 | shouldDuplicate: bool ref} | |
94 | val {get = varInfo: Var.t -> info option, set = setVarInfo, ...} = | |
95 | Property.getSetOnce (Var.plist, Property.initConst NONE) | |
96 | fun new {lambda = _, ty, var}: unit = | |
97 | if not hofo orelse Type.isHigherOrder ty | |
98 | then setVarInfo (var, SOME {numOccurrences = ref 0, | |
99 | shouldDuplicate = ref false}) | |
100 | else () | |
101 | fun loopExp (e: Exp.t, numDuplicates: int): unit = | |
102 | let | |
103 | fun loopVar (x: VarExp.t): unit = | |
104 | case varInfo (VarExp.var x) of | |
105 | NONE => () | |
106 | | SOME {numOccurrences, ...} => | |
107 | numOccurrences := !numOccurrences + numDuplicates | |
108 | fun loopVars xs = Vector.foreach (xs, loopVar) | |
109 | val {decs, result} = Exp.dest e | |
110 | val rec loopDecs = | |
111 | fn [] => loopVar result | |
112 | | dec :: decs => | |
113 | case dec of | |
114 | MonoVal {var, ty, exp} => | |
115 | (case exp of | |
116 | Lambda l => | |
117 | (new {var = var, ty = ty, lambda = l} | |
118 | ; loopDecs decs | |
119 | ; let | |
120 | val body = Lambda.body l | |
121 | val numDuplicates = | |
122 | case varInfo var of | |
123 | NONE => numDuplicates | |
124 | | SOME {numOccurrences, | |
125 | shouldDuplicate} => | |
126 | if isOK (var, lambdaSize l, | |
127 | !numOccurrences) | |
128 | then (shouldDuplicate := true | |
129 | ; !numOccurrences) | |
130 | else numDuplicates | |
131 | in loopExp (body, numDuplicates) | |
132 | end) | |
133 | | _ => | |
134 | let | |
135 | val loopExp = | |
136 | fn e => loopExp (e, numDuplicates) | |
137 | val _ = | |
138 | case exp of | |
139 | App {func, arg} => | |
140 | (loopVar func; loopVar arg) | |
141 | | Case {test, cases, default} => | |
142 | (loopVar test | |
143 | ; Cases.foreach (cases, loopExp) | |
144 | ; (Option.app | |
145 | (default, loopExp o #1))) | |
146 | | ConApp {arg, ...} => | |
147 | Option.app (arg, loopVar) | |
148 | | Const _ => () | |
149 | | Handle {try, handler, ...} => | |
150 | (loopExp try; loopExp handler) | |
151 | | Lambda _ => | |
152 | Error.bug "Polyvariance.loopExp.loopDecs: unexpected Lambda" | |
153 | | PrimApp {args, ...} => loopVars args | |
154 | | Profile _ => () | |
155 | | Raise {exn, ...} => loopVar exn | |
156 | | Select {tuple, ...} => loopVar tuple | |
157 | | Tuple xs => loopVars xs | |
158 | | Var x => loopVar x | |
159 | in | |
160 | loopDecs decs | |
161 | end) | |
162 | | Fun {decs = lambdas, ...} => | |
163 | let | |
164 | val _ = (Vector.foreach (lambdas, new) | |
165 | ; loopDecs decs) | |
166 | val dups = | |
167 | Vector.fold | |
168 | (lambdas, [], fn ({var, lambda, ...}, dups) => | |
169 | let val body = Lambda.body lambda | |
170 | in case varInfo var of | |
171 | NONE => | |
172 | (loopExp (body, numDuplicates); dups) | |
173 | | SOME info => | |
174 | {body = body, | |
175 | size = lambdaSize lambda, | |
176 | info = info} :: dups | |
177 | end) | |
178 | in case dups of | |
179 | [] => () | |
180 | | _ => | |
181 | let | |
182 | val size = | |
183 | List.fold | |
184 | (dups, 0, fn ({size, ...}, n) => n + size) | |
185 | val numOccurrences = | |
186 | List.fold | |
187 | (dups, 0, | |
188 | fn ({info = {numOccurrences, ...}, ...}, | |
189 | n) => n + !numOccurrences) | |
190 | in if isOK (if Vector.isEmpty lambdas | |
191 | then Error.bug "Polyvariance.loopExp.loopDecs: empty lambdas" | |
192 | else | |
193 | #var (Vector.first lambdas), | |
194 | size, numOccurrences) | |
195 | then (List.foreach | |
196 | (dups, | |
197 | fn {body, | |
198 | info = {shouldDuplicate, ...}, | |
199 | ...} => | |
200 | (shouldDuplicate := true | |
201 | ; loopExp (body, numOccurrences)))) | |
202 | else | |
203 | List.foreach | |
204 | (dups, fn {body, ...} => | |
205 | loopExp (body, numDuplicates)) | |
206 | end | |
207 | end | |
208 | | _ => Error.bug "Polyvariance.loopExp.loopDecs: strange dec" | |
209 | in loopDecs decs | |
210 | end | |
211 | val _ = loopExp (body, 1) | |
212 | fun sort l = | |
213 | List.insertionSort (l, fn ((_, _, _, c), (_, _, _, c')) => c < c') | |
214 | val _ = | |
215 | Control.diagnostics | |
216 | (fn layout => | |
217 | List.foreach | |
218 | (sort (!costs), fn (x, size, numOcc, c) => | |
219 | layout (let open Layout | |
220 | in seq [Var.layout x, | |
221 | str " ", Int.layout size, | |
222 | str " ", Int.layout numOcc, | |
223 | str " ", Int.layout c] | |
224 | end))) | |
225 | in | |
226 | fn x => | |
227 | case varInfo x of | |
228 | NONE => false | |
229 | | SOME {shouldDuplicate, ...} => !shouldDuplicate | |
230 | end | |
231 | ||
232 | fun transform (program as Program.T {datatypes, body, overflow}, | |
233 | hofo: bool, | |
234 | small: int, | |
235 | product: int) = | |
236 | let | |
237 | val shouldDuplicate = shouldDuplicate (program, hofo, small, product) | |
238 | datatype info = | |
239 | Replace of Var.t | |
240 | | Dup of { | |
241 | duplicates: Var.t list ref | |
242 | } | |
243 | val {get = varInfo: Var.t -> info, set = setVarInfo, ...} = | |
244 | Property.getSet (Var.plist, | |
245 | Property.initRaise ("Polyvariance.info", Var.layout)) | |
246 | fun loopVar (x: VarExp.t): VarExp.t = | |
247 | VarExp.mono | |
248 | (let val x = VarExp.var x | |
249 | in case varInfo x of | |
250 | Replace y => y | |
251 | | Dup {duplicates, ...} => | |
252 | let val x' = Var.new x | |
253 | in List.push (duplicates, x') | |
254 | ; x' | |
255 | end | |
256 | end) | |
257 | fun loopVars xs = Vector.map (xs, loopVar) | |
258 | fun bind (x: Var.t): Var.t = | |
259 | let val x' = Var.new x | |
260 | in setVarInfo (x, Replace x') | |
261 | ; x' | |
262 | end | |
263 | fun bindVarType (x, t) = (bind x, t) | |
264 | fun bindPat (Pat.T {con, targs, arg}) = | |
265 | Pat.T {con = con, | |
266 | targs = targs, | |
267 | arg = Option.map (arg, bindVarType)} | |
268 | fun new {lambda = _, ty = _, var}: unit = | |
269 | if shouldDuplicate var | |
270 | then setVarInfo (var, Dup {duplicates = ref []}) | |
271 | else ignore (bind var) | |
272 | fun loopExp (e: Exp.t): Exp.t = | |
273 | let | |
274 | val {decs, result} = Exp.dest e | |
275 | in | |
276 | Exp.make (loopDecs (decs, result)) | |
277 | end | |
278 | and loopLambda (l: Lambda.t): Lambda.t = | |
279 | let | |
280 | val {arg, argType, body, mayInline} = Lambda.dest l | |
281 | in | |
282 | Lambda.make {arg = bind arg, | |
283 | argType = argType, | |
284 | body = loopExp body, | |
285 | mayInline = mayInline} | |
286 | end | |
287 | and loopDecs (ds: Dec.t list, result): {decs: Dec.t list, | |
288 | result: VarExp.t} = | |
289 | case ds of | |
290 | [] => {decs = [], result = loopVar result} | |
291 | | d :: ds => | |
292 | case d of | |
293 | MonoVal {var, ty, exp} => | |
294 | (case exp of | |
295 | Lambda l => | |
296 | let | |
297 | val _ = new {var = var, ty = ty, lambda = l} | |
298 | val {decs, result} = loopDecs (ds, result) | |
299 | val decs = | |
300 | case varInfo var of | |
301 | Replace var => | |
302 | MonoVal {var = var, ty = ty, | |
303 | exp = Lambda (loopLambda l)} | |
304 | :: decs | |
305 | | Dup {duplicates, ...} => | |
306 | List.fold | |
307 | (!duplicates, decs, fn (var, decs) => | |
308 | MonoVal {var = var, ty = ty, | |
309 | exp = Lambda (loopLambda l)} | |
310 | :: decs) | |
311 | in {decs = decs, result = result} | |
312 | end | |
313 | | _ => | |
314 | let | |
315 | val exp = | |
316 | case exp of | |
317 | App {func, arg} => | |
318 | App {func = loopVar func, | |
319 | arg = loopVar arg} | |
320 | | Case {test, cases, default} => | |
321 | let | |
322 | datatype z = datatype Cases.t | |
323 | val cases = | |
324 | case cases of | |
325 | Con cases => | |
326 | Con | |
327 | (Vector.map | |
328 | (cases, fn (p, e) => | |
329 | (bindPat p, loopExp e))) | |
330 | | Word (s, v) => | |
331 | Word | |
332 | (s, (Vector.map | |
333 | (v, fn (z, e) => | |
334 | (z, loopExp e)))) | |
335 | in | |
336 | Case {test = loopVar test, | |
337 | cases = cases, | |
338 | default = | |
339 | Option.map | |
340 | (default, fn (e, r) => | |
341 | (loopExp e, r))} | |
342 | end | |
343 | | ConApp {con, targs, arg} => | |
344 | ConApp {con = con, | |
345 | targs = targs, | |
346 | arg = Option.map (arg, loopVar)} | |
347 | | Const _ => exp | |
348 | | Handle {try, catch, handler} => | |
349 | Handle {try = loopExp try, | |
350 | catch = bindVarType catch, | |
351 | handler = loopExp handler} | |
352 | | Lambda _ => | |
353 | Error.bug "Polyvariance.loopDecs: unexpected Lambda" | |
354 | | PrimApp {prim, targs, args} => | |
355 | PrimApp {prim = prim, | |
356 | targs = targs, | |
357 | args = loopVars args} | |
358 | | Profile _ => exp | |
359 | | Raise {exn, extend} => | |
360 | Raise {exn = loopVar exn, | |
361 | extend = extend} | |
362 | | Select {tuple, offset} => | |
363 | Select {tuple = loopVar tuple, | |
364 | offset = offset} | |
365 | | Tuple xs => Tuple (loopVars xs) | |
366 | | Var x => Var (loopVar x) | |
367 | val var = bind var | |
368 | val {decs, result} = loopDecs (ds, result) | |
369 | in {decs = (MonoVal {var = var, ty = ty, exp = exp} | |
370 | :: decs), | |
371 | result = result} | |
372 | end) | |
373 | | Fun {decs, ...} => | |
374 | let | |
375 | val _ = Vector.foreach (decs, new) | |
376 | val {decs = ds, result} = loopDecs (ds, result) | |
377 | val ac = | |
378 | ref [Vector.keepAllMap | |
379 | (decs, fn {var, ty, lambda} => | |
380 | case varInfo var of | |
381 | Replace var => | |
382 | SOME {var = var, ty = ty, | |
383 | lambda = loopLambda lambda} | |
384 | | Dup _ => NONE)] | |
385 | val dups = | |
386 | Vector.keepAllMap | |
387 | (decs, fn dec as {var, ...} => | |
388 | case varInfo var of | |
389 | Replace _ => NONE | |
390 | | Dup {duplicates, ...} => SOME (dec, !duplicates)) | |
391 | val _ = | |
392 | Vector.foreach | |
393 | (dups, fn ({var, ...}, duplicates) => | |
394 | List.foreach | |
395 | (duplicates, fn var' => | |
396 | let | |
397 | val vars = | |
398 | Vector.map | |
399 | (dups, fn ({var = var'', ...}, _) => | |
400 | if Var.equals (var, var'') | |
401 | then (setVarInfo (var, Replace var') | |
402 | ; var') | |
403 | else bind var'') | |
404 | in List.push | |
405 | (ac, | |
406 | Vector.map2 | |
407 | (dups, vars, | |
408 | fn (({ty, lambda, ...}, _), var) => | |
409 | {var = var, ty = ty, | |
410 | lambda = loopLambda lambda})) | |
411 | end)) | |
412 | val decs = Vector.concat (!ac) | |
413 | in {decs = Fun {tyvars = Vector.new0 (), | |
414 | decs = decs} :: ds, | |
415 | result = result} | |
416 | end | |
417 | | _ => Error.bug "Polyvariance.loopDecs: saw bogus dec" | |
418 | val body = loopExp body | |
419 | val overflow = | |
420 | Option.map (overflow, fn x => | |
421 | case varInfo x of | |
422 | Replace y => y | |
423 | | _ => Error.bug "Polyvariance.duplicate: duplicating Overflow?") | |
424 | val program = | |
425 | Program.T {datatypes = datatypes, | |
426 | body = body, | |
427 | overflow = overflow} | |
428 | val _ = Program.clear program | |
429 | in | |
430 | program | |
431 | end | |
432 | ||
433 | val transform = | |
434 | fn p => | |
435 | case !Control.polyvariance of | |
436 | NONE => p | |
437 | | SOME {hofo, rounds, small, product} => | |
438 | let | |
439 | fun loop (p, n) = | |
440 | if n = rounds | |
441 | then p | |
442 | else let | |
443 | val p = | |
444 | Control.pass | |
445 | {display = Control.Layouts Program.layouts, | |
446 | name = "duplicate" ^ (Int.toString (n + 1)), | |
447 | stats = Program.layoutStats, | |
448 | style = Control.No, | |
449 | suffix = "post.xml", | |
450 | thunk = fn () => shrink (transform (p, hofo, small, product))} | |
451 | in | |
452 | loop (p, n + 1) | |
453 | end | |
454 | in loop (p, 0) | |
455 | end | |
456 | ||
457 | end |