#:use-module (language cps dfg)
#:use-module (language cps renumber)
#:use-module (language cps types)
+ #:use-module (system base target)
#:export (type-fold))
+
+\f
+
+;; Branch folders.
+
(define &scalar-types
- (logior &exact-integer &flonum &char &unspecified &boolean &nil &null))
+ (logior &exact-integer &flonum &char &unspecified &false &true &nil &null))
(define *branch-folders* (make-hash-table))
;; All the cases that are in compile-bytecode.
(define-unary-type-predicate-folder pair? &pair)
(define-unary-type-predicate-folder null? &null)
-(define-unary-type-predicate-folder nil? &nil)
(define-unary-type-predicate-folder symbol? &symbol)
(define-unary-type-predicate-folder variable? &box)
(define-unary-type-predicate-folder vector? &vector)
((= <= <) (values #t #f))
(else (values #f #f))))
-(define (compute-folded fun dfg min-label min-var)
+(define-binary-branch-folder (logtest type0 min0 max0 type1 min1 max1)
+ (define (logand-min a b)
+ (if (< a b 0)
+ (min a b)
+ 0))
+ (define (logand-max a b)
+ (if (< a b 0)
+ 0
+ (max a b)))
+ (if (and (= min0 max0) (= min1 max1) (eqv? type0 type1 &exact-integer))
+ (values #t (logtest min0 min1))
+ (values #f #f)))
+
+
+\f
+
+;; Strength reduction.
+
+(define *primcall-reducers* (make-hash-table))
+
+(define-syntax-rule (define-primcall-reducer name f)
+ (hashq-set! *primcall-reducers* 'name f))
+
+(define-syntax-rule (define-unary-primcall-reducer (name dfg k src
+ arg type min max)
+ body ...)
+ (define-primcall-reducer name
+ (lambda (dfg k src arg type min max) body ...)))
+
+(define-syntax-rule (define-binary-primcall-reducer (name dfg k src
+ arg0 type0 min0 max0
+ arg1 type1 min1 max1)
+ body ...)
+ (define-primcall-reducer name
+ (lambda (dfg k src arg0 type0 min0 max0 arg1 type1 min1 max1) body ...)))
+
+(define-binary-primcall-reducer (mul dfg k src
+ arg0 type0 min0 max0
+ arg1 type1 min1 max1)
+ (define (negate arg)
+ (let-fresh (kzero) (zero)
+ (build-cps-term
+ ($letk ((kzero ($kargs (#f) (zero)
+ ($continue k src ($primcall 'sub (zero arg))))))
+ ($continue kzero src ($const 0))))))
+ (define (zero)
+ (build-cps-term ($continue k src ($const 0))))
+ (define (identity arg)
+ (build-cps-term ($continue k src ($values (arg)))))
+ (define (double arg)
+ (build-cps-term ($continue k src ($primcall 'add (arg arg)))))
+ (define (power-of-two constant arg)
+ (let ((n (let lp ((bits 0) (constant constant))
+ (if (= constant 1) bits (lp (1+ bits) (ash constant -1))))))
+ (let-fresh (kbits) (bits)
+ (build-cps-term
+ ($letk ((kbits ($kargs (#f) (bits)
+ ($continue k src ($primcall 'ash (arg bits))))))
+ ($continue kbits src ($const n)))))))
+ (define (mul/constant constant constant-type arg arg-type)
+ (and (or (= constant-type &exact-integer) (= constant-type arg-type))
+ (case constant
+ ;; (* arg -1) -> (- 0 arg)
+ ((-1) (negate arg))
+ ;; (* arg 0) -> 0 if arg is not a flonum or complex
+ ((0) (and (= constant-type &exact-integer)
+ (zero? (logand arg-type
+ (lognot (logior &flonum &complex))))
+ (zero)))
+ ;; (* arg 1) -> arg
+ ((1) (identity arg))
+ ;; (* arg 2) -> (+ arg arg)
+ ((2) (double arg))
+ (else (and (= constant-type arg-type &exact-integer)
+ (positive? constant)
+ (zero? (logand constant (1- constant)))
+ (power-of-two constant arg))))))
+ (cond
+ ((logtest (logior type0 type1) (lognot &number)) #f)
+ ((= min0 max0) (mul/constant min0 type0 arg1 type1))
+ ((= min1 max1) (mul/constant min1 type1 arg0 type0))
+ (else #f)))
+
+(define-binary-primcall-reducer (logbit? dfg k src
+ arg0 type0 min0 max0
+ arg1 type1 min1 max1)
+ (define (convert-to-logtest bool-term)
+ (let-fresh (kt kf kmask kbool) (mask bool)
+ (build-cps-term
+ ($letk ((kt ($kargs () ()
+ ($continue kbool src ($const #t))))
+ (kf ($kargs () ()
+ ($continue kbool src ($const #f))))
+ (kbool ($kargs (#f) (bool)
+ ,(bool-term bool)))
+ (kmask ($kargs (#f) (mask)
+ ($continue kf src
+ ($branch kt ($primcall 'logtest (mask arg1)))))))
+ ,(if (eq? min0 max0)
+ ($continue kmask src ($const (ash 1 min0)))
+ (let-fresh (kone) (one)
+ (build-cps-term
+ ($letk ((kone ($kargs (#f) (one)
+ ($continue kmask src
+ ($primcall 'ash (one arg0))))))
+ ($continue kone src ($const 1))))))))))
+ ;; Hairiness because we are converting from a primcall with unknown
+ ;; arity to a branching primcall.
+ (let ((positive-fixnum-bits (- (* (target-word-size) 8) 3)))
+ (and (= type0 &exact-integer)
+ (<= 0 min0 positive-fixnum-bits)
+ (<= 0 max0 positive-fixnum-bits)
+ (match (lookup-cont k dfg)
+ (($ $kreceive arity kargs)
+ (match arity
+ (($ $arity (_) () (not #f) () #f)
+ (convert-to-logtest
+ (lambda (bool)
+ (let-fresh (knil) (nil)
+ (build-cps-term
+ ($letk ((knil ($kargs (#f) (nil)
+ ($continue kargs src
+ ($values (bool nil))))))
+ ($continue knil src ($const '()))))))))
+ (_
+ (convert-to-logtest
+ (lambda (bool)
+ (build-cps-term
+ ($continue k src ($primcall 'values (bool)))))))))
+ (($ $ktail)
+ (convert-to-logtest
+ (lambda (bool)
+ (build-cps-term
+ ($continue k src ($primcall 'return (bool)))))))))))
+
+
+\f
+
+;;
+
+(define (fold-and-reduce fun dfg min-label min-var)
(define (scalar-value type val)
(cond
((eqv? type &exact-integer) val)
((eqv? type &flonum) (exact->inexact val))
((eqv? type &char) (integer->char val))
((eqv? type &unspecified) *unspecified*)
- ((eqv? type &boolean) (not (zero? val)))
+ ((eqv? type &false) #f)
+ ((eqv? type &true) #t)
((eqv? type &nil) #nil)
((eqv? type &null) '())
(else (error "unhandled type" type val))))
- (let* ((typev (infer-types fun dfg #:max-label-count 3000))
- (folded? (and typev
- (make-bitvector (/ (vector-length typev) 2) #f)))
- (folded-values (and typev
- (make-vector (bitvector-length folded?) #f))))
+ (let* ((typev (infer-types fun dfg))
+ (label-count ((make-local-cont-folder label-count)
+ (lambda (k cont label-count) (1+ label-count))
+ fun 0))
+ (folded? (make-bitvector label-count #f))
+ (folded-values (make-vector label-count #f))
+ (reduced-terms (make-vector label-count #f)))
(define (label->idx label) (- label min-label))
(define (var->idx var) (- var min-var))
- (define (maybe-fold-value! label name k def)
- (call-with-values (lambda () (lookup-post-type typev label def))
+ (define (maybe-reduce-primcall! label k src name args)
+ (let* ((reducer (hashq-ref *primcall-reducers* name)))
+ (when reducer
+ (vector-set!
+ reduced-terms
+ (label->idx label)
+ (match args
+ ((arg0)
+ (call-with-values (lambda () (lookup-pre-type typev label arg0))
+ (lambda (type0 min0 max0)
+ (reducer dfg k src arg0 type0 min0 max0))))
+ ((arg0 arg1)
+ (call-with-values (lambda () (lookup-pre-type typev label arg0))
+ (lambda (type0 min0 max0)
+ (call-with-values (lambda () (lookup-pre-type typev label arg1))
+ (lambda (type1 min1 max1)
+ (reducer dfg k src arg0 type0 min0 max0
+ arg1 type1 min1 max1))))))
+ (_ #f))))))
+ (define (maybe-fold-value! label name def)
+ (call-with-values (lambda () (lookup-post-type typev label def 0))
(lambda (type min max)
- (when (and (not (zero? type))
- (zero? (logand type (1- type)))
- (zero? (logand type (lognot &scalar-types)))
- (eqv? min max))
- (bitvector-set! folded? label #t)
- (vector-set! folded-values label (scalar-value type min))))))
+ (cond
+ ((and (not (zero? type))
+ (zero? (logand type (1- type)))
+ (zero? (logand type (lognot &scalar-types)))
+ (eqv? min max))
+ (bitvector-set! folded? (label->idx label) #t)
+ (vector-set! folded-values (label->idx label)
+ (scalar-value type min))
+ #t)
+ (else #f)))))
(define (maybe-fold-unary-branch! label name arg)
(let* ((folder (hashq-ref *branch-folders* name)))
(when folder
(lambda (type min max)
(call-with-values (lambda () (folder type min max))
(lambda (f? v)
- (bitvector-set! folded? label f?)
- (vector-set! folded-values label v))))))))
+ (bitvector-set! folded? (label->idx label) f?)
+ (vector-set! folded-values (label->idx label) v))))))))
(define (maybe-fold-binary-branch! label name arg0 arg1)
(let* ((folder (hashq-ref *branch-folders* name)))
(when folder
(call-with-values (lambda ()
(folder type0 min0 max0 type1 min1 max1))
(lambda (f? v)
- (bitvector-set! folded? label f?)
- (vector-set! folded-values label v))))))))))
+ (bitvector-set! folded? (label->idx label) f?)
+ (vector-set! folded-values (label->idx label) v))))))))))
(define (visit-cont cont)
(match cont
(($ $cont label ($ $kargs _ _ body))
(($ $letrec _ _ _ body)
(visit-term body label))
(($ $continue k src ($ $primcall name args))
- ;; We might be able to fold primcalls that define a value or
- ;; that branch.
+ ;; We might be able to fold primcalls that define a value.
(match (lookup-cont k dfg)
(($ $kargs (_) (def))
- (maybe-fold-value! (label->idx label) name (label->idx k)
- (var->idx def)))
- (($ $kif kt kf)
- (match args
- ((arg)
- (maybe-fold-unary-branch! (label->idx label) name
- (var->idx arg)))
- ((arg0 arg1)
- (maybe-fold-binary-branch! (label->idx label) name
- (var->idx arg0) (var->idx arg1)))))
- (_ #f)))
+ ;(pk 'maybe-fold-value src name args)
+ (unless (maybe-fold-value! label name def)
+ (maybe-reduce-primcall! label k src name args)))
+ (_
+ (maybe-reduce-primcall! label k src name args))))
(($ $continue kf src ($ $branch kt ($ $primcall name args)))
;; We might be able to fold primcalls that branch.
+ ;(pk 'maybe-fold-branch label src name args)
(match args
((arg)
- (maybe-fold-unary-branch! (label->idx label) name
- (var->idx arg)))
+ (maybe-fold-unary-branch! label name arg))
((arg0 arg1)
- (maybe-fold-binary-branch! (label->idx label) name
- (var->idx arg0) (var->idx arg1)))))
+ (maybe-fold-binary-branch! label name arg0 arg1))))
(_ #f)))
(when typev
(match fun
(($ $cont kfun ($ $kfun src meta self tail clause))
(visit-cont clause))))
- (values folded? folded-values)))
+ (values folded? folded-values reduced-terms)))
(define (fold-constants* fun dfg)
(match fun
(($ $cont min-label ($ $kfun _ _ min-var))
- (call-with-values (lambda () (compute-folded fun dfg min-label min-var))
- (lambda (folded? folded-values)
+ (call-with-values (lambda () (fold-and-reduce fun dfg min-label min-var))
+ (lambda (folded? folded-values reduced-terms)
(define (label->idx label) (- label min-label))
(define (var->idx var) (- var min-var))
(define (visit-cont cont)
,(visit-term body label)))
(($ $continue k src (and fun ($ $fun)))
($continue k src ,(visit-fun fun)))
- (($ $continue k src (and primcall ($ $primcall)))
- ,(if (and folded?
- (bitvector-ref folded? (label->idx label)))
- (let ((val (vector-ref folded-values (label->idx label))))
- ;; Uncomment for debugging.
- ;; (pk 'folded src primcall val)
- (match (lookup-cont k dfg)
- (($ $kargs)
- (let-fresh (k*) (v*)
- ;; Rely on DCE to elide this expression, if
- ;; possible.
- (build-cps-term
- ($letk ((k* ($kargs (#f) (v*)
- ($continue k src ($const val)))))
- ($continue k* src ,primcall)))))
- (($ $kif kt kf)
- ;; Folded branch.
- (build-cps-term
- ($continue (if val kt kf) src ($values ()))))))
- term))
+ (($ $continue k src (and primcall ($ $primcall name args)))
+ ,(cond
+ ((bitvector-ref folded? (label->idx label))
+ (let ((val (vector-ref folded-values (label->idx label))))
+ ;; Uncomment for debugging.
+ ;; (pk 'folded src primcall val)
+ (let-fresh (k*) (v*)
+ ;; Rely on DCE to elide this expression, if
+ ;; possible.
+ (build-cps-term
+ ($letk ((k* ($kargs (#f) (v*)
+ ($continue k src ($const val)))))
+ ($continue k* src ,primcall))))))
+ (else
+ (or (vector-ref reduced-terms (label->idx label))
+ term))))
(($ $continue kf src ($ $branch kt ($ $primcall)))
- ,(if (and folded?
- (bitvector-ref folded? (label->idx label)))
+ ,(if (bitvector-ref folded? (label->idx label))
;; Folded branch.
(let ((val (vector-ref folded-values (label->idx label))))
(build-cps-term