gnu-smalltalk, go, js, lua, php, rust, ts: Detect more unterminated strings.
[jackhill/mal.git] / lua / core.lua
1 local utils = require('utils')
2 local types = require('types')
3 local reader = require('reader')
4 local printer = require('printer')
5 local readline = require('readline')
6 local socket = require('socket')
7
8 local Nil, List, HashMap, _pr_str = types.Nil, types.List, types.HashMap, printer._pr_str
9
10 local M = {}
11
12 -- string functions
13
14 function pr_str(...)
15 return table.concat(
16 utils.map(function(e) return _pr_str(e, true) end,
17 table.pack(...)), " ")
18 end
19
20 function str(...)
21 return table.concat(
22 utils.map(function(e) return _pr_str(e, false) end,
23 table.pack(...)), "")
24 end
25
26 function prn(...)
27 print(table.concat(
28 utils.map(function(e) return _pr_str(e, true) end,
29 table.pack(...)), " "))
30 io.flush()
31 return Nil
32 end
33
34 function println(...)
35 print(table.concat(
36 utils.map(function(e) return _pr_str(e, false) end,
37 table.pack(...)), " "))
38 io.flush()
39 return Nil
40 end
41
42 function slurp(file)
43 local lines = {}
44 for line in io.lines(file) do
45 lines[#lines+1] = line
46 end
47 return table.concat(lines, "\n") .. "\n"
48 end
49
50 function do_readline(prompt)
51 local line = readline.readline(prompt)
52 if line == nil then
53 return Nil
54 else
55 return line
56 end
57 end
58
59 -- hash map functions
60
61 function assoc(hm, ...)
62 return types._assoc_BANG(types.copy(hm), ...)
63 end
64
65 function dissoc(hm, ...)
66 return types._dissoc_BANG(types.copy(hm), ...)
67 end
68
69 function get(hm, key)
70 local res = hm[key]
71 if res == nil then return Nil end
72 return res
73 end
74
75 function keys(hm)
76 local res = {}
77 for k,v in pairs(hm) do
78 res[#res+1] = k
79 end
80 return List:new(res)
81 end
82
83 function vals(hm)
84 local res = {}
85 for k,v in pairs(hm) do
86 res[#res+1] = v
87 end
88 return List:new(res)
89 end
90
91 -- sequential functions
92
93 function cons(a,lst)
94 local new_lst = lst:slice(1)
95 table.insert(new_lst, 1, a)
96 return List:new(new_lst)
97 end
98
99 function concat(...)
100 local arg = table.pack(...)
101 local new_lst = {}
102 for i = 1, #arg do
103 for j = 1, #arg[i] do
104 table.insert(new_lst, arg[i][j])
105 end
106 end
107 return List:new(new_lst)
108 end
109
110 function nth(seq, idx)
111 if idx+1 <= #seq then
112 return seq[idx+1]
113 else
114 types.throw("nth: index out of range")
115 end
116 end
117
118 function first(a)
119 if #a == 0 then
120 return Nil
121 else
122 return a[1]
123 end
124 end
125
126 function rest(a)
127 if a == Nil then
128 return List:new()
129 else
130 return List:new(a:slice(2))
131 end
132 end
133
134 function apply(f, ...)
135 local arg = table.pack(...)
136 if types._malfunc_Q(f) then
137 f = f.fn
138 end
139 local args = concat(types.slice(arg, 1, #arg-1),
140 arg[#arg])
141 return f(table.unpack(args))
142 end
143
144 function map(f, lst)
145 if types._malfunc_Q(f) then
146 f = f.fn
147 end
148 return List:new(utils.map(f, lst))
149 end
150
151 -- metadata functions
152
153 function meta(obj)
154 local m = getmetatable(obj)
155 if m == nil or m.meta == nil then return Nil end
156 return m.meta
157 end
158
159 function with_meta(obj, meta)
160 local new_obj = types.copy(obj)
161 getmetatable(new_obj).meta = meta
162 return new_obj
163 end
164
165 -- atom functions
166
167 function swap_BANG(atm,f,...)
168 if types._malfunc_Q(f) then
169 f = f.fn
170 end
171 local args = List:new(table.pack(...))
172 table.insert(args, 1, atm.val)
173 atm.val = f(table.unpack(args))
174 return atm.val
175 end
176
177 local function conj(obj, ...)
178 local arg = table.pack(...)
179 local new_obj = types.copy(obj)
180 if types._list_Q(new_obj) then
181 for i, v in ipairs(arg) do
182 table.insert(new_obj, 1, v)
183 end
184 else
185 for i, v in ipairs(arg) do
186 table.insert(new_obj, v)
187 end
188 end
189 return new_obj
190 end
191
192 local function seq(obj, ...)
193 if obj == Nil or #obj == 0 then
194 return Nil
195 elseif types._list_Q(obj) then
196 return obj
197 elseif types._vector_Q(obj) then
198 return List:new(obj)
199 elseif types._string_Q(obj) then
200 local chars = {}
201 for i = 1, #obj do
202 chars[#chars+1] = string.sub(obj,i,i)
203 end
204 return List:new(chars)
205 end
206 return Nil
207 end
208
209 local function lua_to_mal(a)
210 if a == nil then
211 return Nil
212 elseif type(a) == "boolean" or type(a) == "number" or type(a) == "string" then
213 return a
214 elseif type(a) == "table" then
215 local first_key, _ = next(a)
216 if first_key == nil then
217 return List:new({})
218 elseif type(first_key) == "number" then
219 local list = {}
220 for i, v in ipairs(a) do
221 list[i] = lua_to_mal(v)
222 end
223 return List:new(list)
224 else
225 local hashmap = {}
226 for k, v in pairs(a) do
227 hashmap[lua_to_mal(k)] = lua_to_mal(v)
228 end
229 return HashMap:new(hashmap)
230 end
231 end
232 return tostring(a)
233 end
234
235 local function lua_eval(str)
236 local f, err = loadstring("return "..str)
237 if err then
238 types.throw("lua-eval: can't load code: "..err)
239 end
240 return lua_to_mal(f())
241 end
242
243 M.ns = {
244 ['='] = types._equal_Q,
245 throw = types.throw,
246
247 ['nil?'] = function(a) return a==Nil end,
248 ['true?'] = function(a) return a==true end,
249 ['false?'] = function(a) return a==false end,
250 ['number?'] = function(a) return types._number_Q(a) end,
251 symbol = function(a) return types.Symbol:new(a) end,
252 ['symbol?'] = function(a) return types._symbol_Q(a) end,
253 ['string?'] = function(a) return types._string_Q(a) and "\177" ~= string.sub(a,1,1) end,
254 keyword = function(a) return "\177"..a end,
255 ['keyword?'] = function(a) return types._keyword_Q(a) end,
256 ['fn?'] = function(a) return types._fn_Q(a) end,
257 ['macro?'] = function(a) return types._macro_Q(a) end,
258
259 ['pr-str'] = pr_str,
260 str = str,
261 prn = prn,
262 println = println,
263 ['read-string'] = reader.read_str,
264 readline = do_readline,
265 slurp = slurp,
266
267 ['<'] = function(a,b) return a<b end,
268 ['<='] = function(a,b) return a<=b end,
269 ['>'] = function(a,b) return a>b end,
270 ['>='] = function(a,b) return a>=b end,
271 ['+'] = function(a,b) return a+b end,
272 ['-'] = function(a,b) return a-b end,
273 ['*'] = function(a,b) return a*b end,
274 ['/'] = function(a,b) return math.floor(a/b) end,
275 ['time-ms'] = function() return math.floor(socket.gettime() * 1000) end,
276
277 list = function(...) return List:new(table.pack(...)) end,
278 ['list?'] = function(a) return types._list_Q(a) end,
279 vector = function(...) return types.Vector:new(table.pack(...)) end,
280 ['vector?'] = types._vector_Q,
281 ['hash-map'] = types.hash_map,
282 ['map?'] = types._hash_map_Q,
283 assoc = assoc,
284 dissoc = dissoc,
285 get = get,
286 ['contains?'] = function(a,b) return a[b] ~= nil end,
287 keys = keys,
288 vals = vals,
289
290 ['sequential?'] = types._sequential_Q,
291 cons = cons,
292 concat = concat,
293 nth = nth,
294 first = first,
295 rest = rest,
296 ['empty?'] = function(a) return a==Nil or #a == 0 end,
297 count = function(a) return #a end,
298 apply = apply,
299 map = map,
300 conj = conj,
301 seq = seq,
302
303 meta = meta,
304 ['with-meta'] = with_meta,
305 atom = function(a) return types.Atom:new(a) end,
306 ['atom?'] = types._atom_Q,
307 deref = function(a) return a.val end,
308 ['reset!'] = function(a,b) a.val = b; return b end,
309 ['swap!'] = swap_BANG,
310
311 ['lua-eval'] = lua_eval,
312 }
313
314 return M
315