fix `nil?' type inference
[bpt/guile.git] / module / language / cps / type-fold.scm
index 20abc36..6cc1284 100644 (file)
   #:use-module (language cps dfg)
   #:use-module (language cps renumber)
   #:use-module (language cps types)
+  #:use-module (system base target)
   #:export (type-fold))
 
+
+\f
+
+;; Branch folders.
+
 (define &scalar-types
-  (logior &exact-integer &flonum &char &unspecified &boolean &nil &null))
+  (logior &exact-integer &flonum &char &unspecified &false &true &nil &null))
 
 (define *branch-folders* (make-hash-table))
 
@@ -61,7 +67,6 @@
 ;; All the cases that are in compile-bytecode.
 (define-unary-type-predicate-folder pair? &pair)
 (define-unary-type-predicate-folder null? &null)
-(define-unary-type-predicate-folder nil? &nil)
 (define-unary-type-predicate-folder symbol? &symbol)
 (define-unary-type-predicate-folder variable? &box)
 (define-unary-type-predicate-folder vector? &vector)
     ((= <= <) (values #t #f))
     (else (values #f #f))))
 
-(define (compute-folded fun dfg min-label min-var)
+(define-binary-branch-folder (logtest type0 min0 max0 type1 min1 max1)
+  (define (logand-min a b)
+    (if (< a b 0)
+        (min a b)
+        0))
+  (define (logand-max a b)
+    (if (< a b 0)
+        0
+        (max a b)))
+  (if (and (= min0 max0) (= min1 max1) (eqv? type0 type1 &exact-integer))
+      (values #t (logtest min0 min1))
+      (values #f #f)))
+
+
+\f
+
+;; Strength reduction.
+
+(define *primcall-reducers* (make-hash-table))
+
+(define-syntax-rule (define-primcall-reducer name f)
+  (hashq-set! *primcall-reducers* 'name f))
+
+(define-syntax-rule (define-unary-primcall-reducer (name dfg k src
+                                                         arg type min max)
+                      body ...)
+  (define-primcall-reducer name
+    (lambda (dfg k src arg type min max) body ...)))
+
+(define-syntax-rule (define-binary-primcall-reducer (name dfg k src
+                                                          arg0 type0 min0 max0
+                                                          arg1 type1 min1 max1)
+                      body ...)
+  (define-primcall-reducer name
+    (lambda (dfg k src arg0 type0 min0 max0 arg1 type1 min1 max1) body ...)))
+
+(define-binary-primcall-reducer (mul dfg k src
+                                     arg0 type0 min0 max0
+                                     arg1 type1 min1 max1)
+  (define (negate arg)
+    (let-fresh (kzero) (zero)
+      (build-cps-term
+        ($letk ((kzero ($kargs (#f) (zero)
+                         ($continue k src ($primcall 'sub (zero arg))))))
+          ($continue kzero src ($const 0))))))
+  (define (zero)
+    (build-cps-term ($continue k src ($const 0))))
+  (define (identity arg)
+    (build-cps-term ($continue k src ($values (arg)))))
+  (define (double arg)
+    (build-cps-term ($continue k src ($primcall 'add (arg arg)))))
+  (define (power-of-two constant arg)
+    (let ((n (let lp ((bits 0) (constant constant))
+               (if (= constant 1) bits (lp (1+ bits) (ash constant -1))))))
+      (let-fresh (kbits) (bits)
+        (build-cps-term
+          ($letk ((kbits ($kargs (#f) (bits)
+                           ($continue k src ($primcall 'ash (arg bits))))))
+            ($continue kbits src ($const n)))))))
+  (define (mul/constant constant constant-type arg arg-type)
+    (and (or (= constant-type &exact-integer) (= constant-type arg-type))
+         (case constant
+           ;; (* arg -1) -> (- 0 arg)
+           ((-1) (negate arg))
+           ;; (* arg 0) -> 0 if arg is not a flonum or complex
+           ((0) (and (= constant-type &exact-integer)
+                     (zero? (logand arg-type
+                                    (lognot (logior &flonum &complex))))
+                     (zero)))
+           ;; (* arg 1) -> arg
+           ((1) (identity arg))
+           ;; (* arg 2) -> (+ arg arg)
+           ((2) (double arg))
+           (else (and (= constant-type arg-type &exact-integer)
+                      (positive? constant)
+                      (zero? (logand constant (1- constant)))
+                      (power-of-two constant arg))))))
+  (cond
+   ((logtest (logior type0 type1) (lognot &number)) #f)
+   ((= min0 max0) (mul/constant min0 type0 arg1 type1))
+   ((= min1 max1) (mul/constant min1 type1 arg0 type0))
+   (else #f)))
+
+(define-binary-primcall-reducer (logbit? dfg k src
+                                         arg0 type0 min0 max0
+                                         arg1 type1 min1 max1)
+  (define (convert-to-logtest bool-term)
+    (let-fresh (kt kf kmask kbool) (mask bool)
+     (build-cps-term
+       ($letk ((kt ($kargs () ()
+                     ($continue kbool src ($const #t))))
+               (kf ($kargs () ()
+                     ($continue kbool src ($const #f))))
+               (kbool ($kargs (#f) (bool)
+                        ,(bool-term bool)))
+               (kmask ($kargs (#f) (mask)
+                        ($continue kf src
+                          ($branch kt ($primcall 'logtest (mask arg1)))))))
+         ,(if (eq? min0 max0)
+              ($continue kmask src ($const (ash 1 min0)))
+              (let-fresh (kone) (one)
+                (build-cps-term
+                  ($letk ((kone ($kargs (#f) (one)
+                                  ($continue kmask src
+                                    ($primcall 'ash (one arg0))))))
+                    ($continue kone src ($const 1))))))))))
+  ;; Hairiness because we are converting from a primcall with unknown
+  ;; arity to a branching primcall.
+  (let ((positive-fixnum-bits (- (* (target-word-size) 8) 3)))
+    (and (= type0 &exact-integer)
+         (<= 0 min0 positive-fixnum-bits)
+         (<= 0 max0 positive-fixnum-bits)
+         (match (lookup-cont k dfg)
+           (($ $kreceive arity kargs)
+            (match arity
+              (($ $arity (_) () (not #f) () #f)
+               (convert-to-logtest
+                (lambda (bool)
+                  (let-fresh (knil) (nil)
+                    (build-cps-term
+                      ($letk ((knil ($kargs (#f) (nil)
+                                      ($continue kargs src
+                                        ($values (bool nil))))))
+                        ($continue knil src ($const '()))))))))
+              (_
+               (convert-to-logtest
+                (lambda (bool)
+                  (build-cps-term
+                    ($continue k src ($primcall 'values (bool)))))))))
+           (($ $ktail)
+            (convert-to-logtest
+             (lambda (bool)
+               (build-cps-term
+                 ($continue k src ($primcall 'return (bool)))))))))))
+
+
+\f
+
+;;
+
+(define (fold-and-reduce fun dfg min-label min-var)
   (define (scalar-value type val)
     (cond
      ((eqv? type &exact-integer) val)
      ((eqv? type &flonum) (exact->inexact val))
      ((eqv? type &char) (integer->char val))
      ((eqv? type &unspecified) *unspecified*)
-     ((eqv? type &boolean) (not (zero? val)))
+     ((eqv? type &false) #f)
+     ((eqv? type &true) #t)
      ((eqv? type &nil) #nil)
      ((eqv? type &null) '())
      (else (error "unhandled type" type val))))
-  (let* ((typev (infer-types fun dfg #:max-label-count 3000))
-         (folded? (and typev
-                       (make-bitvector (/ (vector-length typev) 2) #f)))
-         (folded-values (and typev
-                             (make-vector (bitvector-length folded?) #f))))
+  (let* ((typev (infer-types fun dfg))
+         (label-count ((make-local-cont-folder label-count)
+                       (lambda (k cont label-count) (1+ label-count))
+                       fun 0))
+         (folded? (make-bitvector label-count #f))
+         (folded-values (make-vector label-count #f))
+         (reduced-terms (make-vector label-count #f)))
     (define (label->idx label) (- label min-label))
     (define (var->idx var) (- var min-var))
-    (define (maybe-fold-value! label name k def)
-      (call-with-values (lambda () (lookup-post-type typev label def))
+    (define (maybe-reduce-primcall! label k src name args)
+      (let* ((reducer (hashq-ref *primcall-reducers* name)))
+        (when reducer
+          (vector-set!
+           reduced-terms
+           (label->idx label)
+           (match args
+             ((arg0)
+              (call-with-values (lambda () (lookup-pre-type typev label arg0))
+                (lambda (type0 min0 max0)
+                  (reducer dfg k src arg0 type0 min0 max0))))
+             ((arg0 arg1)
+              (call-with-values (lambda () (lookup-pre-type typev label arg0))
+                (lambda (type0 min0 max0)
+                  (call-with-values (lambda () (lookup-pre-type typev label arg1))
+                    (lambda (type1 min1 max1)
+                      (reducer dfg k src arg0 type0 min0 max0
+                               arg1 type1 min1 max1))))))
+             (_ #f))))))
+    (define (maybe-fold-value! label name def)
+      (call-with-values (lambda () (lookup-post-type typev label def 0))
         (lambda (type min max)
-          (when (and (not (zero? type))
-                     (zero? (logand type (1- type)))
-                     (zero? (logand type (lognot &scalar-types)))
-                     (eqv? min max))
-            (bitvector-set! folded? label #t)
-            (vector-set! folded-values label (scalar-value type min))))))
+          (cond
+           ((and (not (zero? type))
+                 (zero? (logand type (1- type)))
+                 (zero? (logand type (lognot &scalar-types)))
+                 (eqv? min max))
+            (bitvector-set! folded? (label->idx label) #t)
+            (vector-set! folded-values (label->idx label)
+                         (scalar-value type min))
+            #t)
+           (else #f)))))
     (define (maybe-fold-unary-branch! label name arg)
       (let* ((folder (hashq-ref *branch-folders* name)))
         (when folder
             (lambda (type min max)
               (call-with-values (lambda () (folder type min max))
                 (lambda (f? v)
-                  (bitvector-set! folded? label f?)
-                  (vector-set! folded-values label v))))))))
+                  (bitvector-set! folded? (label->idx label) f?)
+                  (vector-set! folded-values (label->idx label) v))))))))
     (define (maybe-fold-binary-branch! label name arg0 arg1)
       (let* ((folder (hashq-ref *branch-folders* name)))
         (when folder
                   (call-with-values (lambda ()
                                       (folder type0 min0 max0 type1 min1 max1))
                     (lambda (f? v)
-                      (bitvector-set! folded? label f?)
-                      (vector-set! folded-values label v))))))))))
+                      (bitvector-set! folded? (label->idx label) f?)
+                      (vector-set! folded-values (label->idx label) v))))))))))
     (define (visit-cont cont)
       (match cont
         (($ $cont label ($ $kargs _ _ body))
         (($ $letrec _ _ _ body)
          (visit-term body label))
         (($ $continue k src ($ $primcall name args))
-         ;; We might be able to fold primcalls that define a value or
-         ;; that branch.
+         ;; We might be able to fold primcalls that define a value.
          (match (lookup-cont k dfg)
            (($ $kargs (_) (def))
-            (maybe-fold-value! (label->idx label) name (label->idx k)
-                               (var->idx def)))
-           (($ $kif kt kf)
-            (match args
-              ((arg)
-               (maybe-fold-unary-branch! (label->idx label) name
-                                         (var->idx arg)))
-              ((arg0 arg1)
-               (maybe-fold-binary-branch! (label->idx label) name
-                                          (var->idx arg0) (var->idx arg1)))))
-           (_ #f)))
+            ;(pk 'maybe-fold-value src name args)
+            (unless (maybe-fold-value! label name def)
+              (maybe-reduce-primcall! label k src name args)))
+           (_
+            (maybe-reduce-primcall! label k src name args))))
         (($ $continue kf src ($ $branch kt ($ $primcall name args)))
          ;; We might be able to fold primcalls that branch.
+         ;(pk 'maybe-fold-branch label src name args)
          (match args
            ((arg)
-            (maybe-fold-unary-branch! (label->idx label) name
-                                      (var->idx arg)))
+            (maybe-fold-unary-branch! label name arg))
            ((arg0 arg1)
-            (maybe-fold-binary-branch! (label->idx label) name
-                                       (var->idx arg0) (var->idx arg1)))))
+            (maybe-fold-binary-branch! label name arg0 arg1))))
         (_ #f)))
     (when typev
       (match fun
         (($ $cont kfun ($ $kfun src meta self tail clause))
          (visit-cont clause))))
-    (values folded? folded-values)))
+    (values folded? folded-values reduced-terms)))
 
 (define (fold-constants* fun dfg)
   (match fun
     (($ $cont min-label ($ $kfun _ _ min-var))
-     (call-with-values (lambda () (compute-folded fun dfg min-label min-var))
-       (lambda (folded? folded-values)
+     (call-with-values (lambda () (fold-and-reduce fun dfg min-label min-var))
+       (lambda (folded? folded-values reduced-terms)
          (define (label->idx label) (- label min-label))
          (define (var->idx var) (- var min-var))
          (define (visit-cont cont)
                 ,(visit-term body label)))
              (($ $continue k src (and fun ($ $fun)))
               ($continue k src ,(visit-fun fun)))
-             (($ $continue k src (and primcall ($ $primcall)))
-              ,(if (and folded?
-                        (bitvector-ref folded? (label->idx label)))
-                   (let ((val (vector-ref folded-values (label->idx label))))
-                     ;; Uncomment for debugging.
-                     ;; (pk 'folded src primcall val)
-                     (match (lookup-cont k dfg)
-                       (($ $kargs)
-                        (let-fresh (k*) (v*)
-                          ;; Rely on DCE to elide this expression, if
-                          ;; possible.
-                          (build-cps-term
-                            ($letk ((k* ($kargs (#f) (v*)
-                                          ($continue k src ($const val)))))
-                              ($continue k* src ,primcall)))))
-                       (($ $kif kt kf)
-                        ;; Folded branch.
-                        (build-cps-term
-                          ($continue (if val kt kf) src ($values ()))))))
-                   term))
+             (($ $continue k src (and primcall ($ $primcall name args)))
+              ,(cond
+                ((bitvector-ref folded? (label->idx label))
+                 (let ((val (vector-ref folded-values (label->idx label))))
+                   ;; Uncomment for debugging.
+                   ;; (pk 'folded src primcall val)
+                   (let-fresh (k*) (v*)
+                     ;; Rely on DCE to elide this expression, if
+                     ;; possible.
+                     (build-cps-term
+                       ($letk ((k* ($kargs (#f) (v*)
+                                     ($continue k src ($const val)))))
+                         ($continue k* src ,primcall))))))
+                (else
+                 (or (vector-ref reduced-terms (label->idx label))
+                     term))))
              (($ $continue kf src ($ $branch kt ($ $primcall)))
-              ,(if (and folded?
-                        (bitvector-ref folded? (label->idx label)))
+              ,(if (bitvector-ref folded? (label->idx label))
                    ;; Folded branch.
                    (let ((val (vector-ref folded-values (label->idx label))))
                      (build-cps-term