X-Git-Url: http://git.hcoop.net/bpt/guile.git/blobdiff_plain/92805e219789654115f741b7d621bc9947833379..543d9e1a6cc94982ec99044e7c63309b716c3fa8:/module/language/cps/type-fold.scm diff --git a/module/language/cps/type-fold.scm b/module/language/cps/type-fold.scm index 20abc36f3..6cc128456 100644 --- a/module/language/cps/type-fold.scm +++ b/module/language/cps/type-fold.scm @@ -29,10 +29,16 @@ #:use-module (language cps dfg) #:use-module (language cps renumber) #:use-module (language cps types) + #:use-module (system base target) #:export (type-fold)) + + + +;; Branch folders. + (define &scalar-types - (logior &exact-integer &flonum &char &unspecified &boolean &nil &null)) + (logior &exact-integer &flonum &char &unspecified &false &true &nil &null)) (define *branch-folders* (make-hash-table)) @@ -61,7 +67,6 @@ ;; All the cases that are in compile-bytecode. (define-unary-type-predicate-folder pair? &pair) (define-unary-type-predicate-folder null? &null) -(define-unary-type-predicate-folder nil? &nil) (define-unary-type-predicate-folder symbol? &symbol) (define-unary-type-predicate-folder variable? &box) (define-unary-type-predicate-folder vector? &vector) @@ -123,33 +128,199 @@ ((= <= <) (values #t #f)) (else (values #f #f)))) -(define (compute-folded fun dfg min-label min-var) +(define-binary-branch-folder (logtest type0 min0 max0 type1 min1 max1) + (define (logand-min a b) + (if (< a b 0) + (min a b) + 0)) + (define (logand-max a b) + (if (< a b 0) + 0 + (max a b))) + (if (and (= min0 max0) (= min1 max1) (eqv? type0 type1 &exact-integer)) + (values #t (logtest min0 min1)) + (values #f #f))) + + + + +;; Strength reduction. + +(define *primcall-reducers* (make-hash-table)) + +(define-syntax-rule (define-primcall-reducer name f) + (hashq-set! *primcall-reducers* 'name f)) + +(define-syntax-rule (define-unary-primcall-reducer (name dfg k src + arg type min max) + body ...) + (define-primcall-reducer name + (lambda (dfg k src arg type min max) body ...))) + +(define-syntax-rule (define-binary-primcall-reducer (name dfg k src + arg0 type0 min0 max0 + arg1 type1 min1 max1) + body ...) + (define-primcall-reducer name + (lambda (dfg k src arg0 type0 min0 max0 arg1 type1 min1 max1) body ...))) + +(define-binary-primcall-reducer (mul dfg k src + arg0 type0 min0 max0 + arg1 type1 min1 max1) + (define (negate arg) + (let-fresh (kzero) (zero) + (build-cps-term + ($letk ((kzero ($kargs (#f) (zero) + ($continue k src ($primcall 'sub (zero arg)))))) + ($continue kzero src ($const 0)))))) + (define (zero) + (build-cps-term ($continue k src ($const 0)))) + (define (identity arg) + (build-cps-term ($continue k src ($values (arg))))) + (define (double arg) + (build-cps-term ($continue k src ($primcall 'add (arg arg))))) + (define (power-of-two constant arg) + (let ((n (let lp ((bits 0) (constant constant)) + (if (= constant 1) bits (lp (1+ bits) (ash constant -1)))))) + (let-fresh (kbits) (bits) + (build-cps-term + ($letk ((kbits ($kargs (#f) (bits) + ($continue k src ($primcall 'ash (arg bits)))))) + ($continue kbits src ($const n))))))) + (define (mul/constant constant constant-type arg arg-type) + (and (or (= constant-type &exact-integer) (= constant-type arg-type)) + (case constant + ;; (* arg -1) -> (- 0 arg) + ((-1) (negate arg)) + ;; (* arg 0) -> 0 if arg is not a flonum or complex + ((0) (and (= constant-type &exact-integer) + (zero? (logand arg-type + (lognot (logior &flonum &complex)))) + (zero))) + ;; (* arg 1) -> arg + ((1) (identity arg)) + ;; (* arg 2) -> (+ arg arg) + ((2) (double arg)) + (else (and (= constant-type arg-type &exact-integer) + (positive? constant) + (zero? (logand constant (1- constant))) + (power-of-two constant arg)))))) + (cond + ((logtest (logior type0 type1) (lognot &number)) #f) + ((= min0 max0) (mul/constant min0 type0 arg1 type1)) + ((= min1 max1) (mul/constant min1 type1 arg0 type0)) + (else #f))) + +(define-binary-primcall-reducer (logbit? dfg k src + arg0 type0 min0 max0 + arg1 type1 min1 max1) + (define (convert-to-logtest bool-term) + (let-fresh (kt kf kmask kbool) (mask bool) + (build-cps-term + ($letk ((kt ($kargs () () + ($continue kbool src ($const #t)))) + (kf ($kargs () () + ($continue kbool src ($const #f)))) + (kbool ($kargs (#f) (bool) + ,(bool-term bool))) + (kmask ($kargs (#f) (mask) + ($continue kf src + ($branch kt ($primcall 'logtest (mask arg1))))))) + ,(if (eq? min0 max0) + ($continue kmask src ($const (ash 1 min0))) + (let-fresh (kone) (one) + (build-cps-term + ($letk ((kone ($kargs (#f) (one) + ($continue kmask src + ($primcall 'ash (one arg0)))))) + ($continue kone src ($const 1)))))))))) + ;; Hairiness because we are converting from a primcall with unknown + ;; arity to a branching primcall. + (let ((positive-fixnum-bits (- (* (target-word-size) 8) 3))) + (and (= type0 &exact-integer) + (<= 0 min0 positive-fixnum-bits) + (<= 0 max0 positive-fixnum-bits) + (match (lookup-cont k dfg) + (($ $kreceive arity kargs) + (match arity + (($ $arity (_) () (not #f) () #f) + (convert-to-logtest + (lambda (bool) + (let-fresh (knil) (nil) + (build-cps-term + ($letk ((knil ($kargs (#f) (nil) + ($continue kargs src + ($values (bool nil)))))) + ($continue knil src ($const '())))))))) + (_ + (convert-to-logtest + (lambda (bool) + (build-cps-term + ($continue k src ($primcall 'values (bool))))))))) + (($ $ktail) + (convert-to-logtest + (lambda (bool) + (build-cps-term + ($continue k src ($primcall 'return (bool))))))))))) + + + + +;; + +(define (fold-and-reduce fun dfg min-label min-var) (define (scalar-value type val) (cond ((eqv? type &exact-integer) val) ((eqv? type &flonum) (exact->inexact val)) ((eqv? type &char) (integer->char val)) ((eqv? type &unspecified) *unspecified*) - ((eqv? type &boolean) (not (zero? val))) + ((eqv? type &false) #f) + ((eqv? type &true) #t) ((eqv? type &nil) #nil) ((eqv? type &null) '()) (else (error "unhandled type" type val)))) - (let* ((typev (infer-types fun dfg #:max-label-count 3000)) - (folded? (and typev - (make-bitvector (/ (vector-length typev) 2) #f))) - (folded-values (and typev - (make-vector (bitvector-length folded?) #f)))) + (let* ((typev (infer-types fun dfg)) + (label-count ((make-local-cont-folder label-count) + (lambda (k cont label-count) (1+ label-count)) + fun 0)) + (folded? (make-bitvector label-count #f)) + (folded-values (make-vector label-count #f)) + (reduced-terms (make-vector label-count #f))) (define (label->idx label) (- label min-label)) (define (var->idx var) (- var min-var)) - (define (maybe-fold-value! label name k def) - (call-with-values (lambda () (lookup-post-type typev label def)) + (define (maybe-reduce-primcall! label k src name args) + (let* ((reducer (hashq-ref *primcall-reducers* name))) + (when reducer + (vector-set! + reduced-terms + (label->idx label) + (match args + ((arg0) + (call-with-values (lambda () (lookup-pre-type typev label arg0)) + (lambda (type0 min0 max0) + (reducer dfg k src arg0 type0 min0 max0)))) + ((arg0 arg1) + (call-with-values (lambda () (lookup-pre-type typev label arg0)) + (lambda (type0 min0 max0) + (call-with-values (lambda () (lookup-pre-type typev label arg1)) + (lambda (type1 min1 max1) + (reducer dfg k src arg0 type0 min0 max0 + arg1 type1 min1 max1)))))) + (_ #f)))))) + (define (maybe-fold-value! label name def) + (call-with-values (lambda () (lookup-post-type typev label def 0)) (lambda (type min max) - (when (and (not (zero? type)) - (zero? (logand type (1- type))) - (zero? (logand type (lognot &scalar-types))) - (eqv? min max)) - (bitvector-set! folded? label #t) - (vector-set! folded-values label (scalar-value type min)))))) + (cond + ((and (not (zero? type)) + (zero? (logand type (1- type))) + (zero? (logand type (lognot &scalar-types))) + (eqv? min max)) + (bitvector-set! folded? (label->idx label) #t) + (vector-set! folded-values (label->idx label) + (scalar-value type min)) + #t) + (else #f))))) (define (maybe-fold-unary-branch! label name arg) (let* ((folder (hashq-ref *branch-folders* name))) (when folder @@ -157,8 +328,8 @@ (lambda (type min max) (call-with-values (lambda () (folder type min max)) (lambda (f? v) - (bitvector-set! folded? label f?) - (vector-set! folded-values label v)))))))) + (bitvector-set! folded? (label->idx label) f?) + (vector-set! folded-values (label->idx label) v)))))))) (define (maybe-fold-binary-branch! label name arg0 arg1) (let* ((folder (hashq-ref *branch-folders* name))) (when folder @@ -169,8 +340,8 @@ (call-with-values (lambda () (folder type0 min0 max0 type1 min1 max1)) (lambda (f? v) - (bitvector-set! folded? label f?) - (vector-set! folded-values label v)))))))))) + (bitvector-set! folded? (label->idx label) f?) + (vector-set! folded-values (label->idx label) v)))))))))) (define (visit-cont cont) (match cont (($ $cont label ($ $kargs _ _ body)) @@ -187,42 +358,34 @@ (($ $letrec _ _ _ body) (visit-term body label)) (($ $continue k src ($ $primcall name args)) - ;; We might be able to fold primcalls that define a value or - ;; that branch. + ;; We might be able to fold primcalls that define a value. (match (lookup-cont k dfg) (($ $kargs (_) (def)) - (maybe-fold-value! (label->idx label) name (label->idx k) - (var->idx def))) - (($ $kif kt kf) - (match args - ((arg) - (maybe-fold-unary-branch! (label->idx label) name - (var->idx arg))) - ((arg0 arg1) - (maybe-fold-binary-branch! (label->idx label) name - (var->idx arg0) (var->idx arg1))))) - (_ #f))) + ;(pk 'maybe-fold-value src name args) + (unless (maybe-fold-value! label name def) + (maybe-reduce-primcall! label k src name args))) + (_ + (maybe-reduce-primcall! label k src name args)))) (($ $continue kf src ($ $branch kt ($ $primcall name args))) ;; We might be able to fold primcalls that branch. + ;(pk 'maybe-fold-branch label src name args) (match args ((arg) - (maybe-fold-unary-branch! (label->idx label) name - (var->idx arg))) + (maybe-fold-unary-branch! label name arg)) ((arg0 arg1) - (maybe-fold-binary-branch! (label->idx label) name - (var->idx arg0) (var->idx arg1))))) + (maybe-fold-binary-branch! label name arg0 arg1)))) (_ #f))) (when typev (match fun (($ $cont kfun ($ $kfun src meta self tail clause)) (visit-cont clause)))) - (values folded? folded-values))) + (values folded? folded-values reduced-terms))) (define (fold-constants* fun dfg) (match fun (($ $cont min-label ($ $kfun _ _ min-var)) - (call-with-values (lambda () (compute-folded fun dfg min-label min-var)) - (lambda (folded? folded-values) + (call-with-values (lambda () (fold-and-reduce fun dfg min-label min-var)) + (lambda (folded? folded-values reduced-terms) (define (label->idx label) (- label min-label)) (define (var->idx var) (- var min-var)) (define (visit-cont cont) @@ -243,29 +406,24 @@ ,(visit-term body label))) (($ $continue k src (and fun ($ $fun))) ($continue k src ,(visit-fun fun))) - (($ $continue k src (and primcall ($ $primcall))) - ,(if (and folded? - (bitvector-ref folded? (label->idx label))) - (let ((val (vector-ref folded-values (label->idx label)))) - ;; Uncomment for debugging. - ;; (pk 'folded src primcall val) - (match (lookup-cont k dfg) - (($ $kargs) - (let-fresh (k*) (v*) - ;; Rely on DCE to elide this expression, if - ;; possible. - (build-cps-term - ($letk ((k* ($kargs (#f) (v*) - ($continue k src ($const val))))) - ($continue k* src ,primcall))))) - (($ $kif kt kf) - ;; Folded branch. - (build-cps-term - ($continue (if val kt kf) src ($values ())))))) - term)) + (($ $continue k src (and primcall ($ $primcall name args))) + ,(cond + ((bitvector-ref folded? (label->idx label)) + (let ((val (vector-ref folded-values (label->idx label)))) + ;; Uncomment for debugging. + ;; (pk 'folded src primcall val) + (let-fresh (k*) (v*) + ;; Rely on DCE to elide this expression, if + ;; possible. + (build-cps-term + ($letk ((k* ($kargs (#f) (v*) + ($continue k src ($const val))))) + ($continue k* src ,primcall)))))) + (else + (or (vector-ref reduced-terms (label->idx label)) + term)))) (($ $continue kf src ($ $branch kt ($ $primcall))) - ,(if (and folded? - (bitvector-ref folded? (label->idx label))) + ,(if (bitvector-ref folded? (label->idx label)) ;; Folded branch. (let ((val (vector-ref folded-values (label->idx label)))) (build-cps-term