Merge pull request #238 from prt2121/pt/haskell-7.10.1
[jackhill/mal.git] / lua / core.lua
1 local utils = require('utils')
2 local types = require('types')
3 local reader = require('reader')
4 local printer = require('printer')
5 local readline = require('readline')
6 local socket = require('socket')
7
8 local Nil, List, HashMap, _pr_str = types.Nil, types.List, types.HashMap, printer._pr_str
9
10 local M = {}
11
12 -- string functions
13
14 function pr_str(...)
15 return table.concat(
16 utils.map(function(e) return _pr_str(e, true) end, arg), " ")
17 end
18
19 function str(...)
20 return table.concat(
21 utils.map(function(e) return _pr_str(e, false) end, arg), "")
22 end
23
24 function prn(...)
25 print(table.concat(
26 utils.map(function(e) return _pr_str(e, true) end, arg), " "))
27 io.flush()
28 return Nil
29 end
30
31 function println(...)
32 print(table.concat(
33 utils.map(function(e) return _pr_str(e, false) end, arg), " "))
34 io.flush()
35 return Nil
36 end
37
38 function slurp(file)
39 local lines = {}
40 for line in io.lines(file) do
41 lines[#lines+1] = line
42 end
43 return table.concat(lines, "\n") .. "\n"
44 end
45
46 function do_readline(prompt)
47 local line = readline.readline(prompt)
48 if line == nil then
49 return Nil
50 else
51 return line
52 end
53 end
54
55 -- hash map functions
56
57 function assoc(hm, ...)
58 return types._assoc_BANG(types.copy(hm), unpack(arg))
59 end
60
61 function dissoc(hm, ...)
62 return types._dissoc_BANG(types.copy(hm), unpack(arg))
63 end
64
65 function get(hm, key)
66 local res = hm[key]
67 if res == nil then return Nil end
68 return res
69 end
70
71 function keys(hm)
72 local res = {}
73 for k,v in pairs(hm) do
74 res[#res+1] = k
75 end
76 return List:new(res)
77 end
78
79 function vals(hm)
80 local res = {}
81 for k,v in pairs(hm) do
82 res[#res+1] = v
83 end
84 return List:new(res)
85 end
86
87 -- sequential functions
88
89 function cons(a,lst)
90 local new_lst = lst:slice(1)
91 table.insert(new_lst, 1, a)
92 return List:new(new_lst)
93 end
94
95 function concat(...)
96 local new_lst = {}
97 for i = 1, #arg do
98 for j = 1, #arg[i] do
99 table.insert(new_lst, arg[i][j])
100 end
101 end
102 return List:new(new_lst)
103 end
104
105 function nth(seq, idx)
106 if idx+1 <= #seq then
107 return seq[idx+1]
108 else
109 types.throw("nth: index out of range")
110 end
111 end
112
113 function first(a)
114 if #a == 0 then
115 return Nil
116 else
117 return a[1]
118 end
119 end
120
121 function rest(a)
122 if a == Nil then
123 return List:new()
124 else
125 return List:new(a:slice(2))
126 end
127 end
128
129 function apply(f, ...)
130 if types._malfunc_Q(f) then
131 f = f.fn
132 end
133 local args = concat(types.slice(arg, 1, #arg-1),
134 arg[#arg])
135 return f(unpack(args))
136 end
137
138 function map(f, lst)
139 if types._malfunc_Q(f) then
140 f = f.fn
141 end
142 return List:new(utils.map(f, lst))
143 end
144
145 -- metadata functions
146
147 function meta(obj)
148 local m = getmetatable(obj)
149 if m == nil or m.meta == nil then return Nil end
150 return m.meta
151 end
152
153 function with_meta(obj, meta)
154 local new_obj = types.copy(obj)
155 getmetatable(new_obj).meta = meta
156 return new_obj
157 end
158
159 -- atom functions
160
161 function swap_BANG(atm,f,...)
162 if types._malfunc_Q(f) then
163 f = f.fn
164 end
165 local args = List:new(arg)
166 table.insert(args, 1, atm.val)
167 atm.val = f(unpack(args))
168 return atm.val
169 end
170
171 local function conj(obj, ...)
172 local new_obj = types.copy(obj)
173 if types._list_Q(new_obj) then
174 for i, v in ipairs(arg) do
175 table.insert(new_obj, 1, v)
176 end
177 else
178 for i, v in ipairs(arg) do
179 table.insert(new_obj, v)
180 end
181 end
182 return new_obj
183 end
184
185 local function seq(obj, ...)
186 if obj == Nil or #obj == 0 then
187 return Nil
188 elseif types._list_Q(obj) then
189 return obj
190 elseif types._vector_Q(obj) then
191 return List:new(obj)
192 elseif types._string_Q(obj) then
193 local chars = {}
194 for i = 1, #obj do
195 chars[#chars+1] = string.sub(obj,i,i)
196 end
197 return List:new(chars)
198 end
199 return Nil
200 end
201
202 local function lua_to_mal(a)
203 if a == nil then
204 return Nil
205 elseif type(a) == "boolean" or type(a) == "number" or type(a) == "string" then
206 return a
207 elseif type(a) == "table" then
208 local first_key, _ = next(a)
209 if first_key == nil then
210 return List:new({})
211 elseif type(first_key) == "number" then
212 local list = {}
213 for i, v in ipairs(a) do
214 list[i] = lua_to_mal(v)
215 end
216 return List:new(list)
217 else
218 local hashmap = {}
219 for k, v in pairs(a) do
220 hashmap[lua_to_mal(k)] = lua_to_mal(v)
221 end
222 return HashMap:new(hashmap)
223 end
224 end
225 return tostring(a)
226 end
227
228 local function lua_eval(str)
229 local f, err = loadstring("return "..str)
230 if err then
231 types.throw("lua-eval: can't load code: "..err)
232 end
233 return lua_to_mal(f())
234 end
235
236 M.ns = {
237 ['='] = types._equal_Q,
238 throw = types.throw,
239
240 ['nil?'] = function(a) return a==Nil end,
241 ['true?'] = function(a) return a==true end,
242 ['false?'] = function(a) return a==false end,
243 symbol = function(a) return types.Symbol:new(a) end,
244 ['symbol?'] = function(a) return types._symbol_Q(a) end,
245 ['string?'] = function(a) return types._string_Q(a) and "\177" ~= string.sub(a,1,1) end,
246 keyword = function(a) return "\177"..a end,
247 ['keyword?'] = function(a) return types._keyword_Q(a) end,
248
249 ['pr-str'] = pr_str,
250 str = str,
251 prn = prn,
252 println = println,
253 ['read-string'] = reader.read_str,
254 readline = do_readline,
255 slurp = slurp,
256
257 ['<'] = function(a,b) return a<b end,
258 ['<='] = function(a,b) return a<=b end,
259 ['>'] = function(a,b) return a>b end,
260 ['>='] = function(a,b) return a>=b end,
261 ['+'] = function(a,b) return a+b end,
262 ['-'] = function(a,b) return a-b end,
263 ['*'] = function(a,b) return a*b end,
264 ['/'] = function(a,b) return math.floor(a/b) end,
265 ['time-ms'] = function() return math.floor(socket.gettime() * 1000) end,
266
267 list = function(...) return List:new(arg) end,
268 ['list?'] = function(a) return types._list_Q(a) end,
269 vector = function(...) return types.Vector:new(arg) end,
270 ['vector?'] = types._vector_Q,
271 ['hash-map'] = types.hash_map,
272 ['map?'] = types._hash_map_Q,
273 assoc = assoc,
274 dissoc = dissoc,
275 get = get,
276 ['contains?'] = function(a,b) return a[b] ~= nil end,
277 keys = keys,
278 vals = vals,
279
280 ['sequential?'] = types._sequential_Q,
281 cons = cons,
282 concat = concat,
283 nth = nth,
284 first = first,
285 rest = rest,
286 ['empty?'] = function(a) return a==Nil or #a == 0 end,
287 count = function(a) return #a end,
288 apply = apply,
289 map = map,
290 conj = conj,
291 seq = seq,
292
293 meta = meta,
294 ['with-meta'] = with_meta,
295 atom = function(a) return types.Atom:new(a) end,
296 ['atom?'] = types._atom_Q,
297 deref = function(a) return a.val end,
298 ['reset!'] = function(a,b) a.val = b; return b end,
299 ['swap!'] = swap_BANG,
300
301 ['lua-eval'] = lua_eval,
302 }
303
304 return M
305