Domains example
[hcoop/domtool2.git] / src / tycheck.sml
1 (* HCoop Domtool (http://hcoop.sourceforge.net/)
2 * Copyright (c) 2006, Adam Chlipala
3 *
4 * This program is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License
6 * as published by the Free Software Foundation; either version 2
7 * of the License, or (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, write to the Free Software
16 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17 *)
18
19 (* Domtool configuration language type checking *)
20
21 structure Tycheck :> TYCHECK = struct
22
23 open Ast Print
24
25 structure SS = StringSet
26 structure SM = StringMap
27
28 type env = SS.set * typ SM.map
29 val empty : env = (SS.add (SS.singleton "int", "string"),
30 SM.empty)
31
32 fun lookupType (ts, _) name = SS.member (ts, name)
33 fun lookupVal (_, vs) name = SM.find (vs, name)
34
35 fun bindType (ts, vs) name = (SS.add (ts, name), vs)
36 fun bindVal (ts, vs) (name, t) = (ts, SM.insert (vs, name, t))
37
38 local
39 val unifCount = ref 0
40 in
41 fun resetUnif () = unifCount := 0
42
43 fun newUnif () =
44 let
45 val c = !unifCount
46 val name =
47 if c < 26 then
48 str (chr (ord #"A" + c))
49 else
50 "UNIF" ^ Int.toString (c - 26)
51 in
52 unifCount := c + 1;
53 TUnif (name, ref NONE)
54 end
55 end
56
57 exception UnequalDomains
58
59 fun eqRecord f (r1, r2) =
60 (SM.appi (fn (k, v1) =>
61 case SM.find (r2, k) of
62 NONE => raise UnequalDomains
63 | SOME v2 =>
64 if f (v1, v2) then
65 ()
66 else
67 raise UnequalDomains) r1;
68 SM.appi (fn (k, v2) =>
69 case SM.find (r1, k) of
70 NONE => raise UnequalDomains
71 | SOME v1 =>
72 if f (v1, v2) then
73 ()
74 else
75 raise UnequalDomains) r2;
76 true)
77 handle UnequalDomains => false
78
79 fun eqPred ((p1, _), (p2, _)) =
80 case (p1, p2) of
81 (CRoot, CRoot) => true
82 | (CConst s1, CConst s2) => s1 = s2
83 | (CPrefix p1, CPrefix p2) => eqPred (p1, p2)
84 | (CNot p1, CNot p2) => eqPred (p1, p2)
85 | (CAnd (p1, q1), CAnd (p2, q2)) =>
86 eqPred (p1, p2) andalso eqPred (q1, q2)
87
88 | _ => false
89
90 fun eqTy (t1All as (t1, _), t2All as (t2, _)) =
91 case (t1, t2) of
92 (TBase s1, TBase s2) => s1 = s2
93 | (TList t1, TList t2) => eqTy (t1, t2)
94 | (TArrow (d1, r1), TArrow (d2, r2)) =>
95 eqTy (d1, d2) andalso eqTy (r1, r2)
96
97 | (TAction (p1, d1, r1), TAction (p2, d2, r2)) =>
98 eqPred (p1, p2) andalso eqRecord eqTy (d1, d2)
99 andalso eqRecord eqTy (r1, r2)
100
101 | (TNested (p1, q1), TNested (p2, q2)) =>
102 eqPred (p1, p2) andalso eqTy (q1, q2)
103
104 | (TUnif (_, ref (SOME t1)), _) => eqTy (t1, t2All)
105 | (_, TUnif (_, ref (SOME t2))) => eqTy (t1All, t2)
106
107 | (TUnif (_, r1), TUnif (_, r2)) => r1 = r2
108
109 | (TError, TError) => true
110
111 | _ => false
112
113 datatype unification_error =
114 UnifyPred of pred * pred
115 | UnifyTyp of typ * typ
116 | UnifyOccurs of string * typ
117
118 exception Unify of unification_error
119
120 datatype type_error =
121 WrongType of string * exp * typ * typ * unification_error option
122 | WrongForm of string * string * exp * typ * unification_error option
123 | UnboundVariable of string
124 | WrongPred of string * pred * pred
125
126 fun preface (s, d) = printd (PD.hovBox (PD.PPS.Rel 0,
127 [PD.string s, PD.space 1, d]))
128
129 fun describe_unification_error t ue =
130 case ue of
131 UnifyPred (p1, p2) =>
132 (print "Reason: Incompatible contexts.\n";
133 preface ("Have:", p_pred p1);
134 preface ("Need:", p_pred p2))
135 | UnifyTyp (t1, t2) =>
136 if eqTy (t, t1) then
137 ()
138 else
139 (print "Reason: Incompatible types.\n";
140 preface ("Have:", p_typ t1);
141 preface ("Need:", p_typ t2))
142 | UnifyOccurs (name, t') =>
143 if eqTy (t, t') then
144 ()
145 else
146 (print "Reason: Occurs check failed for ";
147 print name;
148 print " in:\n";
149 printd (p_typ t))
150
151 fun describe_type_error loc te =
152 case te of
153 WrongType (place, e, t1, t2, ueo) =>
154 (ErrorMsg.error (SOME loc) (place ^ " has wrong type.");
155 preface (" Expression:", p_exp e);
156 preface ("Actual type:", p_typ t1);
157 preface ("Needed type:", p_typ t2);
158 Option.app (describe_unification_error t1) ueo)
159 | WrongForm (place, form, e, t, ueo) =>
160 (ErrorMsg.error (SOME loc) (place ^ " has a non-" ^ form ^ " type.");
161 preface ("Expression:", p_exp e);
162 preface (" Type:", p_typ t);
163 Option.app (describe_unification_error t) ueo)
164 | UnboundVariable name =>
165 ErrorMsg.error (SOME loc) ("Unbound variable " ^ name ^ ".\n")
166 | WrongPred (place, p1, p2) =>
167 (ErrorMsg.error (SOME loc) ("Context incompatibility for " ^ place ^ ".");
168 preface ("Have:", p_pred p1);
169 preface ("Need:", p_pred p2))
170
171 fun predImplies (p1All as (p1, _), p2All as (p2, _)) =
172 case (p1, p2) of
173 (_, CAnd (p1, p2)) => predImplies (p1All, p1) andalso predImplies (p1All, p2)
174 | (CAnd (p1, p2), _) => predImplies (p1, p2All) orelse predImplies (p2, p2All)
175
176 | (_, CPrefix (CRoot, _)) => true
177 | (CNot (CPrefix (CRoot, _), _), _) => true
178
179 | (CRoot, CRoot) => true
180
181 | (CConst s1, CConst s2) => s1 = s2
182
183 | (CPrefix p1, CPrefix p2) => predImplies (p1, p2)
184 | (_, CPrefix p2) => predImplies (p1All, p2)
185
186 | (CNot p1, CNot p2) => predImplies (p2, p1)
187
188 | _ => false
189
190 fun predSimpl (pAll as (p, loc)) =
191 case p of
192 CRoot => pAll
193 | CConst _ => pAll
194 | CPrefix p => (CPrefix (predSimpl p), loc)
195 | CNot p => (CNot (predSimpl p), loc)
196 | CAnd (p1, p2) =>
197 let
198 val p1' = predSimpl p1
199 val p2' = predSimpl p2
200 in
201 case p1' of
202 (CAnd (c1, c2), _) => predSimpl (CAnd (c1, (CAnd (c2, p2'), loc)), loc)
203 | _ => if predImplies (p2', p1') then
204 p2'
205 else if predImplies (p1', p2') then
206 p1'
207 else
208 (CAnd (p1', p2'), loc)
209 end
210
211 fun subPred (p1, p2) =
212 if predImplies (p1, p2) then
213 ()
214 else
215 raise (Unify (UnifyPred (p1, p2)))
216
217 fun subRecord f (r1, r2) =
218 SM.appi (fn (k, v2) =>
219 case SM.find (r1, k) of
220 NONE => raise UnequalDomains
221 | SOME v1 => f (v1, v2)) r2
222
223 fun occurs u (t, _) =
224 case t of
225 TBase _ => false
226 | TList t => occurs u t
227 | TArrow (d, r) => occurs u d orelse occurs u r
228 | TAction (_, d, r) =>
229 List.exists (occurs u) (SM.listItems d)
230 orelse List.exists (occurs u) (SM.listItems r)
231 | TNested (_, t) => occurs u t
232 | TError => false
233 | TUnif (_, ref (SOME t)) => occurs u t
234 | TUnif (_, u') => u = u'
235
236 fun subTyp (t1All as (t1, _), t2All as (t2, _)) =
237 case (t1, t2) of
238 (TBase s1, TBase s2) =>
239 if s1 = s2 then
240 ()
241 else
242 raise Unify (UnifyTyp (t1All, t2All))
243 | (TList t1, TList t2) => subTyp (t1, t2)
244 | (TArrow (d1, r1), TArrow (d2, r2)) =>
245 (subTyp (d2, d1);
246 subTyp (r1, r2))
247
248 | (TAction (p1, d1, r1), TAction (p2, d2, r2)) =>
249 ((subPred (p2, p1);
250 subRecord subTyp (d2, d1);
251 subRecord subTyp (r1, r2);
252 subRecord subTyp (r2, r1))
253 handle UnequalDomains => raise Unify (UnifyTyp (t1All, t2All)))
254
255 | (TNested (d1, r1), TNested (d2, r2)) =>
256 (subPred (d2, d1);
257 subTyp (r1, r2))
258
259 | (TUnif (_, ref (SOME t1)), _) => subTyp (t1, t2All)
260 | (_, TUnif (_, ref (SOME t2))) => subTyp (t1All, t2)
261
262 | (TUnif (_, r1), TUnif (_, r2)) =>
263 if r1 = r2 then
264 ()
265 else
266 r1 := SOME t2All
267
268 | (TUnif (name, r), _) =>
269 if occurs r t2All then
270 raise (Unify (UnifyOccurs (name, t2All)))
271 else
272 r := SOME t2All
273
274 | (_, TUnif (name, r)) =>
275 if occurs r t1All then
276 raise (Unify (UnifyOccurs (name, t1All)))
277 else
278 r := SOME t1All
279
280 | (TError, _) => ()
281 | (_, TError) => ()
282
283 | _ => raise Unify (UnifyTyp (t1All, t2All))
284
285 fun isError t =
286 case t of
287 (TError, _) => true
288 | _ => false
289
290 fun whnorm (tAll as (t, loc)) =
291 case t of
292 TUnif (_, ref (SOME tAll)) => whnorm tAll
293 | _ => tAll
294
295 fun checkTyp G (tAll as (t, loc)) =
296 let
297 val err = ErrorMsg.error (SOME loc)
298 in
299 case t of
300 TBase name =>
301 if lookupType G name then
302 tAll
303 else
304 (err ("Unbound type name " ^ name);
305 (TError, loc))
306 | TList t => (TList (checkTyp G t), loc)
307 | TArrow (d, r) => (TArrow (checkTyp G d, checkTyp G r), loc)
308 | TAction (p, d, r) => (TAction (p, SM.map (checkTyp G) d,
309 SM.map (checkTyp G) r), loc)
310 | TNested (p, t) => (TNested (p, checkTyp G t), loc)
311 | TError => raise Fail "TError in parser-generated type"
312 | TUnif _ => raise Fail "TUnif in parser-generated type"
313 end
314
315 fun checkExp G (eAll as (e, loc)) =
316 let
317 val dte = describe_type_error loc
318 in
319 case e of
320 EInt _ => (TBase "int", loc)
321 | EString _ => (TBase "string", loc)
322 | EList es =>
323 let
324 val t = (newUnif (), loc)
325 in
326 foldl (fn (e', ret) =>
327 let
328 val t' = checkExp G e'
329 in
330 (subTyp (t', t);
331 if isError t' then
332 (TList (TError, loc), loc)
333 else
334 ret)
335 handle Unify ue =>
336 (dte (WrongType ("List element",
337 e',
338 t',
339 t,
340 SOME ue));
341 (TError, loc))
342 end) (TList t, loc) es
343 end
344
345 | ELam (x, to, e) =>
346 let
347 val t =
348 case to of
349 NONE => (newUnif (), loc)
350 | SOME t => checkTyp G t
351
352 val G' = bindVal G (x, t)
353 val t' = checkExp G' e
354 in
355 (TArrow (t, t'), loc)
356 end
357 | EVar x =>
358 (case lookupVal G x of
359 NONE => (dte (UnboundVariable x);
360 (TError, loc))
361 | SOME t => t)
362 | EApp (func, arg) =>
363 let
364 val dom = (newUnif (), loc)
365 val ran = (newUnif (), loc)
366
367 val tf = checkExp G func
368 val ta = checkExp G arg
369 in
370 (subTyp (tf, (TArrow (dom, ran), loc));
371 subTyp (ta, dom)
372 handle Unify ue =>
373 dte (WrongType ("Function argument",
374 arg,
375 ta,
376 dom,
377 SOME ue));
378 ran)
379 handle Unify ue =>
380 (dte (WrongForm ("Function to be applied",
381 "function",
382 func,
383 tf,
384 SOME ue));
385 (TError, loc))
386 end
387
388 | ESet (evar, e) =>
389 let
390 val t = checkExp G e
391 in
392 (TAction ((CPrefix (CRoot, loc), loc),
393 SM.empty,
394 SM.insert (SM.empty, evar, t)),
395 loc)
396 end
397 | EGet (x, evar, rest) =>
398 let
399 val xt = (newUnif (), loc)
400 val G' = bindVal G (x, xt)
401
402 val rt = whnorm (checkExp G' rest)
403 in
404 case rt of
405 (TAction (p, d, r), _) =>
406 (case SM.find (d, evar) of
407 NONE => (TAction (p, SM.insert (d, evar, xt), r), loc)
408 | SOME xt' =>
409 (subTyp (xt', xt)
410 handle Unify ue =>
411 dte (WrongType ("Retrieved environment variable",
412 (EVar x, loc),
413 xt',
414 xt,
415 SOME ue));
416 rt))
417 | (TError, _) => rt
418 | _ => (dte (WrongForm ("Body of environment variable read",
419 "action",
420 rest,
421 rt,
422 NONE));
423 (TError, loc))
424 end
425
426 | ESeq [] => raise Fail "Empty ESeq"
427 | ESeq [e1] => checkExp G e1
428 | ESeq (e1 :: rest) =>
429 let
430 val e2 = (ESeq rest, loc)
431
432 val t1 = whnorm (checkExp G e1)
433 val t2 = whnorm (checkExp G e2)
434 in
435 case t1 of
436 (TAction (p1, d1, r1), _) =>
437 (case t2 of
438 (TAction (p2, d2, r2), _) =>
439 let
440 val p' = predSimpl (CAnd (p1, p2), loc)
441
442 val d' = SM.foldli (fn (name, t, d') =>
443 case SM.find (r1, name) of
444 NONE =>
445 (case SM.find (d', name) of
446 NONE => SM.insert (d', name, t)
447 | SOME t' =>
448 (subTyp (t, t')
449 handle Unify ue =>
450 dte (WrongType ("Shared environment variable",
451 (EVar name, loc),
452 t,
453 t',
454 SOME ue));
455 d'))
456 | SOME t' =>
457 (subTyp (t, t')
458 handle Unify ue =>
459 dte (WrongType ("Shared environment variable",
460 (EVar name, loc),
461 t,
462 t',
463 SOME ue));
464 d'))
465 d1 d2
466
467 val r' = SM.foldli (fn (name, t, r') => SM.insert (r', name, t))
468 r1 r2
469 in
470 (TAction (p', d', r'), loc)
471 end
472 | (TError, _) => t2
473 | _ => (dte (WrongForm ("Action to be sequenced",
474 "action",
475 e2,
476 t2,
477 NONE));
478 (TError, loc)))
479 | (TError, _) => t1
480 | _ => (dte (WrongForm ("Action to be sequenced",
481 "action",
482 e1,
483 t1,
484 NONE));
485 (TError, loc))
486 end
487
488 | ELocal (e1, e2) =>
489 let
490 val t1 = whnorm (checkExp G e1)
491 val t2 = whnorm (checkExp G e2)
492 in
493 case t1 of
494 (TAction (p1, d1, r1), _) =>
495 (case t2 of
496 (TAction (p2, d2, r2), _) =>
497 let
498 val p' = predSimpl (CAnd (p1, p2), loc)
499
500 val d' = SM.foldli (fn (name, t, d') =>
501 case SM.find (r1, name) of
502 NONE =>
503 (case SM.find (d', name) of
504 NONE => SM.insert (d', name, t)
505 | SOME t' =>
506 (subTyp (t, t')
507 handle Unify ue =>
508 dte (WrongType ("Shared environment variable",
509 (EVar name, loc),
510 t,
511 t',
512 SOME ue));
513 d'))
514 | SOME t' =>
515 (subTyp (t, t')
516 handle Unify ue =>
517 dte (WrongType ("Shared environment variable",
518 (EVar name, loc),
519 t,
520 t',
521 SOME ue));
522 d'))
523 d1 d2
524 in
525 (TAction (p', d', r2), loc)
526 end
527 | (TError, _) => t2
528 | _ => (dte (WrongForm ("Action to be sequenced",
529 "action",
530 e2,
531 t2,
532 NONE));
533 (TError, loc)))
534 | (TError, _) => t1
535 | _ => (dte (WrongForm ("Action to be sequenced",
536 "action",
537 e1,
538 t1,
539 NONE));
540 (TError, loc))
541 end
542
543
544 | EWith (e1, e2) =>
545 let
546 val t1 = whnorm (checkExp G e1)
547 val t2 = whnorm (checkExp G e2)
548 in
549 case t1 of
550 (TNested (pd, (TAction (pr, d1, r1), _)), _) =>
551 (case t2 of
552 (TAction (p, d, r), _) =>
553 if predImplies (pd, p) then
554 let
555 val combineRecs =
556 SM.unionWithi (fn (name, t1, t2) =>
557 (subTyp (t1, t2)
558 handle Unify ue =>
559 dte (WrongType ("Environment variable",
560 (EVar name, loc),
561 t1,
562 t2,
563 SOME ue));
564 t2))
565 in
566 (TAction (pr, combineRecs (d, d1),
567 combineRecs (r, r1)), loc)
568 end
569 else
570 (dte (WrongPred ("nested action",
571 pd,
572 p));
573 (TError, loc))
574 | (TError, _) => t2
575 | _ =>
576 (dte (WrongForm ("Body of nested action",
577 "action",
578 e2,
579 t2,
580 NONE));
581 (TError, loc)))
582 | (TError, _) => t1
583 | _ =>
584 (dte (WrongForm ("Container of nested action",
585 "action",
586 e1,
587 t1,
588 NONE));
589 (TError, loc))
590 end
591
592 | ESkip => (TAction ((CPrefix (CRoot, loc), loc),
593 SM.empty, SM.empty), loc)
594 end
595
596 exception Ununif
597
598 fun ununif (tAll as (t, loc)) =
599 case t of
600 TBase _ => tAll
601 | TList t => (TList (ununif t), loc)
602 | TArrow (d, r) => (TArrow (ununif d, ununif r), loc)
603 | TAction (p, d, r) => (TAction (p, SM.map ununif d, SM.map ununif r), loc)
604 | TUnif (_, ref (SOME t)) => ununif t
605 | TNested _ => tAll
606 | TError => tAll
607
608 | TUnif (_, ref NONE) => raise Ununif
609
610 fun hasError (t, _) =
611 case t of
612 TBase _ => false
613 | TList t => hasError t
614 | TArrow (d, r) => hasError d orelse hasError r
615 | TAction (p, d, r) => List.exists hasError (SM.listItems d)
616 orelse List.exists hasError (SM.listItems r)
617 | TNested _ => false
618 | TError => false
619 | TUnif (_, ref (SOME t)) => hasError t
620 | TUnif (_, ref NONE) => false
621
622
623 fun checkUnit G (eAll as (_, loc)) =
624 let
625 val _ = resetUnif ()
626 val t = checkExp G eAll
627 in
628 if hasError t then
629 t
630 else
631 ununif t
632 handle Ununif =>
633 (ErrorMsg.error (SOME loc) "Unification variables remain in type:";
634 printd (p_typ t);
635 t)
636 end
637
638 fun checkDecl G (d, _, loc) =
639 case d of
640 DExternType name => bindType G name
641 | DExternVal (name, t) => bindVal G (name, checkTyp G t)
642
643 fun checkFile G tInit (ds, eo) =
644 let
645 val G' = foldl (fn (d, G) => checkDecl G d) G ds
646 in
647 case eo of
648 NONE => ()
649 | SOME (e as (_, loc)) =>
650 let
651 val t = checkExp G' e
652 in
653 subTyp (t, tInit)
654 handle Unify ue =>
655 (ErrorMsg.error (SOME loc) "Bad type for final expression of source file.";
656 preface ("Actual:", p_typ t);
657 preface ("Needed:", p_typ tInit))
658 end;
659 G'
660 end
661
662 end