Various improvements made while working on relwiki
[bpt/mlt.git] / src / mlt.sml
index eecf08d..b853295 100644 (file)
@@ -37,6 +37,10 @@ struct
     val ppstream = PrettyPrint.mk_ppstream {consumer = TextIO.print, flush = fn () => TextIO.flushOut TextIO.stdOut,
                                            linewidth = 80}
 
+    datatype unify =
+            ExpUn of exp
+          | PatUn of pat
+
     (* States to thread throughout translation *)
     local
        datatype state = STATE of {env: StaticEnv.staticEnv,
@@ -103,7 +107,8 @@ struct
        fun getVal (STRCT {elements, ...}, v, pos) =
            (case ModuleUtil.getSpec (elements, Symbol.varSymbol v) of
                 Modules.VALspec {spec, ...} => #1 (TypesUtil.instantiatePoly spec)
-              | _ => raise Fail "Unexpected spec in getVal")
+              | Modules.CONspec {spec = Types.DATACON {typ, ...}, ...} => #1 (TypesUtil.instantiatePoly typ)
+              | _ => raise Fail ("Unexpected spec in getVal for " ^ v))
            handle ModuleUtil.Unbound _ => (case ModuleUtil.getSpec (elements, Symbol.tycSymbol v) of
                                   Modules.CONspec {spec = Types.DATACON {typ, ...}, ...} => #1 (TypesUtil.instantiatePoly typ)
                                 | _ => raise Fail "Unexpected spec in getVal")
@@ -116,7 +121,7 @@ struct
            handle ModuleUtil.Unbound _ => (error (SOME pos, "Unbound constructor " ^ v);
                               errorTy)
 
-       fun unify (STATE {env, ...}) (pos, t1, t2) =
+       fun unify (STATE {env, ...}) (pos, e, t1, t2) =
            (*let
                val t1 = ModuleUtil.transType eenv t1
                val t2 = ModuleUtil.transType eenv t2
@@ -139,7 +144,8 @@ struct
                    PrettyPrint.end_block ppstream;
                    PrettyPrint.add_break ppstream (1, 0);
                    PrettyPrint.flush_ppstream ppstream;
-                   error (SOME pos, Unify.failMessage msg))
+                   error (SOME pos, Unify.failMessage msg ^ " for " ^ (case e of ExpUn e => Tree.expString e
+                                                                               | PatUn p => "<pat>")))
                                      
        fun resolvePath (getter, transer) (pos, state, path) =
            let
@@ -248,9 +254,10 @@ struct
 
     val templateTy = BasicTypes.--> (Types.CONty (BasicTypes.listTycon,
                                                  [mkTuple [BasicTypes.stringTy,
-                                                           BasicTypes.stringTy]]), BasicTypes.unitTy)
+                                                           Types.CONty (BasicTypes.listTycon,
+                                                                        [BasicTypes.stringTy])]]), BasicTypes.unitTy)
 
-    fun xexp state (EXP (e, pos)) =
+    fun xexp state (exp as EXP (e, pos)) =
        (case e of
             Int_e n =>
                 (BasicTypes.intTy, Int.toString n)
@@ -271,16 +278,29 @@ struct
 
                     val xt = mkTuple [ty1, ty2]
                 in
-                    unify state (pos, dom, xt);
+                    unify state (pos, ExpUn exp, dom, xt);
                     (ran, "(" ^ es1 ^ ") :: (" ^ es2 ^ ")")
                 end
+          | Compose_e (e1, e2) =>
+                let
+                    val (ty1, es1) = xexp state e1
+                    val (ty2, es2) = xexp state e2
+               
+                    val dom1 = newTyvar false
+                    val ran1dom2 = newTyvar false
+                    val ran2 = newTyvar false
+                in
+                    unify state (pos, ExpUn exp, ty2, BasicTypes.--> (dom1, ran1dom2));
+                    unify state (pos, ExpUn exp, ty1, BasicTypes.--> (ran1dom2, ran2));
+                    (BasicTypes.--> (dom1, ran2), "(" ^ es1 ^ ") o (" ^ es2 ^ ")")
+                end
           | StrCat_e (e1, e2) =>
                 let
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.stringTy);
-                    unify state (pos, ty2, BasicTypes.stringTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.stringTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.stringTy);
                     (BasicTypes.stringTy, "(" ^ es1 ^ ") ^ (" ^ es2 ^ ")")
                 end
           | Orelse_e (e1, e2) =>
@@ -288,8 +308,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.boolTy);
-                    unify state (pos, ty2, BasicTypes.boolTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.boolTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.boolTy);
                     (BasicTypes.boolTy, "(" ^ es1 ^ ") orelse (" ^ es2 ^ ")")
                 end
           | Andalso_e (e1, e2) =>
@@ -297,8 +317,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.boolTy);
-                    unify state (pos, ty2, BasicTypes.boolTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.boolTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.boolTy);
                     (BasicTypes.boolTy, "(" ^ es1 ^ ") andalso (" ^ es2 ^ ")")
                 end
           | Plus_e (e1, e2) =>
@@ -306,8 +326,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.intTy);
-                    unify state (pos, ty2, BasicTypes.intTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.intTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.intTy);
                     (BasicTypes.intTy, "(" ^ es1 ^ ") + (" ^ es2 ^ ")")
                 end
           | Minus_e (e1, e2) =>
@@ -315,8 +335,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.intTy);
-                    unify state (pos, ty2, BasicTypes.intTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.intTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.intTy);
                     (BasicTypes.intTy, "(" ^ es1 ^ ") - (" ^ es2 ^ ")")
                 end
           | Times_e (e1, e2) =>
@@ -324,8 +344,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.intTy);
-                    unify state (pos, ty2, BasicTypes.intTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.intTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.intTy);
                     (BasicTypes.intTy, "(" ^ es1 ^ ") * (" ^ es2 ^ ")")
                 end
           | Divide_e (e1, e2) =>
@@ -333,8 +353,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.intTy);
-                    unify state (pos, ty2, BasicTypes.intTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.intTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.intTy);
                     (BasicTypes.intTy, "(" ^ es1 ^ ") div (" ^ es2 ^ ")")
                 end
           | Mod_e (e1, e2) =>
@@ -342,8 +362,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.intTy);
-                    unify state (pos, ty2, BasicTypes.intTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.intTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.intTy);
                     (BasicTypes.intTy, "(" ^ es1 ^ ") mod (" ^ es2 ^ ")")
                 end
           | Lt_e (e1, e2) =>
@@ -351,8 +371,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.intTy);
-                    unify state (pos, ty2, BasicTypes.intTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.intTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.intTy);
                     (BasicTypes.boolTy, "(" ^ es1 ^ ") < (" ^ es2 ^ ")")
                 end
           | Lte_e (e1, e2) =>
@@ -360,8 +380,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.intTy);
-                    unify state (pos, ty2, BasicTypes.intTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.intTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.intTy);
                     (BasicTypes.boolTy, "(" ^ es1 ^ ") <= (" ^ es2 ^ ")")
                 end
           | Gt_e (e1, e2) =>
@@ -369,8 +389,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.intTy);
-                    unify state (pos, ty2, BasicTypes.intTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.intTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.intTy);
                     (BasicTypes.boolTy, "(" ^ es1 ^ ") > (" ^ es2 ^ ")")
                 end
           | Gte_e (e1, e2) =>
@@ -378,8 +398,8 @@ struct
                     val (ty1, es1) = xexp state e1
                     val (ty2, es2) = xexp state e2
                 in
-                    unify state (pos, ty1, BasicTypes.intTy);
-                    unify state (pos, ty2, BasicTypes.intTy);
+                    unify state (pos, ExpUn e1, ty1, BasicTypes.intTy);
+                    unify state (pos, ExpUn e2, ty2, BasicTypes.intTy);
                     (BasicTypes.boolTy, "(" ^ es1 ^ ") >= (" ^ es2 ^ ")")
                 end
           | Param_e => (BasicTypes.--> (BasicTypes.stringTy, BasicTypes.stringTy), "Web.getParam")
@@ -390,7 +410,7 @@ struct
                     fun toUpper ch = chr (ord ch + ord #"A" - ord #"a")
                     val name = str (toUpper (String.sub (name, 0))) ^ String.extract (name, 1, NONE)
                 in
-                    (templateTy, "(Web.withParams " ^ name ^ ".exec)")
+                    (templateTy, "(Web.withParams " ^ name ^ "_.exec)")
                 end
             else
                 (error (SOME pos, "Unknown template " ^ name);
@@ -406,8 +426,8 @@ struct
                     val (ty1, s1) = xexp state e1
                     val (ty2, s2) = xexp state e2
                 in
-                    unify state (pos, ty1, ty2);
-                    unify state (pos, ty1, newTyvar true);
+                    unify state (pos, ExpUn e1, ty1, ty2);
+                    unify state (pos, ExpUn e2, ty1, newTyvar true);
                     (BasicTypes.boolTy, "(" ^ s1 ^ ") = (" ^ s2 ^ ")")
                 end
           | Neq_e (e1, e2) =>
@@ -415,8 +435,8 @@ struct
                     val (ty1, s1) = xexp state e1
                     val (ty2, s2) = xexp state e2
                 in
-                    unify state (pos, ty1, ty2);
-                    unify state (pos, ty1, newTyvar true);
+                    unify state (pos, ExpUn e1, ty1, ty2);
+                    unify state (pos, ExpUn e2, ty1, newTyvar true);
                     (BasicTypes.boolTy, "(" ^ s1 ^ ") <> (" ^ s2 ^ ")")
                 end
           | Ident_e [] => raise Fail "Impossible empty variable path"
@@ -431,13 +451,13 @@ struct
                    let
                        val (ft, fs) = xexp state f
                        val (xt, xs) = xexp state x
-
-                       (*val (ft, _) = TypesUtil.instantiatePoly ft*)
-                       val dom = domain ft
-                       val ran = range ft
                    in
-                       unify state (pos, dom, xt);
-                       (ran, "(" ^ fs ^ ") (" ^ xs ^ ")")
+                       if BasicTypes.isArrowType ft then
+                           (unify state (pos, ExpUn x, domain ft, xt);
+                            (range ft, "(" ^ fs ^ ") (" ^ xs ^ ")"))
+                       else
+                           (error (SOME pos, "Applying non-function");
+                            (errorTy, "<error>"))
                    end
           | Case_e (e, matches) =>
             let
@@ -447,11 +467,11 @@ struct
                     let
                         val (pty, vars', ps) = xpat state p
                                                
-                        val _ = unify state (pos, ty, pty)
+                        val _ = unify state (pos, ExpUn e, ty, pty)
                                 
                         val (ty', str') = xexp (addVars (state, vars')) e'
                     in
-                        unify state (pos, ty', bodyTy);
+                        unify state (pos, ExpUn e', ty', bodyTy);
                         (false,
                          str ^ (if first then "   " else " | ") ^ "(" ^ ps ^ ") => " ^
                          str' ^ "\n",
@@ -493,11 +513,11 @@ struct
                     let
                         val (pty, vars', ps) = xpat state p
                                                
-                        val _ = unify state (pos, dom, pty)
+                        val _ = unify state (pos, ExpUn exp, dom, pty)
                                 
                         val (ty', str') = xexp (addVars (state, vars')) e'
                     in
-                        unify state (pos, ty', ran);
+                        unify state (pos, ExpUn e', ty', ran);
                         (false,
                          str ^ (if first then "   " else " | ") ^ "(" ^ ps ^ ") => " ^
                          str' ^ "\n")
@@ -512,9 +532,26 @@ struct
             let
                 val (ty, es) = xexp state e
             in
-                unify state (pos, ty, BasicTypes.exnTy);
+                unify state (pos, ExpUn e, ty, BasicTypes.exnTy);
                 (newTyvar false, "(raise (" ^ es ^ "))")
             end
+          | Let_e (b, e) =>
+            let
+                val (state, str) = xblock state b
+                val (ty, es) = xexp state e
+            in
+                (ty, "let\n" ^ str ^ "\nin\n" ^ es ^ "\nend\n")
+            end
+          | If_e (c, t, e) =>
+            let
+                val (bty, ce) = xexp state c
+                val (ty, te) = xexp state t
+                val (ty', ee) = xexp state e
+            in
+                unify state (pos, ExpUn c, bty, BasicTypes.boolTy);
+                unify state (pos, ExpUn exp, ty, ty');
+                (ty, "(if (" ^ ce ^ ") then (" ^ te ^ ") else (" ^ ee ^ "))")
+            end
           | RecordUpd_e (e, cs) =>
                 let
                     val (ty, es) = xexp state e
@@ -532,7 +569,7 @@ struct
                                                              let
                                                                  val (ty', s) = xexp state e
                                                              in
-                                                                 unify state (pos, ty, ty');
+                                                                 unify state (pos, ExpUn e, ty, ty');
                                                                  (n + 1, str ^ ", " ^ Symbol.name id ^ " = " ^ s)
                                                              end) (0, "") cs'
 
@@ -555,7 +592,7 @@ struct
                               NONE => StringMap.insert (vars, v, ty)
                             | SOME _ => error (SOME pos, "Duplicate variable " ^ v ^ " in pattern"))) vars1 vars2
 
-    and xpat state (PAT (p, pos)) =
+    and xpat state (pat as PAT (p, pos)) =
        (case p of
             Ident_p [] => raise Fail "Impossible empty Ident_p"
           | Ident_p [id] =>
@@ -580,7 +617,7 @@ struct
                 val tyc = lookCon (state, id, pos)
                 val dom = domain tyc
             in
-                unify state (pos, dom, ty);
+                unify state (pos, PatUn p, dom, ty);
                 (range tyc, vars, id ^ " (" ^ s ^ ")")
             end
           | App_p (path as (fst::rest), p) =>
@@ -589,7 +626,7 @@ struct
                     val tyc = resolveCon (pos, state, path)
                     val dom = domain tyc
                 in
-                    unify state (pos, dom, ty);
+                    unify state (pos, PatUn p, dom, ty);
                     (range tyc, vars, foldl (fn (n, st) => st ^ "." ^ n) fst rest ^ " (" ^ s ^ ")")
                 end
           | Cons_p (p1, p2) =>
@@ -599,7 +636,7 @@ struct
 
                     val resty = Types.CONty (BasicTypes.listTycon, [ty1])
                 in
-                    unify state (pos, ty2, resty);
+                    unify state (pos, PatUn pat, ty2, resty);
                     (resty, mergePatVars pos (vars', vars''), "(" ^ s1 ^ ")::(" ^ s2 ^ ")")
                 end
           | As_p (id, p) =>
@@ -652,7 +689,7 @@ struct
             error (SOME pos, "Not done yet!!!")*))
        handle Skip => (errorTy, StringMap.empty, "<error>")
 
-    fun xblock state (BLOCK (blocks, pos)) =
+    and xblock state (BLOCK (blocks, pos)) =
        let
            fun folder (BITEM (bi, pos), (state, str)) =
                (case bi of
@@ -681,7 +718,7 @@ struct
 
                             val (ty, es) = xexp state e
                         in
-                            unify state (pos, ty, vty);
+                            unify state (pos, ExpUn e, ty, vty);
                             (state, str ^ "val _ = " ^ id ^ " := (" ^ es ^ ")\n")
                         end
                   | Val_i (p, e) =>
@@ -690,7 +727,7 @@ struct
                             val state' = addVars (state, vars)
                             val (ty, es) = xexp state e
                         in
-                            unify state (pos, pty, ty);
+                            unify state (pos, ExpUn e, pty, ty);
                             (state', str ^ "val " ^ ps ^ " = (" ^ es ^ ")\n")
                         end
                   | Exp_i e =>
@@ -714,7 +751,7 @@ struct
                             val str = str ^ "val _ = "
                             val (ty, s) = xexp state e
                             val (_, str') = xblock state b
-                            val _ = unify state (pos, ty, BasicTypes.boolTy)
+                            val _ = unify state (pos, ExpUn e, ty, BasicTypes.boolTy)
                             val str = str ^ "if (" ^ s ^ ") then let\n" ^
                                       str' ^
                                       "in () end\n"
@@ -733,20 +770,20 @@ struct
                         in
                             (state, str)
                         end
-                  | Foreach_i (id, e, b) =>
+                  | Foreach_i (p, e, b) =>
                         let
                             val parm = newTyvar false
-
+       
+                            val (pty, vars, ps) = xpat state p
                             val (ty, es) = xexp state e
 
-                            val _ = unify state (pos, ty, Types.CONty (BasicTypes.listTycon, [parm]))
-
-                            (*val _ = print ("... to " ^ tyToString (context, ivmap, pty) ^ "\n")*)
+                            val _ = unify state (pos, ExpUn e, ty, Types.CONty (BasicTypes.listTycon, [parm]))
+                            val _ = unify state (pos, PatUn p, pty, parm)
 
-                            val state = addVar (state, id, VAR parm)
-                            val (_, bs) = xblock state b
-                        in
-                            (state, str ^ "fun foreach (" ^ id ^ " : " ^
+                            val state' = addVars (state, vars)
+                            val (_, bs) = xblock state' b
+                        in                          
+                            (state, str ^ "fun foreach ((" ^ ps ^ ") : " ^
                                     tyToString state parm ^ ") = let\n" ^
                              bs ^
                              "in () end\n" ^
@@ -755,10 +792,10 @@ struct
                   | For_i (id, eFrom, eTo, b) =>
                         let
                             val (ty1, es1) = xexp state eFrom
-                            val _ = unify state (pos, ty1, BasicTypes.intTy)
+                            val _ = unify state (pos, ExpUn eFrom, ty1, BasicTypes.intTy)
 
                             val (ty2, es2) = xexp state eTo
-                            val _ = unify state (pos, ty2, BasicTypes.intTy)
+                            val _ = unify state (pos, ExpUn eTo, ty2, BasicTypes.intTy)
 
                             val state = addVar (state, id, VAR BasicTypes.intTy)
                             val (_, bs) = xblock state b
@@ -776,7 +813,7 @@ struct
                                 let
                                     val (pty, vars', ps) = xpat state p
 
-                                    val _ = unify state (pos, ty, pty)
+                                    val _ = unify state (pos, PatUn p, ty, pty)
 
                                     val (_, str') = xblock (addVars (state, vars')) b
 
@@ -803,7 +840,7 @@ struct
                                     val state = addVars (state, vars)
                                     val (_, str') = xblock state b
                                 in
-                                    unify state (pos, BasicTypes.exnTy, pty);
+                                    unify state (pos, PatUn p, BasicTypes.exnTy, pty);
                                     (false,
                                      str ^ (if first then "   " else " | ") ^ "(" ^ ps ^ ") => let\n" ^
                                      str' ^