Import Debian changes 20180207-1
[hcoop/debian/mlton.git] / mlton / xml / cps-transform.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 2007-2007 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 *
4 * MLton is released under a BSD-style license.
5 * See the file MLton-LICENSE for details.
6 *)
7
8functor CPSTransform (S: XML_TRANSFORM_STRUCTS): XML_TRANSFORM =
9struct
10
11open S
12datatype z = datatype Dec.t
13datatype z = datatype PrimExp.t
14
15fun transform (prog: Program.t): Program.t =
16 let
17 val Program.T {datatypes, body, overflow} = prog
18
19 (* Answer type is always unit in an XML IL program. *)
20 val ansTy = Type.unit
21 (* Exception type is always exn in an XML IL program. *)
22 val exnTy = Type.exn
23
24
25 (* Style of function-type translation. *)
26 datatype style = Curried | Mixed | Uncurried
27 val style = Uncurried
28
29 val {hom = transType, destroy = destroyTransType} =
30 Type.makeMonoHom
31 {con = fn (_, c, tys) =>
32 if Tycon.equals (c, Tycon.arrow)
33 then let
34 val argTy = Vector.sub (tys, 0)
35 val resTy = Vector.sub (tys, 1)
36 in
37 case style of
38 Curried =>
39 Type.arrow
40 (Type.arrow (resTy, ansTy),
41 Type.arrow
42 (Type.arrow (exnTy, ansTy),
43 Type.arrow (argTy, ansTy)))
44 | Mixed =>
45 Type.arrow
46 ((Type.tuple o Vector.new2)
47 (Type.arrow (resTy, ansTy),
48 Type.arrow (exnTy, ansTy)),
49 Type.arrow (argTy, ansTy))
50 | Uncurried =>
51 Type.arrow
52 ((Type.tuple o Vector.new3)
53 (Type.arrow (resTy, ansTy),
54 Type.arrow (exnTy, ansTy),
55 argTy),
56 ansTy)
57 end
58 else Type.con (c, tys)}
59
60 (* A property to record (original) type of each bound variable. *)
61 val {get = getVarOrigType: Var.t -> Type.t, set = setVarOrigType, ...} =
62 Property.getSetOnce
63 (Var.plist, Property.initRaise ("getVarOrigType", Var.layout))
64 val getVarExpOrigType = getVarOrigType o VarExp.var
65
66 (* A mayOverflow primitive needs a special translation with a wrapper
67 * datatype. See transPrimExp:PrimApp.
68 *)
69 val wrapDatatypes = ref []
70 val {get = getWrap, destroy = destroyWrap, ...} =
71 Property.destGet
72 (Type.plist, Property.initFun (fn ty =>
73 let
74 val successCon = Con.newString "Success"
75 val failureCon = Con.newString "Failure"
76 val wrapTycon = Tycon.newString "Wrap"
77 val wrapTy = Type.con (wrapTycon, Vector.new0 ())
78 val wrapDatatype =
79 {cons = Vector.new2
80 ({arg = SOME ty, con = successCon},
81 {arg = SOME exnTy, con = failureCon}),
82 tycon = wrapTycon,
83 tyvars = Vector.new0 ()}
84 val () = List.push (wrapDatatypes, wrapDatatype)
85 in
86 {successCon = successCon,
87 failureCon = failureCon,
88 wrapTy = wrapTy}
89 end))
90
91 fun transVarExpWithType (x: VarExp.t) : DirectExp.t * Type.t =
92 let
93 val xTy = transType (getVarExpOrigType x)
94 in
95 (DirectExp.varExp (x, xTy), xTy)
96 end
97 val transVarExp = #1 o transVarExpWithType
98
99 fun transLambda (l: Lambda.t): Lambda.t =
100 let
101 val {arg = argVar, argType = argTy, body, mayInline} = Lambda.dest l
102 val resTy = getVarExpOrigType (Exp.result body)
103
104 val argTy = transType argTy
105 val resTy = transType resTy
106 val kVar = Var.newString "k"
107 val kTy = Type.arrow (resTy, ansTy)
108 val hVar = Var.newString "h"
109 val hTy = Type.arrow (exnTy, ansTy)
110 val bodyKHA = transExp (body, kVar, kTy, hVar, hTy)
111 in
112 case style of
113 Curried =>
114 let
115 val bodyKH =
116 DirectExp.lambda
117 {arg = argVar,
118 argType = argTy,
119 body = bodyKHA,
120 bodyType = ansTy,
121 mayInline = mayInline}
122 val bodyK =
123 DirectExp.lambda
124 {arg = hVar,
125 argType = hTy,
126 body = bodyKH,
127 bodyType = Type.arrow (argTy, ansTy),
128 mayInline = true}
129 in
130 Lambda.make
131 {arg = kVar,
132 argType = kTy,
133 body = DirectExp.toExp bodyK,
134 mayInline = true}
135 end
136 | Mixed =>
137 let
138 val xVar = Var.newNoname ()
139 val xTy = Type.tuple (Vector.new2 (kTy, hTy))
140 val x = DirectExp.monoVar (xVar, xTy)
141 val bodyKH =
142 DirectExp.lambda
143 {arg = argVar,
144 argType = argTy,
145 body = bodyKHA,
146 bodyType = ansTy,
147 mayInline = mayInline}
148 val bodyXK =
149 DirectExp.let1
150 {var = hVar,
151 exp = (DirectExp.select {tuple = x,
152 offset = 1,
153 ty = hTy}),
154 body = bodyKH}
155 val bodyX =
156 DirectExp.let1
157 {var = kVar,
158 exp = (DirectExp.select {tuple = x,
159 offset = 0,
160 ty = kTy}),
161 body = bodyXK}
162 in
163 Lambda.make
164 {arg = xVar,
165 argType = xTy,
166 body = DirectExp.toExp bodyX,
167 mayInline = true}
168 end
169 | Uncurried =>
170 let
171 val xVar = Var.newNoname ()
172 val xTy = Type.tuple (Vector.new3 (kTy, hTy, argTy))
173 val x = DirectExp.monoVar (xVar, xTy)
174 val bodyXKH =
175 DirectExp.let1
176 {var = argVar,
177 exp = (DirectExp.select {tuple = x,
178 offset = 2,
179 ty = argTy}),
180 body = bodyKHA}
181 val bodyXK =
182 DirectExp.let1
183 {var = hVar,
184 exp = (DirectExp.select {tuple = x,
185 offset = 1,
186 ty = hTy}),
187 body = bodyXKH}
188 val bodyX =
189 DirectExp.let1
190 {var = kVar,
191 exp = (DirectExp.select {tuple = x,
192 offset = 0,
193 ty = kTy}),
194 body = bodyXK}
195 in
196 Lambda.make
197 {arg = xVar,
198 argType = xTy,
199 body = DirectExp.toExp bodyX,
200 mayInline = mayInline}
201 end
202 end
203 and transPrimExp (e: PrimExp.t, eTy: Type.t,
204 kVar: Var.t, kTy: Type.t,
205 hVar: Var.t, hTy: Type.t): DirectExp.t =
206 let
207 val eTy = transType eTy
208 val k = DirectExp.monoVar (kVar, kTy)
209 val h = DirectExp.monoVar (hVar, hTy)
210 fun return x = DirectExp.app {func = k, arg = x, ty = ansTy}
211 in
212 case e of
213 App {arg, func} =>
214 let
215 val (arg, argTy) = transVarExpWithType arg
216 val func = transVarExp func
217 in
218 case style of
219 Curried =>
220 let
221 val app1 =
222 DirectExp.app
223 {func = func,
224 arg = k,
225 ty = Type.arrow (hTy, Type.arrow (argTy, ansTy))}
226 val app2 =
227 DirectExp.app
228 {func = app1,
229 arg = h,
230 ty = Type.arrow (argTy, ansTy)}
231 val app3 =
232 DirectExp.app
233 {func = app2,
234 arg = arg,
235 ty = ansTy}
236 in
237 app3
238 end
239 | Mixed =>
240 let
241 val arg2 =
242 DirectExp.tuple
243 {exps = Vector.new2 (k, h),
244 ty = (Type.tuple o Vector.new2) (kTy, hTy)}
245 val app2 =
246 DirectExp.app
247 {func = func,
248 arg = arg2,
249 ty = Type.arrow (argTy, ansTy)}
250 val app3 =
251 DirectExp.app
252 {func = app2,
253 arg = arg,
254 ty = ansTy}
255 in
256 app3
257 end
258 | Uncurried =>
259 let
260 val arg3 =
261 DirectExp.tuple
262 {exps = Vector.new3 (k, h, arg),
263 ty = (Type.tuple o Vector.new3) (kTy, hTy, argTy)}
264 val app3 =
265 DirectExp.app
266 {func = func,
267 arg = arg3,
268 ty = ansTy}
269 in
270 app3
271 end
272 end
273 | Case {cases, default, test} =>
274 let
275 val cases =
276 case cases of
277 Cases.Con cases =>
278 let
279 val cases =
280 Vector.map
281 (cases, fn (Pat.T {arg, con, targs}, e) =>
282 let
283 val arg =
284 Option.map
285 (arg, fn (arg, argTy) =>
286 (arg, transType argTy))
287 val targs = Vector.map (targs, transType)
288 in
289 (Pat.T {arg = arg, con = con, targs = targs},
290 transExp (e, kVar, kTy, hVar, hTy))
291 end)
292 in
293 Cases.Con cases
294 end
295 | Cases.Word (ws, cases) =>
296 let
297 val cases =
298 Vector.map
299 (cases, fn (w, e) =>
300 (w, transExp (e, kVar, kTy, hVar, hTy)))
301 in
302 Cases.Word (ws, cases)
303 end
304 val default =
305 Option.map
306 (default, fn (e, r) =>
307 (transExp (e, kVar, kTy, hVar, hTy), r))
308 in
309 DirectExp.casee
310 {cases = cases,
311 default = default,
312 test = transVarExp test,
313 ty = ansTy}
314 end
315 | ConApp {arg, con, targs} =>
316 (return o DirectExp.conApp)
317 {arg = Option.map (arg, transVarExp),
318 con = con,
319 targs = Vector.map (targs, transType),
320 ty = eTy}
321 | Const c => return (DirectExp.const c)
322 | Handle {catch = (cVar, _), handler, try} =>
323 let
324 val h'Var = Var.newString "h"
325 val h'Ty = Type.arrow (exnTy, ansTy)
326 val h'Body =
327 DirectExp.lambda
328 {arg = cVar,
329 argType = exnTy,
330 body = transExp (handler, kVar, kTy, hVar, hTy),
331 bodyType = ansTy,
332 mayInline = true}
333 in
334 DirectExp.let1 {var = h'Var, exp = h'Body, body =
335 transExp (try, kVar, kTy, h'Var, h'Ty)}
336 end
337 | Lambda l =>
338 let
339 val l = transLambda l
340 in
341 return (DirectExp.fromLambda (l, eTy))
342 end
343 | PrimApp {args, prim, targs} =>
344 let
345 val primAppExp =
346 DirectExp.primApp
347 {args = Vector.map (args, transVarExp),
348 prim = prim,
349 targs = Vector.map (targs, transType),
350 ty = eTy}
351 in
352 if Prim.mayOverflow prim
353 then let
354 (* A mayOverflow primitive has an
355 * implicit raise, which is introduced
356 * explicitly by closure-convert
357 * (transformation from SXML to SSA).
358 *
359 * We leave an explicit Handle around
360 * the primitive to catch the
361 * exception. The non-exceptional
362 * result goes to the (normal)
363 * continuation, while the exception
364 * goes to the exception continuation.
365 *
366 * Naively, we would do:
367 * (k (primApp)) handle x => h x
368 * But, this evaluates the (normal)
369 * continuation in the context of the
370 * handler.
371 *
372 * Rather, we do:
373 * case ((Success (primApp))
374 * handle x => Failure x) of
375 * Success x => k x
376 * Failure x => h x
377 * This evaluates the (normal)
378 * continuation outside the context of
379 * the handler.
380 *
381 * See <src>/lib/mlton/basic/exn0.sml
382 * and "Exceptional Syntax" by Benton
383 * and Kennedy.
384 *
385 *)
386
387 val {successCon, failureCon, wrapTy} =
388 getWrap eTy
389
390 val testExp =
391 let
392 val xVar = Var.newNoname ()
393 val x = DirectExp.monoVar (xVar, exnTy)
394 in
395 DirectExp.handlee
396 {try = DirectExp.conApp
397 {arg = SOME primAppExp,
398 con = successCon,
399 targs = Vector.new0 (),
400 ty = wrapTy},
401 catch = (xVar, exnTy),
402 handler = DirectExp.conApp
403 {arg = SOME x,
404 con = failureCon,
405 targs = Vector.new0 (),
406 ty = wrapTy},
407 ty = wrapTy}
408 end
409
410 val successCase =
411 let
412 val xVar = Var.newNoname ()
413 in
414 (Pat.T {arg = SOME (xVar, eTy),
415 con = successCon,
416 targs = Vector.new0 ()},
417 DirectExp.app
418 {func = k,
419 arg = DirectExp.monoVar (xVar, eTy),
420 ty = ansTy})
421 end
422 val failureCase =
423 let
424 val xVar = Var.newNoname ()
425 in
426 (Pat.T
427 {arg = SOME (xVar, exnTy),
428 con = failureCon,
429 targs = Vector.new0 ()},
430 DirectExp.app
431 {func = h,
432 arg = DirectExp.monoVar (xVar, exnTy),
433 ty = ansTy})
434 end
435 val cases =
436 Cases.Con (Vector.new2 (successCase, failureCase))
437 in
438 DirectExp.casee
439 {test = testExp,
440 cases = cases,
441 default = NONE,
442 ty = ansTy}
443 end
444 else return primAppExp
445 end
446 | Profile _ =>
447 let
448 (* Profile statements won't properly nest after
449 * CPS conversion.
450 *)
451 in
452 Error.bug "CPSTransform.transPrimExp: Profile"
453 end
454 | Raise {exn, ...} =>
455 DirectExp.app
456 {func = h,
457 arg = transVarExp exn,
458 ty = ansTy}
459 | Select {offset, tuple} =>
460 (return o DirectExp.select)
461 {tuple = transVarExp tuple,
462 offset = offset,
463 ty = eTy}
464 | Tuple xs =>
465 (return o DirectExp.tuple)
466 {exps = Vector.map (xs, transVarExp),
467 ty = eTy}
468 | Var x => return (transVarExp x)
469 end
470 and transDec (d: Dec.t,
471 kBody: DirectExp.t,
472 hVar: Var.t, hTy: Type.t): DirectExp.t =
473 let
474 in
475 case d of
476 Exception _ => Error.bug "CPSTransform.transDec: Exception"
477 | Fun {decs, tyvars} =>
478 let
479 val decs =
480 Vector.map
481 (decs, fn {var, ty, lambda} =>
482 {var = var,
483 ty = transType ty,
484 lambda = transLambda lambda})
485 val d = Fun {decs = decs, tyvars = tyvars}
486 in
487 DirectExp.lett {decs = [d], body = kBody}
488 end
489 | MonoVal {var, ty, exp} =>
490 let
491 val expTy = ty
492 val argVar = var
493 val argTy = transType ty
494 val k'Var = Var.newString "k"
495 val k'Ty = Type.arrow (argTy, ansTy)
496 val k'Body =
497 DirectExp.lambda
498 {arg = argVar,
499 argType = argTy,
500 body = kBody,
501 bodyType = ansTy,
502 mayInline = true}
503 in
504 DirectExp.let1 {var = k'Var, exp = k'Body, body =
505 transPrimExp (exp, expTy, k'Var, k'Ty, hVar, hTy)}
506 end
507 | PolyVal _ => Error.bug "CPSTransform.transDec: PolyVal"
508 end
509 and transExp (e: Exp.t,
510 kVar: Var.t, kTy: Type.t,
511 hVar: Var.t, hTy: Type.t): DirectExp.t =
512 let
513 val {decs, result} = Exp.dest e
514 val k = DirectExp.monoVar (kVar, kTy)
515 val k'Body =
516 DirectExp.app
517 {func = k, arg = transVarExp result, ty = ansTy}
518 in
519 List.foldr
520 (decs, k'Body, fn (dec, kBody) =>
521 transDec (dec, kBody, hVar, hTy))
522 end
523
524 (* Set (original) type of each bound variable. *)
525 val () =
526 Exp.foreachBoundVar
527 (body, fn (v, _, ty) =>
528 setVarOrigType (v, ty))
529
530 (* Translate datatypes. *)
531 val datatypes =
532 Vector.map
533 (datatypes, fn {cons, tycon, tyvars} =>
534 {cons = Vector.map (cons, fn {arg, con} =>
535 {arg = Option.map (arg, transType),
536 con = con}),
537 tycon = tycon,
538 tyvars = tyvars})
539
540 (* Initial continuation. *)
541 val k0 = Var.newString "k0"
542 val k0Body =
543 DirectExp.lambda
544 {arg = Var.newNoname (),
545 argType = ansTy,
546 body = DirectExp.unit (),
547 bodyType = ansTy,
548 mayInline = true}
549 val k0Ty = Type.arrow (ansTy, Type.unit)
550 (* Initial exception continuation. *)
551 val h0 = Var.newString "h0"
552 val h0Body =
553 DirectExp.lambda
554 {arg = Var.newNoname (),
555 argType = exnTy,
556 body = DirectExp.unit (),
557 bodyType = ansTy,
558 mayInline = true}
559 val h0Ty = Type.arrow (exnTy, Type.unit)
560
561 (* Translate body, in context of initial continuations. *)
562 val body = DirectExp.let1 {var = k0, exp = k0Body, body =
563 DirectExp.let1 {var = h0, exp = h0Body, body =
564 transExp (body, k0, k0Ty, h0, h0Ty)}}
565
566 (* Closure-convert (transformation from SXML to SSA) introduces
567 * every (non-main) SSA function with "raises = [exn]";
568 * we need a top-level handler to avoid a "raise mismatch" type
569 * error in the SSA IL.
570 *)
571 val body = DirectExp.handlee
572 {try = body,
573 catch = (Var.newNoname (), exnTy),
574 handler = DirectExp.unit (),
575 ty = ansTy}
576 val body = DirectExp.toExp body
577
578 (* Fetch accumulated wrap datatypes. *)
579 val wrapDatatypes = Vector.fromList (!wrapDatatypes)
580 val datatypes = Vector.concat [datatypes, wrapDatatypes]
581
582 val prog = Program.T {datatypes = datatypes,
583 body = body,
584 overflow = overflow}
585
586 (* Clear and destroy properties. *)
587 val () = Exp.clear body
588 val () = destroyTransType ()
589 val () = destroyWrap ()
590 in
591 prog
592 end
593
594end