Typechecking for basic language done
[hcoop/domtool2.git] / src / tycheck.sml
1 (* HCoop Domtool (http://hcoop.sourceforge.net/)
2 * Copyright (c) 2006, Adam Chlipala
3 *
4 * This program is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License
6 * as published by the Free Software Foundation; either version 2
7 * of the License, or (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, write to the Free Software
16 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17 *)
18
19 (* Domtool configuration language type checking *)
20
21 structure Tycheck :> TYCHECK = struct
22
23 open Ast Print
24
25 structure SM = StringMap
26
27 type env = typ SM.map
28 val empty = SM.empty
29
30 local
31 val unifCount = ref 0
32 in
33 fun resetUnif () = unifCount := 0
34
35 fun newUnif () =
36 let
37 val c = !unifCount
38 val name =
39 if c < 26 then
40 str (chr (ord #"A" + c))
41 else
42 "UNIF" ^ Int.toString (c - 26)
43 in
44 unifCount := c + 1;
45 TUnif (name, ref NONE)
46 end
47 end
48
49 exception UnequalDomains
50
51 fun eqRecord f (r1, r2) =
52 (SM.appi (fn (k, v1) =>
53 case SM.find (r2, k) of
54 NONE => raise UnequalDomains
55 | SOME v2 =>
56 if f (v1, v2) then
57 ()
58 else
59 raise UnequalDomains) r1;
60 SM.appi (fn (k, v2) =>
61 case SM.find (r1, k) of
62 NONE => raise UnequalDomains
63 | SOME v1 =>
64 if f (v1, v2) then
65 ()
66 else
67 raise UnequalDomains) r2;
68 true)
69 handle UnequalDomains => false
70
71 fun eqPred ((p1, _), (p2, _)) =
72 case (p1, p2) of
73 (CRoot, CRoot) => true
74 | (CConst s1, CConst s2) => s1 = s2
75 | (CPrefix p1, CPrefix p2) => eqPred (p1, p2)
76 | (CNot p1, CNot p2) => eqPred (p1, p2)
77 | (CAnd (p1, q1), CAnd (p2, q2)) =>
78 eqPred (p1, p2) andalso eqPred (q1, q2)
79
80 | _ => false
81
82 fun eqTy (t1All as (t1, _), t2All as (t2, _)) =
83 case (t1, t2) of
84 (TBase s1, TBase s2) => s1 = s2
85 | (TList t1, TList t2) => eqTy (t1, t2)
86 | (TArrow (d1, r1), TArrow (d2, r2)) =>
87 eqTy (d1, d2) andalso eqTy (r1, r2)
88
89 | (TAction (p1, d1, r1), TAction (p2, d2, r2)) =>
90 eqPred (p1, p2) andalso eqRecord eqTy (d1, d2)
91 andalso eqRecord eqTy (r1, r2)
92
93 | (TUnif (_, ref (SOME t1)), _) => eqTy (t1, t2All)
94 | (_, TUnif (_, ref (SOME t2))) => eqTy (t1All, t2)
95
96 | (TUnif (_, r1), TUnif (_, r2)) => r1 = r2
97
98 | (TError, TError) => true
99
100 | _ => false
101
102 datatype unification_error =
103 UnifyPred of pred * pred
104 | UnifyTyp of typ * typ
105 | UnifyOccurs of string * typ
106
107 exception Unify of unification_error
108
109 datatype type_error =
110 WrongType of string * exp * typ * typ * unification_error option
111 | WrongForm of string * string * exp * typ * unification_error option
112 | UnboundVariable of string
113
114 fun preface (s, d) = printd (PD.hovBox (PD.PPS.Rel 0,
115 [PD.string s, PD.space 1, d]))
116
117 fun describe_unification_error t ue =
118 case ue of
119 UnifyPred (p1, p2) =>
120 (print "Reason: Incompatible predicates.\n";
121 preface ("Have:", p_pred p1);
122 preface ("Need:", p_pred p2))
123 | UnifyTyp (t1, t2) =>
124 if eqTy (t, t1) then
125 ()
126 else
127 (print "Reason: Incompatible types.\n";
128 preface ("Have:", p_typ t1);
129 preface ("Need:", p_typ t2))
130 | UnifyOccurs (name, t') =>
131 if eqTy (t, t') then
132 ()
133 else
134 (print "Reason: Occurs check failed for ";
135 print name;
136 print " in:\n";
137 printd (p_typ t))
138
139 fun describe_type_error loc te =
140 case te of
141 WrongType (place, e, t1, t2, ueo) =>
142 (ErrorMsg.error (SOME loc) (place ^ " has wrong type.");
143 preface (" Expression:", p_exp e);
144 preface ("Actual type:", p_typ t1);
145 preface ("Needed type:", p_typ t2);
146 Option.app (describe_unification_error t1) ueo)
147 | WrongForm (place, form, e, t, ueo) =>
148 (ErrorMsg.error (SOME loc) (place ^ " has a non-" ^ form ^ " type.");
149 preface ("Expression:", p_exp e);
150 preface (" Type:", p_typ t);
151 Option.app (describe_unification_error t) ueo)
152 | UnboundVariable name =>
153 ErrorMsg.error (SOME loc) ("Unbound variable " ^ name ^ ".\n")
154
155 fun predImplies (p1All as (p1, _), p2All as (p2, _)) =
156 case (p1, p2) of
157 (_, CPrefix (CRoot, _)) => true
158 | (CNot (CPrefix (CRoot, _), _), _) => true
159
160 | (CRoot, CRoot) => true
161
162 | (CConst s1, CConst s2) => s1 = s2
163
164 | (CPrefix p1, CPrefix p2) => predImplies (p1, p2)
165
166 | (CNot p1, CNot p2) => predImplies (p2, p1)
167
168 | (_, CAnd (p1, p2)) => predImplies (p1All, p1) andalso predImplies (p1All, p2)
169 | (CAnd (p1, p2), _) => predImplies (p1, p2All) orelse predImplies (p2, p2All)
170
171 | _ => false
172
173 fun predSimpl (pAll as (p, loc)) =
174 case p of
175 CRoot => pAll
176 | CConst _ => pAll
177 | CPrefix p => (CPrefix (predSimpl p), loc)
178 | CNot p => (CNot (predSimpl p), loc)
179 | CAnd (p1, p2) =>
180 let
181 val p1' = predSimpl p1
182 val p2' = predSimpl p2
183 in
184 case p1' of
185 (CAnd (c1, c2), _) => predSimpl (CAnd (c1, (CAnd (c2, p2'), loc)), loc)
186 | _ => if predImplies (p2', p1') then
187 p2'
188 else
189 (CAnd (p1', p2'), loc)
190 end
191
192 fun unifyPred (p1, p2) =
193 if predImplies (p1, p2) then
194 ()
195 else
196 raise (Unify (UnifyPred (p1, p2)))
197
198 fun unifyRecord f (r1, r2) =
199 (SM.appi (fn (k, v1) =>
200 case SM.find (r2, k) of
201 NONE => raise UnequalDomains
202 | SOME v2 => f (v1, v2)) r1;
203 SM.appi (fn (k, v2) =>
204 case SM.find (r1, k) of
205 NONE => raise UnequalDomains
206 | SOME v1 => f (v1, v2)) r2)
207
208 fun occurs u (t, _) =
209 case t of
210 TBase _ => false
211 | TList t => occurs u t
212 | TArrow (d, r) => occurs u d orelse occurs u r
213 | TAction (_, d, r) =>
214 List.exists (occurs u) (SM.listItems d)
215 orelse List.exists (occurs u) (SM.listItems r)
216 | TError => false
217 | TUnif (_, ref (SOME t)) => occurs u t
218 | TUnif (_, u') => u = u'
219
220 fun unify (t1All as (t1, _), t2All as (t2, _)) =
221 case (t1, t2) of
222 (TBase s1, TBase s2) =>
223 if s1 = s2 then
224 ()
225 else
226 raise Unify (UnifyTyp (t1All, t2All))
227 | (TList t1, TList t2) => unify (t1, t2)
228 | (TArrow (d1, r1), TArrow (d2, r2)) =>
229 (unify (d1, d2);
230 unify (r1, r2))
231
232 | (TAction (p1, d1, r1), TAction (p2, d2, r2)) =>
233 ((unifyPred (p1, p2);
234 unifyRecord unify (d1, d2);
235 unifyRecord unify (r1, r2))
236 handle UnequalDomains => raise Unify (UnifyTyp (t1All, t2All)))
237
238 | (TUnif (_, ref (SOME t1)), _) => unify (t1, t2All)
239 | (_, TUnif (_, ref (SOME t2))) => unify (t1All, t2)
240
241 | (TUnif (_, r1), TUnif (_, r2)) =>
242 if r1 = r2 then
243 ()
244 else
245 r1 := SOME t2All
246
247 | (TUnif (name, r), _) =>
248 if occurs r t2All then
249 raise (Unify (UnifyOccurs (name, t2All)))
250 else
251 r := SOME t2All
252
253 | (_, TUnif (name, r)) =>
254 if occurs r t1All then
255 raise (Unify (UnifyOccurs (name, t1All)))
256 else
257 r := SOME t1All
258
259 | (TError, _) => ()
260 | (_, TError) => ()
261
262 | _ => raise Unify (UnifyTyp (t1All, t2All))
263
264 fun isError t =
265 case t of
266 (TError, _) => true
267 | _ => false
268
269 fun whnorm (tAll as (t, loc)) =
270 case t of
271 TUnif (_, ref (SOME tAll)) => whnorm tAll
272 | _ => tAll
273
274 fun checkExp G (eAll as (e, loc)) =
275 let
276 val dte = describe_type_error loc
277 in
278 case e of
279 EInt _ => (TBase "int", loc)
280 | EString _ => (TBase "string", loc)
281 | EList es =>
282 let
283 val t = (newUnif (), loc)
284 in
285 foldl (fn (e', ret) =>
286 let
287 val t' = checkExp G e'
288 in
289 (unify (t', t);
290 if isError t' then
291 (TList (TError, loc), loc)
292 else
293 ret)
294 handle Unify ue =>
295 (dte (WrongType ("List element",
296 e',
297 t',
298 t,
299 SOME ue));
300 (TError, loc))
301 end) (TList t, loc) es
302 end
303
304 | ELam (x, to, e) =>
305 let
306 val t =
307 case to of
308 NONE => (newUnif (), loc)
309 | SOME t => t
310
311 val G' = SM.insert (G, x, t)
312 val t' = checkExp G' e
313 in
314 (TArrow (t, t'), loc)
315 end
316 | EVar x =>
317 (case SM.find (G, x) of
318 NONE => (dte (UnboundVariable x);
319 (TError, loc))
320 | SOME t => t)
321 | EApp (func, arg) =>
322 let
323 val dom = (newUnif (), loc)
324 val ran = (newUnif (), loc)
325
326 val tf = checkExp G func
327 val ta = checkExp G arg
328 in
329 (unify (tf, (TArrow (dom, ran), loc));
330 unify (ta, dom)
331 handle Unify ue =>
332 dte (WrongType ("Function argument",
333 arg,
334 ta,
335 dom,
336 SOME ue));
337 ran)
338 handle Unify ue =>
339 (dte (WrongForm ("Function to be applied",
340 "function",
341 func,
342 tf,
343 SOME ue));
344 (TError, loc))
345 end
346
347 | ESet (evar, e) =>
348 let
349 val t = checkExp G e
350 in
351 (TAction ((CPrefix (CRoot, loc), loc),
352 SM.empty,
353 SM.insert (SM.empty, evar, t)),
354 loc)
355 end
356 | EGet (x, evar, rest) =>
357 let
358 val xt = (newUnif (), loc)
359 val G' = SM.insert (G, x, xt)
360
361 val rt = whnorm (checkExp G' rest)
362 in
363 case rt of
364 (TAction (p, d, r), _) =>
365 (case SM.find (d, evar) of
366 NONE => (TAction (p, SM.insert (d, evar, xt), r), loc)
367 | SOME xt' =>
368 (unify (xt', xt)
369 handle Unify ue =>
370 dte (WrongType ("Retrieved environment variable",
371 (EVar x, loc),
372 xt',
373 xt,
374 SOME ue));
375 rt))
376 | _ => (dte (WrongForm ("Body of environment variable read",
377 "action",
378 rest,
379 rt,
380 NONE));
381 (TError, loc))
382 end
383
384 | ESeq [] => raise Fail "Empty ESeq"
385 | ESeq [e1] => checkExp G e1
386 | ESeq (e1 :: rest) =>
387 let
388 val e2 = (ESeq rest, loc)
389
390 val t1 = whnorm (checkExp G e1)
391 val t2 = whnorm (checkExp G e2)
392 in
393 case t1 of
394 (TAction (p1, d1, r1), _) =>
395 (case t2 of
396 (TAction (p2, d2, r2), _) =>
397 let
398 val p' = predSimpl (CAnd (p1, p2), loc)
399
400 val d' = SM.foldli (fn (name, t, d') =>
401 case SM.find (r1, name) of
402 NONE =>
403 (case SM.find (d', name) of
404 NONE => SM.insert (d', name, t)
405 | SOME t' =>
406 (unify (t, t')
407 handle Unify ue =>
408 dte (WrongType ("Shared environment variable",
409 (EVar name, loc),
410 t,
411 t',
412 SOME ue));
413 d'))
414 | SOME t' =>
415 (unify (t, t')
416 handle Unify ue =>
417 dte (WrongType ("Shared environment variable",
418 (EVar name, loc),
419 t,
420 t',
421 SOME ue));
422 d'))
423 d1 d2
424
425 val r' = SM.foldli (fn (name, t, r') => SM.insert (r', name, t))
426 r1 r2
427 in
428 (TAction (p', d', r'), loc)
429 end
430 | _ => (dte (WrongForm ("Action to be sequenced",
431 "action",
432 e2,
433 t2,
434 NONE));
435 (TError, loc)))
436 | _ => (dte (WrongForm ("Action to be sequenced",
437 "action",
438 e1,
439 t1,
440 NONE));
441 (TError, loc))
442 end
443
444 | ELocal e =>
445 let
446 val rt = whnorm (checkExp G e)
447 in
448 case rt of
449 (TAction (p, d, _), _) =>
450 (TAction (p, d, SM.empty), loc)
451 | _ => (dte (WrongForm ("Body of local action",
452 "action",
453 e,
454 rt,
455 NONE));
456 (TError, loc))
457 end
458 end
459
460 exception Ununif
461
462 fun ununif (tAll as (t, loc)) =
463 case t of
464 TBase _ => tAll
465 | TList t => (TList (ununif t), loc)
466 | TArrow (d, r) => (TArrow (ununif d, ununif r), loc)
467 | TAction (p, d, r) => (TAction (p, SM.map ununif d, SM.map ununif r), loc)
468 | TUnif (_, ref (SOME t)) => ununif t
469 | TError => tAll
470
471 | TUnif (_, ref NONE) => raise Ununif
472
473 fun hasError (t, _) =
474 case t of
475 TBase _ => false
476 | TList t => hasError t
477 | TArrow (d, r) => hasError d orelse hasError r
478 | TAction (p, d, r) => List.exists hasError (SM.listItems d)
479 orelse List.exists hasError (SM.listItems r)
480 | TError => false
481 | TUnif (_, ref (SOME t)) => hasError t
482 | TUnif (_, ref NONE) => false
483
484
485 fun checkUnit G (eAll as (_, loc)) =
486 let
487 val _ = resetUnif ()
488 val t = checkExp G eAll
489 in
490 if hasError t then
491 t
492 else
493 ununif t
494 handle Ununif =>
495 (ErrorMsg.error (SOME loc) "Unification variables remain in type:";
496 printd (p_typ t);
497 t)
498 end
499
500 end