1 (* Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 * Copyright (C) 1997-2000 NEC Research Institute.
5 * MLton is released under a BSD-style license.
6 * See the file MLton-LICENSE for details.
9 functor Uncurry(S: XML_TRANSFORM_STRUCTS): XML_TRANSFORM =
15 fun transform (program as Program.T{datatypes, body, overflow}) =
17 datatype D = T of {var: Var.t, lambda : Lambda.t}
19 val {get = getArity: Var.t -> int,
20 set = setArity, ...} =
21 Property.getSet(Var.plist,
24 val {get = curriedRep: Var.t -> {unCurriedFun: D, curriedFun: D} option,
25 set = setCurriedRep, ...} =
26 Property.getSet(Var.plist,
27 Property.initConst NONE)
29 val {get = getType: Var.t -> {args: Type.t vector, result: Type.t},
31 Property.getSet(Var.plist,
32 Property.initConst {args = Vector.new1 Type.unit,
35 fun getResultType(exp) =
37 val {decs,result} = Exp.dest(exp)
40 (decs, Type.unit, fn (d, i) =>
42 MonoVal {var, ty, exp} =>
43 if Var.equals(var,VarExp.var(result))
46 | Fun {tyvars, decs} =>
48 (decs, Type.unit, fn ({var,ty,lambda}, i) =>
49 if Var.equals(var,VarExp.var(result))
55 fun buildLambda(f,args,types,resultType) =
57 val newArg' = Var.newString("c")
58 val newArg'' = Var.newString("c")
61 {arg = Vector.last(args),
62 argType = Vector.last(types),
64 (Vector.tabulate(Vector.length args - 2,
65 fn i => Vector.sub(args, i + 1)),
66 Vector.tabulate(Vector.length types - 2,
67 fn i => Vector.sub(types, i + 1)),
69 val newVar = Var.newString("c")
70 val arg = Vector.sub(args,0)
71 val argType = Vector.sub(types,0)
75 ty = Type.arrow(argType,resultType),
84 ty = Type.tuple(Vector.rev(types)),
85 exp = Tuple(Vector.map
87 fn a => VarExp.mono(a)))},
92 arg = VarExp.mono(newArg')}}],
93 result = VarExp.mono(newArg'')}})}]
94 val result = VarExp.mono(newVar)
97 {decs = decs, result = result}
101 val newVar = Var.newString("c")
106 ty = Type.arrow(atype, getResultType(i)),
107 exp = Lambda(Lambda.new {arg = a,
110 result = VarExp.mono(newVar)}
114 fun uncurryFun(dec) =
116 fun lamExp(decs,result,args,types,newDecs,e) =
121 Dec.MonoVal{var, ty, exp = Const c} =>
122 lamExp(rest, result, args,types,d::newDecs,e)
123 | Dec.MonoVal{var, ty, exp = Var v} =>
124 lamExp(rest, result, args,types,d::newDecs,e)
125 | Dec.MonoVal{var, ty, exp = Select tuple} =>
126 lamExp(rest, result, args,types,d::newDecs,e)
127 | Dec.MonoVal{var, ty, exp = Lambda l} =>
129 val body = Lambda.body(l)
131 val {decs,result} = Exp.dest(body)
132 val newDecs = List.append(newDecs,decs)
133 val new = Exp.new{decs = newDecs,result = result}
135 if Var.equals(var, VarExp.var(r))
136 andalso List.isEmpty(rest)
140 Lambda.argType(l)::types,
145 | _ => (args,types,e)
147 let val (args,types,e) = lamExp x
148 in (Vector.fromList args, Vector.fromList types, e)
151 val T{var,lambda} = dec
153 val arg = Lambda.arg(lambda)
154 val argType = Lambda.argType(lambda)
155 val body = Lambda.body(lambda)
156 val {decs,result} = Exp.dest(body)
158 (var, lamExp(decs, result, [arg], [argType], [],body))
161 fun buildCurried (f,args,types,e) =
163 val newVar = Var.newString("c")
164 val newArg = Var.newString("c")
169 ([],0), fn (a, mtype, (l, i)) =>
173 exp = PrimExp.Select {tuple = VarExp.mono(newArg),
176 val newExp = Exp.new {decs = List.append(newDecs, Exp.decs(e)),
177 result = Exp.result(e)}
178 val resultType = getResultType(newExp)
181 lambda = Lambda.new {arg = newArg,
182 argType = Type.tuple(Vector.rev(types)),
184 val newArgs = Vector.map(args, fn z => Var.newString("c"))
185 val newFun = buildLambda(VarExp.mono(newVar),newArgs,types,resultType)
187 val newFunBinding = T{var = f, lambda = newFun}
189 setCurriedRep(f, SOME {unCurriedFun = unCurriedFun,
190 curriedFun = newFunBinding})
195 (setArity(f, Vector.length(args));
196 setType(f, {args = types, result = getResultType(e)});
198 then buildCurried(f,args,types,e)
202 fun replaceVar(decs,old,new) =
204 fun compare(v) = if Var.equals(VarExp.var(v),old)
207 fun replaceExp(e) = let
208 val {decs,result} = Exp.dest(e)
209 val newDecs = replaceVar(decs,old,new)
210 val newResult = compare(result)
212 Exp.new {decs = newDecs,
219 MonoVal {var, ty, exp} =>
223 Var v => PrimExp.Var(compare(v))
225 Tuple(Vector.map(vs, fn v => compare(v)))
226 | Select {tuple,offset} =>
227 Select {tuple = compare(tuple),
231 val {arg,argType,body} = Lambda.dest(l)
232 val {decs,result} = Exp.dest(body)
233 val newDecs = replaceVar(decs,old,new)
238 body=Exp.new {decs = newDecs,
241 | ConApp {con,targs,arg} =>
244 | SOME v => ConApp {con = con,
246 arg = SOME (compare(v))})
247 | PrimApp {prim,targs,args} =>
248 PrimApp {prim = prim,
250 args = Vector.map(args, fn a => compare(a))}
252 App {func = compare(func),
254 | Raise {exn,filePos} =>
255 Raise {exn = compare(exn),
257 | Case {test,cases,default} =>
258 Case {test=compare(test),
263 (default, fn (e,r) =>
265 | Handle {try,catch,handler} =>
266 Handle {try = replaceExp(try),
268 handler = replaceExp(handler)}
270 | Fun {tyvars,decs} =>
273 (decs, fn {var,ty,lambda} =>
277 val {arg,argType,body} =
283 body = replaceExp(body)})
288 fun uncurryApp(decs,expResult) =
290 fun makeUncurryApp(f,arguments,lastCall) =
292 val newArg = Var.newString("c")
293 val newArg' = Var.newString("c")
294 val varF = VarExp.var(f)
295 val {args,result} = getType(varF)
296 val var = (case curriedRep(varF) of
297 NONE => Error.bug "Uncurry: uncurryApp"
298 | SOME {unCurriedFun,curriedFun} =>
299 let val T{var,lambda} = unCurriedFun
302 val argDec = MonoVal{var = newArg,
303 ty = Type.tuple(Vector.rev(args)),
304 exp = Tuple(Vector.rev(arguments))}
305 val appDec = MonoVal{var = newArg',
307 exp = App {func = VarExp.mono(var),
308 arg = VarExp.mono(newArg)}}
309 val newR = if Var.equals(lastCall, VarExp.var(expResult))
312 in (appDec::[argDec],newR,newArg')
315 [] => Error.bug "Uncurry: uncurryApp"
317 MonoVal {var, ty, exp = App {func,arg}} =>
318 (case curriedRep(VarExp.var(func)) of
319 NONE => Error.bug "Uncurry: uncurryApp"
321 val arity = getArity(VarExp.var(func))
322 fun loop(args,arity,d,f) =
324 then SOME (Vector.fromList args,d,f)
331 exp = App {func,arg}} =>
332 if Var.equals(VarExp.var(func),f)
340 case loop([arg],arity-1,r,var) of
342 | SOME (args,r,lastCall) =>
344 val (newDecs,newR,newArg) =
345 makeUncurryApp(func,args,lastCall)
346 val r = (replaceVar(r,lastCall,
347 VarExp.mono(newArg)))
352 | _ => Error.bug "Uncurry: uncurryApp")
355 fun singleUse(var,decs) =
357 fun compare(e) = (case e of
358 App {func,arg} => Var.equals(VarExp.var(func),var)
362 (decs, false, fn (d,r) =>
364 MonoVal {var,ty,exp} => compare(exp)
369 fun transform(body) =
371 val {decs,result} = Exp.dest(body)
378 fun loop(decs,newDecs) =
383 MonoVal {var,ty, exp = Lambda l} =>
384 (case curriedRep(var) of
387 val lamBody = Lambda.body(l)
388 val arg = Lambda.arg(l)
389 val argType = Lambda.argType(l)
393 body = transform(lamBody)}
394 val newDec = MonoVal{var=var,
398 loop(rest,newDec::newDecs)
400 | SOME {unCurriedFun,curriedFun} =>
402 val T{var,lambda} = unCurriedFun
403 val body = Lambda.body(lambda)
404 val newBody = transform(body)
405 val resultType = getResultType(newBody)
406 val argType = Lambda.argType(lambda)
407 val l = Lambda(Lambda.new
412 val b1 = MonoVal{var=var,
413 ty = Type.arrow(argType,resultType),
415 val T{var,lambda} = curriedFun
416 val argType = Lambda.argType(lambda)
417 val resultType = getResultType(Lambda.body(lambda))
418 val b2 = MonoVal{var=var,
420 Type.arrow(argType, resultType),
422 in loop(rest,b2::(b1::newDecs))
424 | MonoVal {var,ty,exp = App {func,arg}} =>
425 (case curriedRep(VarExp.var(func)) of
426 NONE => loop(rest,d::newDecs)
428 if singleUse(var,rest)
431 val (appDecs,r,newResult) =
432 uncurryApp(decs,result)
433 in (newR := newResult;
434 loop(r,List.append(appDecs,newDecs)))
436 else loop(rest,d::newDecs))
437 | MonoVal {var,ty,exp = Case {test,cases,default}} =>
440 Cases.map(cases, fn e => transform(e))
441 val default = Option.map
442 (default, fn (e,r) =>
448 exp = Case {test = test,
450 default = default}}::
453 | MonoVal {var,ty, exp = Handle {try,catch,handler}} =>
457 exp = Handle {try = transform(try),
459 handler = transform(handler)}}::
461 | Fun {tyvars,decs} =>
463 Fun {tyvars = Vector.new0 (),
470 lambda:Lambda.t} list,
475 (case curriedRep(var) of
478 val body = Lambda.body(lambda)
479 val arg = Lambda.arg(lambda)
480 val argType = Lambda.argType(lambda)
481 val newBody = transform(body)
482 val newLam = Lambda.new{arg = arg,
490 | SOME {unCurriedFun,curriedFun} =>
492 val T{var,lambda} = unCurriedFun
493 val body = Lambda.body(lambda)
494 val newBody = transform(body)
495 val argType = Lambda.argType(lambda)
496 val resultType = getResultType(newBody)
498 ty = Type.arrow(argType,resultType),
500 Lambda.new{arg = Lambda.arg(lambda),
503 val T{var,lambda} = curriedFun
504 val argType = Lambda.argType(lambda)
505 val newBody = transform(Lambda.body(lambda))
506 val resultType = getResultType(newBody)
508 ty = Type.arrow(argType,resultType),
512 | _ => loop(rest,d::newDecs))
515 result = (case !newR of
517 | SOME r => VarExp.mono(r))}
523 val {decs,result} = Exp.dest(e)
528 MonoVal {var,ty,exp = Lambda l} =>
529 uncurryFun(T{var=var,lambda=l})
530 | Fun {tyvars,decs} =>
532 (decs, fn {var,ty,lambda} =>
533 uncurryFun(T{var=var,lambda=lambda}))
536 let val newBody = transform(body)
538 Program.T{datatypes = datatypes, body = newBody, overflow = overflow}