Backport from sid to buster
[hcoop/debian/mlton.git] / regression / callcc2.sml
CommitLineData
7f918cf1
CE
1type ident = string
2type con = string
3
4datatype pattern =
5 PVar of ident
6 | PAlias of ident * pattern
7 | PConstruct of con * pattern list
8 | PAliasD of ident * pattern
9 | PConstructD of con * pattern list
10
11datatype exp =
12 Var of ident
13 | Lam of ident * exp
14 | App of exp * exp
15 | Construct of con * exp list
16 | Case of exp * (pattern * exp) list
17 | Let of ident * exp * exp
18
19 | LamD of ident * exp
20 | AppD of exp * exp
21 | ConstructD of con * exp list
22 | CaseD of exp * (pattern * exp) list
23 | LetD of ident * exp * exp
24
25 | Lift of exp
26
27datatype value =
28 Fun of (value -> value)
29 | Con of con * value list
30 | Code of exp
31 | Wrong
32
33val valueToString =
34 fn Fun _ => "Fun"
35 | Con _ => "Con"
36 | Code _ => "Code"
37 | Wrong => "Wrong"
38
39(* control operators *)
40(*********************)
41
42(* toplevel resetMarker *)
43val metaCont = ref (fn (x : value) => x)
44
45fun abort thunk =
46 let val v = thunk () in
47 !metaCont v
48 end
49
50fun reset thunk =
51 let val mc = !metaCont in
52 SMLofNJ.Cont.callcc
53 (fn k => let (* new marker which restores old one *)
54 val _ = metaCont := (fn v =>
55 let val _ = metaCont := mc in
56 SMLofNJ.Cont.throw k v
57 end)
58 in
59 abort thunk
60 end)
61 end
62
63fun shift f =
64 SMLofNJ.Cont.callcc
65 (fn k => abort (fn () => f
66 (fn v => reset
67 (fn () => SMLofNJ.Cont.throw k v))))
68
69(*********************)
70
71(* environment *)
72exception UnboundVar of ident
73
74fun update r var value = (var, value) :: r
75
76fun lookup [] var = raise (UnboundVar var)
77 | lookup ((var, value) :: r) var' =
78 if var = var' then value else lookup r var'
79
80(* pattern matcher - binds variables
81 patterns are linear and pairwise disjoint *)
82fun patterneq (p, value) r =
83 case p of
84 PVar x => (update r x value, true)
85 | PAlias (x, p) =>
86 let val (r', eq) = patterneq (p, value) r in
87 (update r' x value, eq)
88 end
89 | PConstruct (c, ps) =>
90 let val Con(c', vs) = value
91 val eq = (c = c')
92 val eq = eq andalso (List.length vs = List.length ps)
93 in
94 List.foldl (fn ((p, v), (r', eq')) =>
95 let val (r'', eq'') = patterneq (p, v) r' in
96 (r'', eq'' andalso eq')
97 end) (r, eq) (ListPair.zip (ps, vs))
98 end
99
100val gensym =
101 let val count = ref 0 in
102 (fn x => (count := !count + 1;
103 (x^(Int.toString (!count)))))
104 end
105
106(* copies pattern with fresh variables bound in new environment *)
107fun generatePattern (r, p) =
108 case p of
109 PVar x =>
110 let val xx = gensym x in
111 (update r x (Code (Var xx)), PVar xx)
112 end
113 | PAliasD (x, p) =>
114 let val (r', p') = generatePattern (r, p)
115 val xx = gensym x
116 in
117 (update r x (Code (Var xx)),
118 PAlias (xx, p'))
119 end
120 | PConstructD (c, ps) =>
121 let val (r, ps) =
122 List.foldr (fn (p, (r, ps)) =>
123 let val (r', p') = generatePattern (r, p) in
124 (r', p' :: ps)
125 end) (r, []) ps
126 in
127 (r, PConstruct (c, ps))
128 end
129
130(* the specializer *)
131fun spec e r =
132 case e of
133 Var x => lookup r x
134
135 (* Specialization of Static Stuff - standard semantics *)
136 | Lam (x, e) => Fun (fn y => spec e (update r x y))
137
138 | App (f, a) =>
139 let val Fun ff = spec f r in
140 ff (spec a r)
141 end
142
143 | Construct (c, es) =>
144 let val vs = List.map (fn e => spec e r) es in
145 Con (c, vs)
146 end
147
148 | Case (test, cls) =>
149 let val testv = spec test r
150 (* exhaustive by restriction on patterns *)
151 fun loop cls =
152 (case cls of
153 ((p, e) :: cls) =>
154 let val (r', eq) = patterneq (p, testv) r in
155 if eq then spec e r' else loop cls
156 end
157 | [] => Wrong)
158 in loop cls end
159
160 | Let (x, e1, e2) => let val v1 = spec e1 r in spec e2 (update r x v1) end
161
162 (* Specialization of Dynamic stuff *)
163 | LamD (x, e) =>
164 let val xx = gensym x
165 val Code body =
166 reset (fn () => spec e (update r x (Code (Var xx))))
167 in
168 Code (Lam (xx, body))
169 end
170
171 | AppD (f, a) =>
172 let val Code ff = spec f r
173 val Code aa = spec a r
174 in
175 Code (App (ff, aa))
176 end
177
178 | ConstructD (c, es) =>
179 let val es' = List.map (fn e => let val Code v = spec e r
180 in v end) es
181 in
182 Code (Construct (c, es'))
183 end
184
185 | LetD (x, e1, e2) =>
186 let val xx = gensym x in
187 shift (fn k =>
188 let val Code e1' = spec e1 r
189 val Code e2' =
190 reset (fn () => k (spec e2 (update r x (Code (Var xx)))))
191 in
192 Code (Let (xx, e1', e2'))
193 end)
194 end
195
196 | CaseD (test, cls) =>
197 shift (fn k =>
198 let val Code testd = spec test r
199 val newCls = List.map (fn (p, e) =>
200 let val (r', p') = generatePattern(r, p)
201 val Code branch = reset (fn () => k (spec e r'))
202 in
203 (p', branch)
204 end) cls
205 in
206 Code (Case(testd, newCls))
207 end)
208
209 (* first-order lifting *)
210 | Lift e =>
211 let val Con(c, []) = spec e r in
212 Code(Construct (c, []))
213 end
214
215fun specialize p = spec p []
216
217(* standard evaluation *)
218val sampleProg1 = Lam("q", App(Let("id",
219 App(Var "q", Var "q"),
220 Lam("z", Var "z")),
221 Var "q"))
222
223val sampleProg2 = Lam("f", App(Lam("x",
224 Case(Var "x",
225 [(PConstruct("True",[]),
226 Lam("x",Lam("y",Var "x"))),
227 (PConstruct("False",[]),
228 Lam("x",Lam("y",Var "y")))])),
229 Var "f"))
230
231(* partial evaluation *)
232val sampleProg1D = LamD("q", App(LetD("id",
233 AppD(Var "q", Var "q"),
234 Lam("z", Var "z")),
235 Var "q"))
236
237val sampleProg2D = LamD("f", LamD("x",
238 App(CaseD(Var "x",
239 [(PConstructD("True",[]),
240 Lam("z",LamD("y", Var "z"))),
241 (PConstructD("False",[]),
242 Lam("z",LamD("y", Var "y")))]),
243 Var "f")))
244
245val specialize =
246 fn p =>
247 let val v = specialize p
248 in print(valueToString v)
249 ; print "\n"
250 end
251
252val v1 = specialize sampleProg1
253val v2 = specialize sampleProg2
254val v3 = specialize sampleProg1D
255val v4 = specialize sampleProg2