Change quasiquote algorithm
[jackhill/mal.git] / impls / python.2 / stepA_mal.py
index 6ed2977..0cbb09f 100644 (file)
@@ -1,3 +1,4 @@
+import functools
 import readline
 import sys
 from typing import List, Dict
@@ -41,38 +42,32 @@ def eval_ast(ast: MalExpression, env: Env) -> MalExpression:
     return ast
 
 
-def is_pair(x: MalExpression) -> bool:
-    if (isinstance(x, MalList) or isinstance(x, MalVector)) and len(x.native()) > 0:
-        return True
-    return False
+def qq_loop(acc: MalList, elt: MalExpression) -> MalList:
+    if isinstance(elt, MalList):
+        lst = elt.native()
+        if len(lst) == 2:
+            fst = lst[0]
+            if isinstance(fst, MalSymbol) and fst.native() == u"splice-unquote":
+                return MalList([MalSymbol(u"concat"), lst[1], acc])
+    return MalList([MalSymbol(u"cons"), quasiquote(elt), acc])
 
+def qq_foldr(xs: List[MalExpression]) -> MalList:
+    return functools.reduce(qq_loop, reversed(xs), MalList([]))
 
 def quasiquote(ast: MalExpression) -> MalExpression:
-    if not is_pair(ast):
+    if isinstance(ast, MalList):
+        lst = ast.native()
+        if len(lst) == 2:
+            fst = lst[0]
+            if isinstance(fst, MalSymbol) and fst.native() == u'unquote':
+                return lst[1]
+        return qq_foldr(lst)
+    elif isinstance(ast, MalVector):
+        return MalList([MalSymbol("vec"), qq_foldr(ast.native())])
+    elif isinstance(ast, MalSymbol) or isinstance(ast, MalHash_map):
         return MalList([MalSymbol("quote"), ast])
-    elif core.equal(ast.native()[0], MalSymbol("unquote")).native():
-        return ast.native()[1]
-    elif (
-        is_pair(ast.native()[0])
-        and core.equal(
-            ast.native()[0].native()[0], MalSymbol("splice-unquote")
-        ).native()
-    ):
-        return MalList(
-            [
-                MalSymbol("concat"),
-                ast.native()[0].native()[1],
-                quasiquote(MalList(ast.native()[1:])),
-            ]
-        )
     else:
-        return MalList(
-            [
-                MalSymbol("cons"),
-                quasiquote(ast.native()[0]),
-                quasiquote(MalList(ast.native()[1:])),
-            ]
-        )
+        return ast
 
 
 def EVAL(ast: MalExpression, env: Env) -> MalExpression:
@@ -149,6 +144,8 @@ def EVAL(ast: MalExpression, env: Env) -> MalExpression:
                 if isinstance(ast_native[1], MalVector)
                 else ast_native[1]
             )
+        elif first_str == "quasiquoteexpand":
+            return quasiquote(ast_native[1])
         elif first_str == "quasiquote":
             ast = quasiquote(ast_native[1])
             continue