Import Debian changes 20180207-1
[hcoop/debian/mlton.git] / mlton / xml / uncurry.fun
1 (* Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 * Copyright (C) 1997-2000 NEC Research Institute.
4 *
5 * MLton is released under a BSD-style license.
6 * See the file MLton-LICENSE for details.
7 *)
8
9 functor Uncurry(S: XML_TRANSFORM_STRUCTS): XML_TRANSFORM =
10 struct
11
12 open S
13 open Dec PrimExp
14
15 fun transform (program as Program.T{datatypes, body, overflow}) =
16 let
17 datatype D = T of {var: Var.t, lambda : Lambda.t}
18
19 val {get = getArity: Var.t -> int,
20 set = setArity, ...} =
21 Property.getSet(Var.plist,
22 Property.initConst 0)
23
24 val {get = curriedRep: Var.t -> {unCurriedFun: D, curriedFun: D} option,
25 set = setCurriedRep, ...} =
26 Property.getSet(Var.plist,
27 Property.initConst NONE)
28
29 val {get = getType: Var.t -> {args: Type.t vector, result: Type.t},
30 set = setType, ...} =
31 Property.getSet(Var.plist,
32 Property.initConst {args = Vector.new1 Type.unit,
33 result = Type.unit})
34
35 fun getResultType(exp) =
36 let
37 val {decs,result} = Exp.dest(exp)
38 in
39 List.fold
40 (decs, Type.unit, fn (d, i) =>
41 case d of
42 MonoVal {var, ty, exp} =>
43 if Var.equals(var,VarExp.var(result))
44 then ty
45 else i
46 | Fun {tyvars, decs} =>
47 Vector.fold
48 (decs, Type.unit, fn ({var,ty,lambda}, i) =>
49 if Var.equals(var,VarExp.var(result))
50 then ty
51 else i)
52 | _ => i)
53 end
54
55 fun buildLambda(f,args,types,resultType) =
56 let
57 val newArg' = Var.newString("c")
58 val newArg'' = Var.newString("c")
59 in
60 Lambda.new
61 {arg = Vector.last(args),
62 argType = Vector.last(types),
63 body = Vector.fold2
64 (Vector.tabulate(Vector.length args - 2,
65 fn i => Vector.sub(args, i + 1)),
66 Vector.tabulate(Vector.length types - 2,
67 fn i => Vector.sub(types, i + 1)),
68 let
69 val newVar = Var.newString("c")
70 val arg = Vector.sub(args,0)
71 val argType = Vector.sub(types,0)
72 val decs =
73 [MonoVal
74 {var = newVar,
75 ty = Type.arrow(argType,resultType),
76 exp = Lambda
77 (Lambda.new
78 {arg = arg,
79 argType = argType,
80 body = Exp.new
81 {decs =
82 [MonoVal
83 {var = newArg',
84 ty = Type.tuple(Vector.rev(types)),
85 exp = Tuple(Vector.map
86 (Vector.rev(args),
87 fn a => VarExp.mono(a)))},
88 MonoVal
89 {var = newArg'',
90 ty = resultType,
91 exp = App {func = f,
92 arg = VarExp.mono(newArg')}}],
93 result = VarExp.mono(newArg'')}})}]
94 val result = VarExp.mono(newVar)
95 in
96 Exp.new
97 {decs = decs, result = result}
98 end,
99 fn (a, atype, i) =>
100 let
101 val newVar = Var.newString("c")
102 in
103 Exp.new
104 {decs = [MonoVal
105 {var = newVar,
106 ty = Type.arrow(atype, getResultType(i)),
107 exp = Lambda(Lambda.new {arg = a,
108 argType = atype,
109 body = i})}],
110 result = VarExp.mono(newVar)}
111 end)}
112 end
113
114 fun uncurryFun(dec) =
115 let
116 fun lamExp(decs,result,args,types,newDecs,e) =
117 case decs of
118 [] => (args,types,e)
119 | d::rest =>
120 case d of
121 Dec.MonoVal{var, ty, exp = Const c} =>
122 lamExp(rest, result, args,types,d::newDecs,e)
123 | Dec.MonoVal{var, ty, exp = Var v} =>
124 lamExp(rest, result, args,types,d::newDecs,e)
125 | Dec.MonoVal{var, ty, exp = Select tuple} =>
126 lamExp(rest, result, args,types,d::newDecs,e)
127 | Dec.MonoVal{var, ty, exp = Lambda l} =>
128 let
129 val body = Lambda.body(l)
130 val r = result
131 val {decs,result} = Exp.dest(body)
132 val newDecs = List.append(newDecs,decs)
133 val new = Exp.new{decs = newDecs,result = result}
134 in
135 if Var.equals(var, VarExp.var(r))
136 andalso List.isEmpty(rest)
137 then lamExp(newDecs,
138 result,
139 Lambda.arg(l)::args,
140 Lambda.argType(l)::types,
141 [],
142 new)
143 else (args,types,e)
144 end
145 | _ => (args,types,e)
146 val lamExp = fn x =>
147 let val (args,types,e) = lamExp x
148 in (Vector.fromList args, Vector.fromList types, e)
149 end
150
151 val T{var,lambda} = dec
152 val (f, r) = let
153 val arg = Lambda.arg(lambda)
154 val argType = Lambda.argType(lambda)
155 val body = Lambda.body(lambda)
156 val {decs,result} = Exp.dest(body)
157 in
158 (var, lamExp(decs, result, [arg], [argType], [],body))
159 end
160
161 fun buildCurried (f,args,types,e) =
162 let
163 val newVar = Var.newString("c")
164 val newArg = Var.newString("c")
165 val (newDecs,n) =
166 Vector.fold2
167 (Vector.rev(args),
168 Vector.rev(types),
169 ([],0), fn (a, mtype, (l, i)) =>
170 (MonoVal
171 {var = a,
172 ty = mtype,
173 exp = PrimExp.Select {tuple = VarExp.mono(newArg),
174 offset = i}}::l,
175 i+1))
176 val newExp = Exp.new {decs = List.append(newDecs, Exp.decs(e)),
177 result = Exp.result(e)}
178 val resultType = getResultType(newExp)
179 val unCurriedFun =
180 T{var = newVar,
181 lambda = Lambda.new {arg = newArg,
182 argType = Type.tuple(Vector.rev(types)),
183 body = newExp}}
184 val newArgs = Vector.map(args, fn z => Var.newString("c"))
185 val newFun = buildLambda(VarExp.mono(newVar),newArgs,types,resultType)
186
187 val newFunBinding = T{var = f, lambda = newFun}
188 in
189 setCurriedRep(f, SOME {unCurriedFun = unCurriedFun,
190 curriedFun = newFunBinding})
191 end
192 in
193 case r of
194 (args,types,e) =>
195 (setArity(f, Vector.length(args));
196 setType(f, {args = types, result = getResultType(e)});
197 if getArity(f) > 1
198 then buildCurried(f,args,types,e)
199 else ())
200 end
201
202 fun replaceVar(decs,old,new) =
203 let
204 fun compare(v) = if Var.equals(VarExp.var(v),old)
205 then new
206 else v
207 fun replaceExp(e) = let
208 val {decs,result} = Exp.dest(e)
209 val newDecs = replaceVar(decs,old,new)
210 val newResult = compare(result)
211 in
212 Exp.new {decs = newDecs,
213 result = newResult}
214 end
215 in
216 List.map
217 (decs, fn d =>
218 (case d of
219 MonoVal {var, ty, exp} =>
220 MonoVal {var=var,
221 ty = ty,
222 exp = (case exp of
223 Var v => PrimExp.Var(compare(v))
224 | Tuple vs =>
225 Tuple(Vector.map(vs, fn v => compare(v)))
226 | Select {tuple,offset} =>
227 Select {tuple = compare(tuple),
228 offset = offset}
229 | Lambda l =>
230 let
231 val {arg,argType,body} = Lambda.dest(l)
232 val {decs,result} = Exp.dest(body)
233 val newDecs = replaceVar(decs,old,new)
234 in
235 Lambda (Lambda.new
236 {arg=arg,
237 argType=argType,
238 body=Exp.new {decs = newDecs,
239 result = result}})
240 end
241 | ConApp {con,targs,arg} =>
242 (case arg of
243 NONE => exp
244 | SOME v => ConApp {con = con,
245 targs = targs,
246 arg = SOME (compare(v))})
247 | PrimApp {prim,targs,args} =>
248 PrimApp {prim = prim,
249 targs = targs,
250 args = Vector.map(args, fn a => compare(a))}
251 | App {func,arg} =>
252 App {func = compare(func),
253 arg = compare(arg)}
254 | Raise {exn,filePos} =>
255 Raise {exn = compare(exn),
256 filePos = filePos}
257 | Case {test,cases,default} =>
258 Case {test=compare(test),
259 cases = Cases.map
260 (cases,fn e =>
261 replaceExp(e)),
262 default = Option.map
263 (default, fn (e,r) =>
264 (replaceExp e,r))}
265 | Handle {try,catch,handler} =>
266 Handle {try = replaceExp(try),
267 catch = catch,
268 handler = replaceExp(handler)}
269 | _ => exp)}
270 | Fun {tyvars,decs} =>
271 Fun {tyvars=tyvars,
272 decs = Vector.map
273 (decs, fn {var,ty,lambda} =>
274 {var = var,
275 ty = ty,
276 lambda = let
277 val {arg,argType,body} =
278 Lambda.dest(lambda)
279 in
280 Lambda.new
281 ({arg = arg,
282 argType = argType,
283 body = replaceExp(body)})
284 end})}
285 | _ => d))
286 end
287
288 fun uncurryApp(decs,expResult) =
289 let
290 fun makeUncurryApp(f,arguments,lastCall) =
291 let
292 val newArg = Var.newString("c")
293 val newArg' = Var.newString("c")
294 val varF = VarExp.var(f)
295 val {args,result} = getType(varF)
296 val var = (case curriedRep(varF) of
297 NONE => Error.bug "Uncurry: uncurryApp"
298 | SOME {unCurriedFun,curriedFun} =>
299 let val T{var,lambda} = unCurriedFun
300 in var
301 end)
302 val argDec = MonoVal{var = newArg,
303 ty = Type.tuple(Vector.rev(args)),
304 exp = Tuple(Vector.rev(arguments))}
305 val appDec = MonoVal{var = newArg',
306 ty = result,
307 exp = App {func = VarExp.mono(var),
308 arg = VarExp.mono(newArg)}}
309 val newR = if Var.equals(lastCall, VarExp.var(expResult))
310 then (SOME newArg')
311 else NONE
312 in (appDec::[argDec],newR,newArg')
313 end
314 in case decs of
315 [] => Error.bug "Uncurry: uncurryApp"
316 | d::r => (case d of
317 MonoVal {var, ty, exp = App {func,arg}} =>
318 (case curriedRep(VarExp.var(func)) of
319 NONE => Error.bug "Uncurry: uncurryApp"
320 | SOME _ => let
321 val arity = getArity(VarExp.var(func))
322 fun loop(args,arity,d,f) =
323 if arity = 0
324 then SOME (Vector.fromList args,d,f)
325 else
326 case d of
327 [] => NONE
328 | h::r =>
329 (case h of
330 MonoVal {var,ty,
331 exp = App {func,arg}} =>
332 if Var.equals(VarExp.var(func),f)
333 then loop(arg::args,
334 arity-1,
335 r,
336 var)
337 else NONE
338 | _ => NONE)
339 in
340 case loop([arg],arity-1,r,var) of
341 NONE => ([d],r,NONE)
342 | SOME (args,r,lastCall) =>
343 let
344 val (newDecs,newR,newArg) =
345 makeUncurryApp(func,args,lastCall)
346 val r = (replaceVar(r,lastCall,
347 VarExp.mono(newArg)))
348 in
349 (newDecs,r,newR)
350 end
351 end)
352 | _ => Error.bug "Uncurry: uncurryApp")
353 end
354
355 fun singleUse(var,decs) =
356 let
357 fun compare(e) = (case e of
358 App {func,arg} => Var.equals(VarExp.var(func),var)
359 | _ => false)
360 in
361 List.fold
362 (decs, false, fn (d,r) =>
363 case d of
364 MonoVal {var,ty,exp} => compare(exp)
365 | _ => false)
366 end
367
368
369 fun transform(body) =
370 let
371 val {decs,result} = Exp.dest(body)
372 val newR = ref NONE
373 in
374 Exp.new
375 {decs =
376 List.rev
377 (let
378 fun loop(decs,newDecs) =
379 case decs of
380 [] => newDecs
381 | d::rest =>
382 (case d of
383 MonoVal {var,ty, exp = Lambda l} =>
384 (case curriedRep(var) of
385 NONE =>
386 let
387 val lamBody = Lambda.body(l)
388 val arg = Lambda.arg(l)
389 val argType = Lambda.argType(l)
390 val newLam =
391 Lambda.new{arg=arg,
392 argType = argType,
393 body = transform(lamBody)}
394 val newDec = MonoVal{var=var,
395 ty=ty,
396 exp = Lambda newLam}
397 in
398 loop(rest,newDec::newDecs)
399 end
400 | SOME {unCurriedFun,curriedFun} =>
401 let
402 val T{var,lambda} = unCurriedFun
403 val body = Lambda.body(lambda)
404 val newBody = transform(body)
405 val resultType = getResultType(newBody)
406 val argType = Lambda.argType(lambda)
407 val l = Lambda(Lambda.new
408 {arg =
409 Lambda.arg(lambda),
410 argType = argType,
411 body = newBody})
412 val b1 = MonoVal{var=var,
413 ty = Type.arrow(argType,resultType),
414 exp = l}
415 val T{var,lambda} = curriedFun
416 val argType = Lambda.argType(lambda)
417 val resultType = getResultType(Lambda.body(lambda))
418 val b2 = MonoVal{var=var,
419 ty =
420 Type.arrow(argType, resultType),
421 exp = Lambda lambda}
422 in loop(rest,b2::(b1::newDecs))
423 end)
424 | MonoVal {var,ty,exp = App {func,arg}} =>
425 (case curriedRep(VarExp.var(func)) of
426 NONE => loop(rest,d::newDecs)
427 | SOME _ =>
428 if singleUse(var,rest)
429 then
430 let
431 val (appDecs,r,newResult) =
432 uncurryApp(decs,result)
433 in (newR := newResult;
434 loop(r,List.append(appDecs,newDecs)))
435 end
436 else loop(rest,d::newDecs))
437 | MonoVal {var,ty,exp = Case {test,cases,default}} =>
438 let
439 val newCases =
440 Cases.map(cases, fn e => transform(e))
441 val default = Option.map
442 (default, fn (e,r) =>
443 (transform(e),r))
444 in
445 loop(rest,
446 (MonoVal{var=var,
447 ty=ty,
448 exp = Case {test = test,
449 cases = newCases,
450 default = default}}::
451 newDecs))
452 end
453 | MonoVal {var,ty, exp = Handle {try,catch,handler}} =>
454 loop(rest,
455 (MonoVal{var=var,
456 ty=ty,
457 exp = Handle {try = transform(try),
458 catch = catch,
459 handler = transform(handler)}}::
460 newDecs))
461 | Fun {tyvars,decs} =>
462 loop(rest,
463 Fun {tyvars = Vector.new0 (),
464 decs =
465 Vector.fromList(
466 Vector.fold
467 (decs,
468 []:{var:Var.t,
469 ty:Type.t,
470 lambda:Lambda.t} list,
471 fn (d as {var,
472 ty,
473 lambda:Lambda.t},
474 acc) =>
475 (case curriedRep(var) of
476 NONE =>
477 let
478 val body = Lambda.body(lambda)
479 val arg = Lambda.arg(lambda)
480 val argType = Lambda.argType(lambda)
481 val newBody = transform(body)
482 val newLam = Lambda.new{arg = arg,
483 argType = argType,
484 body = newBody}
485 in
486 {var=var,
487 ty=ty,
488 lambda=newLam}::acc
489 end
490 | SOME {unCurriedFun,curriedFun} =>
491 let
492 val T{var,lambda} = unCurriedFun
493 val body = Lambda.body(lambda)
494 val newBody = transform(body)
495 val argType = Lambda.argType(lambda)
496 val resultType = getResultType(newBody)
497 val b1 = {var=var,
498 ty = Type.arrow(argType,resultType),
499 lambda =
500 Lambda.new{arg = Lambda.arg(lambda),
501 argType = argType,
502 body = newBody}}
503 val T{var,lambda} = curriedFun
504 val argType = Lambda.argType(lambda)
505 val newBody = transform(Lambda.body(lambda))
506 val resultType = getResultType(newBody)
507 val b2 = {var=var,
508 ty = Type.arrow(argType,resultType),
509 lambda = lambda}
510 in b1::(b2::acc)
511 end)))}::newDecs)
512 | _ => loop(rest,d::newDecs))
513 in loop(decs,[])
514 end),
515 result = (case !newR of
516 NONE => result
517 | SOME r => VarExp.mono(r))}
518 end
519 in
520 Exp.foreachExp
521 (body, fn e =>
522 let
523 val {decs,result} = Exp.dest(e)
524 in
525 List.foreach
526 (decs, fn d =>
527 case d of
528 MonoVal {var,ty,exp = Lambda l} =>
529 uncurryFun(T{var=var,lambda=l})
530 | Fun {tyvars,decs} =>
531 Vector.foreach
532 (decs, fn {var,ty,lambda} =>
533 uncurryFun(T{var=var,lambda=lambda}))
534 | _ => ())
535 end);
536 let val newBody = transform(body)
537 in
538 Program.T{datatypes = datatypes, body = newBody, overflow = overflow}
539 end
540 end
541 end