Import Upstream version 20180207
[hcoop/debian/mlton.git] / mlton / xml / polyvariance.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 1999-2006, 2008 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 * Copyright (C) 1997-2000 NEC Research Institute.
4 *
5 * MLton is released under a BSD-style license.
6 * See the file MLton-LICENSE for details.
7 *)
8
9(*
10 * Duplicate a let bound function at each variable reference
11 * if cost is smaller than threshold.
12 *
13 *)
14functor Polyvariance (S: XML_TRANSFORM_STRUCTS): XML_TRANSFORM =
15struct
16
17open S
18datatype z = datatype Dec.t
19datatype z = datatype PrimExp.t
20
21structure Type =
22 struct
23 open Type
24
25 fun containsArrow t = containsTycon (t, Tycon.arrow)
26
27 fun isHigherOrder t =
28 case deArrowOpt t of
29 NONE => false
30 | SOME (t1, t2) => containsArrow t1 orelse isHigherOrder t2
31
32(*
33 val isHigherOrder =
34 Trace.trace
35 ("Polyvariance.isHigherOrder", layout, Bool.layout)
36 isHigherOrder
37*)
38
39 end
40
41fun lambdaSize (Program.T {body, ...}): Lambda.t -> int =
42 let
43 val {get = size: Lambda.t -> int, set, ...} =
44 Property.getSetOnce (Lambda.plist,
45 Property.initRaise ("size", Lambda.layout))
46 fun loopExp (e: Exp.t, n: int): int =
47 List.fold
48 (Exp.decs e, n, fn (d, n) =>
49 case d of
50 MonoVal {exp, ...} => loopPrimExp (exp, n + 1)
51 | PolyVal {exp, ...} => loopExp (exp, n + 1)
52 | Fun {decs, ...} => Vector.fold (decs, n, fn ({lambda, ...}, n) =>
53 loopLambda (lambda, n))
54 | Exception _ => n + 1)
55 and loopLambda (l: Lambda.t, n): int =
56 let val m = loopExp (Lambda.body l, 0)
57 in set (l, m); m + n
58 end
59 and loopPrimExp (e: PrimExp.t, n: int): int =
60 case e of
61 Case {cases, default, ...} =>
62 let
63 val n = n + 1
64 in
65 Cases.fold
66 (cases,
67 (case default of
68 NONE => n
69 | SOME (e, _) => loopExp (e, n)),
70 fn (e, n) => loopExp (e, n))
71 end
72 | Handle {try, handler, ...} =>
73 loopExp (try, loopExp (handler, n + 1))
74 | Lambda l => loopLambda (l, n + 1)
75 | Profile _ => n
76 | _ => n + 1
77 val _ = loopExp (body, 0)
78 in
79 size
80 end
81
82fun shouldDuplicate (program as Program.T {body, ...}, hofo, small, product)
83 : Var.t -> bool =
84 let
85 val costs: (Var.t * int * int * int) list ref = ref []
86 val lambdaSize = lambdaSize program
87 fun isOK (var: Var.t, size: int, numOccurrences: int): bool =
88 let val cost = (numOccurrences - 1) * (size - small)
89 in List.push (costs, (var, size, numOccurrences, cost))
90 ; cost <= product
91 end
92 type info = {numOccurrences: int ref,
93 shouldDuplicate: bool ref}
94 val {get = varInfo: Var.t -> info option, set = setVarInfo, ...} =
95 Property.getSetOnce (Var.plist, Property.initConst NONE)
96 fun new {lambda = _, ty, var}: unit =
97 if not hofo orelse Type.isHigherOrder ty
98 then setVarInfo (var, SOME {numOccurrences = ref 0,
99 shouldDuplicate = ref false})
100 else ()
101 fun loopExp (e: Exp.t, numDuplicates: int): unit =
102 let
103 fun loopVar (x: VarExp.t): unit =
104 case varInfo (VarExp.var x) of
105 NONE => ()
106 | SOME {numOccurrences, ...} =>
107 numOccurrences := !numOccurrences + numDuplicates
108 fun loopVars xs = Vector.foreach (xs, loopVar)
109 val {decs, result} = Exp.dest e
110 val rec loopDecs =
111 fn [] => loopVar result
112 | dec :: decs =>
113 case dec of
114 MonoVal {var, ty, exp} =>
115 (case exp of
116 Lambda l =>
117 (new {var = var, ty = ty, lambda = l}
118 ; loopDecs decs
119 ; let
120 val body = Lambda.body l
121 val numDuplicates =
122 case varInfo var of
123 NONE => numDuplicates
124 | SOME {numOccurrences,
125 shouldDuplicate} =>
126 if isOK (var, lambdaSize l,
127 !numOccurrences)
128 then (shouldDuplicate := true
129 ; !numOccurrences)
130 else numDuplicates
131 in loopExp (body, numDuplicates)
132 end)
133 | _ =>
134 let
135 val loopExp =
136 fn e => loopExp (e, numDuplicates)
137 val _ =
138 case exp of
139 App {func, arg} =>
140 (loopVar func; loopVar arg)
141 | Case {test, cases, default} =>
142 (loopVar test
143 ; Cases.foreach (cases, loopExp)
144 ; (Option.app
145 (default, loopExp o #1)))
146 | ConApp {arg, ...} =>
147 Option.app (arg, loopVar)
148 | Const _ => ()
149 | Handle {try, handler, ...} =>
150 (loopExp try; loopExp handler)
151 | Lambda _ =>
152 Error.bug "Polyvariance.loopExp.loopDecs: unexpected Lambda"
153 | PrimApp {args, ...} => loopVars args
154 | Profile _ => ()
155 | Raise {exn, ...} => loopVar exn
156 | Select {tuple, ...} => loopVar tuple
157 | Tuple xs => loopVars xs
158 | Var x => loopVar x
159 in
160 loopDecs decs
161 end)
162 | Fun {decs = lambdas, ...} =>
163 let
164 val _ = (Vector.foreach (lambdas, new)
165 ; loopDecs decs)
166 val dups =
167 Vector.fold
168 (lambdas, [], fn ({var, lambda, ...}, dups) =>
169 let val body = Lambda.body lambda
170 in case varInfo var of
171 NONE =>
172 (loopExp (body, numDuplicates); dups)
173 | SOME info =>
174 {body = body,
175 size = lambdaSize lambda,
176 info = info} :: dups
177 end)
178 in case dups of
179 [] => ()
180 | _ =>
181 let
182 val size =
183 List.fold
184 (dups, 0, fn ({size, ...}, n) => n + size)
185 val numOccurrences =
186 List.fold
187 (dups, 0,
188 fn ({info = {numOccurrences, ...}, ...},
189 n) => n + !numOccurrences)
190 in if isOK (if Vector.isEmpty lambdas
191 then Error.bug "Polyvariance.loopExp.loopDecs: empty lambdas"
192 else
193 #var (Vector.first lambdas),
194 size, numOccurrences)
195 then (List.foreach
196 (dups,
197 fn {body,
198 info = {shouldDuplicate, ...},
199 ...} =>
200 (shouldDuplicate := true
201 ; loopExp (body, numOccurrences))))
202 else
203 List.foreach
204 (dups, fn {body, ...} =>
205 loopExp (body, numDuplicates))
206 end
207 end
208 | _ => Error.bug "Polyvariance.loopExp.loopDecs: strange dec"
209 in loopDecs decs
210 end
211 val _ = loopExp (body, 1)
212 fun sort l =
213 List.insertionSort (l, fn ((_, _, _, c), (_, _, _, c')) => c < c')
214 val _ =
215 Control.diagnostics
216 (fn layout =>
217 List.foreach
218 (sort (!costs), fn (x, size, numOcc, c) =>
219 layout (let open Layout
220 in seq [Var.layout x,
221 str " ", Int.layout size,
222 str " ", Int.layout numOcc,
223 str " ", Int.layout c]
224 end)))
225 in
226 fn x =>
227 case varInfo x of
228 NONE => false
229 | SOME {shouldDuplicate, ...} => !shouldDuplicate
230 end
231
232fun transform (program as Program.T {datatypes, body, overflow},
233 hofo: bool,
234 small: int,
235 product: int) =
236 let
237 val shouldDuplicate = shouldDuplicate (program, hofo, small, product)
238 datatype info =
239 Replace of Var.t
240 | Dup of {
241 duplicates: Var.t list ref
242 }
243 val {get = varInfo: Var.t -> info, set = setVarInfo, ...} =
244 Property.getSet (Var.plist,
245 Property.initRaise ("Polyvariance.info", Var.layout))
246 fun loopVar (x: VarExp.t): VarExp.t =
247 VarExp.mono
248 (let val x = VarExp.var x
249 in case varInfo x of
250 Replace y => y
251 | Dup {duplicates, ...} =>
252 let val x' = Var.new x
253 in List.push (duplicates, x')
254 ; x'
255 end
256 end)
257 fun loopVars xs = Vector.map (xs, loopVar)
258 fun bind (x: Var.t): Var.t =
259 let val x' = Var.new x
260 in setVarInfo (x, Replace x')
261 ; x'
262 end
263 fun bindVarType (x, t) = (bind x, t)
264 fun bindPat (Pat.T {con, targs, arg}) =
265 Pat.T {con = con,
266 targs = targs,
267 arg = Option.map (arg, bindVarType)}
268 fun new {lambda = _, ty = _, var}: unit =
269 if shouldDuplicate var
270 then setVarInfo (var, Dup {duplicates = ref []})
271 else ignore (bind var)
272 fun loopExp (e: Exp.t): Exp.t =
273 let
274 val {decs, result} = Exp.dest e
275 in
276 Exp.make (loopDecs (decs, result))
277 end
278 and loopLambda (l: Lambda.t): Lambda.t =
279 let
280 val {arg, argType, body, mayInline} = Lambda.dest l
281 in
282 Lambda.make {arg = bind arg,
283 argType = argType,
284 body = loopExp body,
285 mayInline = mayInline}
286 end
287 and loopDecs (ds: Dec.t list, result): {decs: Dec.t list,
288 result: VarExp.t} =
289 case ds of
290 [] => {decs = [], result = loopVar result}
291 | d :: ds =>
292 case d of
293 MonoVal {var, ty, exp} =>
294 (case exp of
295 Lambda l =>
296 let
297 val _ = new {var = var, ty = ty, lambda = l}
298 val {decs, result} = loopDecs (ds, result)
299 val decs =
300 case varInfo var of
301 Replace var =>
302 MonoVal {var = var, ty = ty,
303 exp = Lambda (loopLambda l)}
304 :: decs
305 | Dup {duplicates, ...} =>
306 List.fold
307 (!duplicates, decs, fn (var, decs) =>
308 MonoVal {var = var, ty = ty,
309 exp = Lambda (loopLambda l)}
310 :: decs)
311 in {decs = decs, result = result}
312 end
313 | _ =>
314 let
315 val exp =
316 case exp of
317 App {func, arg} =>
318 App {func = loopVar func,
319 arg = loopVar arg}
320 | Case {test, cases, default} =>
321 let
322 datatype z = datatype Cases.t
323 val cases =
324 case cases of
325 Con cases =>
326 Con
327 (Vector.map
328 (cases, fn (p, e) =>
329 (bindPat p, loopExp e)))
330 | Word (s, v) =>
331 Word
332 (s, (Vector.map
333 (v, fn (z, e) =>
334 (z, loopExp e))))
335 in
336 Case {test = loopVar test,
337 cases = cases,
338 default =
339 Option.map
340 (default, fn (e, r) =>
341 (loopExp e, r))}
342 end
343 | ConApp {con, targs, arg} =>
344 ConApp {con = con,
345 targs = targs,
346 arg = Option.map (arg, loopVar)}
347 | Const _ => exp
348 | Handle {try, catch, handler} =>
349 Handle {try = loopExp try,
350 catch = bindVarType catch,
351 handler = loopExp handler}
352 | Lambda _ =>
353 Error.bug "Polyvariance.loopDecs: unexpected Lambda"
354 | PrimApp {prim, targs, args} =>
355 PrimApp {prim = prim,
356 targs = targs,
357 args = loopVars args}
358 | Profile _ => exp
359 | Raise {exn, extend} =>
360 Raise {exn = loopVar exn,
361 extend = extend}
362 | Select {tuple, offset} =>
363 Select {tuple = loopVar tuple,
364 offset = offset}
365 | Tuple xs => Tuple (loopVars xs)
366 | Var x => Var (loopVar x)
367 val var = bind var
368 val {decs, result} = loopDecs (ds, result)
369 in {decs = (MonoVal {var = var, ty = ty, exp = exp}
370 :: decs),
371 result = result}
372 end)
373 | Fun {decs, ...} =>
374 let
375 val _ = Vector.foreach (decs, new)
376 val {decs = ds, result} = loopDecs (ds, result)
377 val ac =
378 ref [Vector.keepAllMap
379 (decs, fn {var, ty, lambda} =>
380 case varInfo var of
381 Replace var =>
382 SOME {var = var, ty = ty,
383 lambda = loopLambda lambda}
384 | Dup _ => NONE)]
385 val dups =
386 Vector.keepAllMap
387 (decs, fn dec as {var, ...} =>
388 case varInfo var of
389 Replace _ => NONE
390 | Dup {duplicates, ...} => SOME (dec, !duplicates))
391 val _ =
392 Vector.foreach
393 (dups, fn ({var, ...}, duplicates) =>
394 List.foreach
395 (duplicates, fn var' =>
396 let
397 val vars =
398 Vector.map
399 (dups, fn ({var = var'', ...}, _) =>
400 if Var.equals (var, var'')
401 then (setVarInfo (var, Replace var')
402 ; var')
403 else bind var'')
404 in List.push
405 (ac,
406 Vector.map2
407 (dups, vars,
408 fn (({ty, lambda, ...}, _), var) =>
409 {var = var, ty = ty,
410 lambda = loopLambda lambda}))
411 end))
412 val decs = Vector.concat (!ac)
413 in {decs = Fun {tyvars = Vector.new0 (),
414 decs = decs} :: ds,
415 result = result}
416 end
417 | _ => Error.bug "Polyvariance.loopDecs: saw bogus dec"
418 val body = loopExp body
419 val overflow =
420 Option.map (overflow, fn x =>
421 case varInfo x of
422 Replace y => y
423 | _ => Error.bug "Polyvariance.duplicate: duplicating Overflow?")
424 val program =
425 Program.T {datatypes = datatypes,
426 body = body,
427 overflow = overflow}
428 val _ = Program.clear program
429 in
430 program
431 end
432
433val transform =
434 fn p =>
435 case !Control.polyvariance of
436 NONE => p
437 | SOME {hofo, rounds, small, product} =>
438 let
439 fun loop (p, n) =
440 if n = rounds
441 then p
442 else let
443 val p =
444 Control.pass
445 {display = Control.Layouts Program.layouts,
446 name = "duplicate" ^ (Int.toString (n + 1)),
447 stats = Program.layoutStats,
448 style = Control.No,
449 suffix = "post.xml",
450 thunk = fn () => shrink (transform (p, hofo, small, product))}
451 in
452 loop (p, n + 1)
453 end
454 in loop (p, 0)
455 end
456
457end