Merge pull request #238 from prt2121/pt/haskell-7.10.1
[jackhill/mal.git] / lua / step2_eval.lua
1 #!/usr/bin/env lua
2
3 local table = require('table')
4
5 local readline = require('readline')
6 local utils = require('utils')
7 local types = require('types')
8 local reader = require('reader')
9 local printer = require('printer')
10 local List, Vector, HashMap = types.List, types.Vector, types.HashMap
11
12 -- read
13 function READ(str)
14 return reader.read_str(str)
15 end
16
17 -- eval
18 function eval_ast(ast, env)
19 if types._symbol_Q(ast) then
20 if env[ast.val] == nil then
21 types.throw("'"..ast.val.."' not found")
22 end
23 return env[ast.val]
24 elseif types._list_Q(ast) then
25 return List:new(utils.map(function(x) return EVAL(x,env) end,ast))
26 elseif types._vector_Q(ast) then
27 return Vector:new(utils.map(function(x) return EVAL(x,env) end,ast))
28 elseif types._hash_map_Q(ast) then
29 local new_hm = {}
30 for k,v in pairs(ast) do
31 new_hm[EVAL(k, env)] = EVAL(v, env)
32 end
33 return HashMap:new(new_hm)
34 else
35 return ast
36 end
37 end
38
39 function EVAL(ast, env)
40 --print("EVAL: "..printer._pr_str(ast,true))
41 if not types._list_Q(ast) then return eval_ast(ast, env) end
42 if #ast == 0 then return ast end
43 local args = eval_ast(ast, env)
44 local f = table.remove(args, 1)
45 return f(unpack(args))
46 end
47
48 -- print
49 function PRINT(exp)
50 return printer._pr_str(exp, true)
51 end
52
53 -- repl
54 local repl_env = {['+'] = function(a,b) return a+b end,
55 ['-'] = function(a,b) return a-b end,
56 ['*'] = function(a,b) return a*b end,
57 ['/'] = function(a,b) return math.floor(a/b) end}
58 function rep(str)
59 return PRINT(EVAL(READ(str),repl_env))
60 end
61
62 if #arg > 0 and arg[1] == "--raw" then
63 readline.raw = true
64 end
65
66 while true do
67 line = readline.readline("user> ")
68 if not line then break end
69 xpcall(function()
70 print(rep(line))
71 end, function(exc)
72 if exc then
73 if types._malexception_Q(exc) then
74 exc = printer._pr_str(exc.val, true)
75 end
76 print("Error: " .. exc)
77 print(debug.traceback())
78 end
79 end)
80 end