elixir, erlang, lua, php, r, vimscript: Fix (first nil) and (rest nil)
[jackhill/mal.git] / lua / step2_eval.lua
1 #!/usr/bin/env lua
2
3 local table = require('table')
4
5 local readline = require('readline')
6 local utils = require('utils')
7 local types = require('types')
8 local reader = require('reader')
9 local printer = require('printer')
10 local List, Vector, HashMap = types.List, types.Vector, types.HashMap
11
12 -- read
13 function READ(str)
14 return reader.read_str(str)
15 end
16
17 -- eval
18 function eval_ast(ast, env)
19 if types._symbol_Q(ast) then
20 if env[ast.val] == nil then
21 types.throw("'"..ast.val.."' not found")
22 end
23 return env[ast.val]
24 elseif types._list_Q(ast) then
25 return List:new(utils.map(function(x) return EVAL(x,env) end,ast))
26 elseif types._vector_Q(ast) then
27 return Vector:new(utils.map(function(x) return EVAL(x,env) end,ast))
28 elseif types._hash_map_Q(ast) then
29 local new_hm = {}
30 for k,v in pairs(ast) do
31 new_hm[EVAL(k, env)] = EVAL(v, env)
32 end
33 return HashMap:new(new_hm)
34 else
35 return ast
36 end
37 end
38
39 function EVAL(ast, env)
40 --print("EVAL: "..printer._pr_str(ast,true))
41 if not types._list_Q(ast) then return eval_ast(ast, env) end
42 local args = eval_ast(ast, env)
43 local f = table.remove(args, 1)
44 return f(unpack(args))
45 end
46
47 -- print
48 function PRINT(exp)
49 return printer._pr_str(exp, true)
50 end
51
52 -- repl
53 local repl_env = {['+'] = function(a,b) return a+b end,
54 ['-'] = function(a,b) return a-b end,
55 ['*'] = function(a,b) return a*b end,
56 ['/'] = function(a,b) return math.floor(a/b) end}
57 function rep(str)
58 return PRINT(EVAL(READ(str),repl_env))
59 end
60
61 if #arg > 0 and arg[1] == "--raw" then
62 readline.raw = true
63 end
64
65 while true do
66 line = readline.readline("user> ")
67 if not line then break end
68 xpcall(function()
69 print(rep(line))
70 end, function(exc)
71 if exc then
72 if types._malexception_Q(exc) then
73 exc = printer._pr_str(exc.val, true)
74 end
75 print("Error: " .. exc)
76 print(debug.traceback())
77 end
78 end)
79 end