Transient intsets
[bpt/guile.git] / module / language / cps / intset.scm
index 8c5fef7..fb42a1f 100644 (file)
@@ -1,5 +1,5 @@
 ;;; Functional name maps
-;;; Copyright (C) 2014 Free Software Foundation, Inc.
+;;; Copyright (C) 2014, 2015 Free Software Foundation, Inc.
 ;;;
 ;;; This library is free software: you can redistribute it and/or modify
 ;;; it under the terms of the GNU Lesser General Public License as
   #:use-module (ice-9 match)
   #:export (empty-intset
             intset?
+            transient-intset?
+            persistent-intset
+            transient-intset
             intset-add
+            intset-add!
             intset-remove
             intset-ref
             intset-next
+            intset-fold
+            intset-fold2
             intset-union
-            intset-intersect))
+            intset-intersect
+            intset-subtract
+            bitvector->intset))
 
 (define-syntax-rule (define-inline name val)
   (define-syntax name (identifier-syntax val)))
 
-(define-inline *leaf-bits* 5)
+(eval-when (expand)
+  (use-modules (system base target))
+  (define-syntax compile-time-cond
+    (lambda (x)
+      (syntax-case x (else)
+        ((_ (test body ...) rest ...)
+         (if (primitive-eval (syntax->datum #'test))
+             #'(begin body ...)
+             #'(begin (compile-time-cond rest ...))))
+        ((_ (else body ...))
+         #'(begin body ...))
+        ((_)
+         (error "no compile-time-cond expression matched"))))))
+
+(compile-time-cond
+ ((eqv? (target-word-size) 4)
+  (define-inline *leaf-bits* 4))
+ ((eqv? (target-word-size) 8)
+  (define-inline *leaf-bits* 5)))
+
+;; FIXME: This should make an actual atomic reference.
+(define-inlinable (make-atomic-reference value)
+  (list value))
+(define-inlinable (get-atomic-reference reference)
+  (car reference))
+(define-inlinable (set-atomic-reference! reference value)
+  (set-car! reference value))
+
 (define-inline *leaf-size* (ash 1 *leaf-bits*))
 (define-inline *leaf-mask* (1- *leaf-size*))
 (define-inline *branch-bits* 3)
 (define-inline *branch-size* (ash 1 *branch-bits*))
+(define-inline *branch-size-with-edit* (1+ *branch-size*))
+(define-inline *edit-index* *branch-size*)
 (define-inline *branch-mask* (1- *branch-size*))
 
 (define-record-type <intset>
   (shift intset-shift)
   (root intset-root))
 
+(define-record-type <transient-intset>
+  (make-transient-intset min shift root edit)
+  transient-intset?
+  (min transient-intset-min set-transient-intset-min!)
+  (shift transient-intset-shift set-transient-intset-shift!)
+  (root transient-intset-root set-transient-intset-root!)
+  (edit transient-intset-edit set-transient-intset-edit!))
+
 (define (new-leaf) 0)
 (define-inlinable (clone-leaf-and-set leaf i val)
   (if val
 (define (leaf-empty? leaf)
   (zero? leaf))
 
-(define (new-branch)
-  (make-vector *branch-size* #f))
+(define-inlinable (new-branch edit)
+  (let ((vec (make-vector *branch-size-with-edit* #f)))
+    (when edit (vector-set! vec *edit-index* edit))
+    vec))
 (define (clone-branch-and-set branch i elt)
-  (let ((new (new-branch)))
+  (let ((new (new-branch #f)))
     (when branch (vector-move-left! branch 0 *branch-size* new 0))
     (vector-set! new i elt)
     new))
+(define-inlinable (assert-readable! root-edit)
+  (unless (eq? (get-atomic-reference root-edit) (current-thread))
+    (error "Transient intset owned by another thread" root-edit)))
+(define-inlinable (writable-branch branch root-edit)
+  (let ((edit (vector-ref branch *edit-index*)))
+    (if (eq? root-edit edit)
+        branch
+        (clone-branch-and-set branch *edit-index* root-edit))))
 (define (branch-empty? branch)
   (let lp ((i 0))
     (or (= i *branch-size*)
        ;; Shouldn't be reached...
        (else empty-intset))))))
 
+(define* (transient-intset #:optional (source empty-intset))
+  (match source
+    (($ <transient-intset> min shift root edit)
+     (assert-readable! edit)
+     source)
+    (($ <intset> min shift root)
+     (let ((edit (make-atomic-reference (current-thread))))
+       (make-transient-intset min shift root edit)))))
+
+(define* (persistent-intset #:optional (source empty-intset))
+  (match source
+    (($ <transient-intset> min shift root edit)
+     (assert-readable! edit)
+     ;; Make a fresh reference, causing any further operations on this
+     ;; transient to clone its root afresh.
+     (set-transient-intset-edit! source
+                                 (make-atomic-reference (current-thread)))
+     ;; Clear the reference to the current thread, causing our edited
+     ;; data structures to be persistent again.
+     (set-atomic-reference! edit #f)
+     (if min
+         (make-intset min shift root)
+         empty-intset))
+    (($ <intset>)
+     source)))
+
+(define (intset-add! bs i)
+  (define (adjoin-leaf i root)
+    (clone-leaf-and-set root (logand i *leaf-mask*) #t))
+  (define (ensure-branch! root idx)
+    (let ((edit (vector-ref root *edit-index*)))
+      (match (vector-ref root idx)
+        (#f (let ((v (new-branch edit)))
+              (vector-set! root idx v)
+              v))
+        (v (writable-branch v edit)))))
+  (define (adjoin-branch! i shift root)
+    (let* ((shift (- shift *branch-bits*))
+           (idx (logand (ash i (- shift)) *branch-mask*)))
+      (cond
+       ((= shift *leaf-bits*)
+        (vector-set! root idx (adjoin-leaf i (vector-ref root idx))))
+       (else
+        (adjoin-branch! i shift (ensure-branch! root idx))))))
+  (match bs
+    (($ <transient-intset> min shift root edit)
+     (assert-readable! edit)
+     (cond
+      ((< i 0)
+       ;; The power-of-two spanning trick doesn't work across 0.
+       (error "Intsets can only hold non-negative integers." i))
+      ((not root)
+       ;; Add first element.
+       (let ((min (round-down i shift)))
+         (set-transient-intset-min! bs min)
+         (set-transient-intset-shift! bs *leaf-bits*)
+         (set-transient-intset-root! bs (adjoin-leaf (- i min) root))))
+      ((and (<= min i) (< i (+ min (ash 1 shift))))
+       ;; Add element to set; level will not change.
+       (if (= shift *leaf-bits*)
+           (set-transient-intset-root! bs (adjoin-leaf (- i min) root))
+           (adjoin-branch! (- i min) shift root)))
+      (else
+       (let lp ((min min)
+                (shift shift)
+                (root (if (eqv? shift *leaf-bits*)
+                          root
+                          (writable-branch root edit))))
+         (let* ((shift* (+ shift *branch-bits*))
+                (min* (round-down min shift*))
+                (idx (logand (ash (- min min*) (- shift)) *branch-mask*))
+                (root* (new-branch edit)))
+           (vector-set! root* idx root)
+           (cond
+            ((and (<= min* i) (< i (+ min* (ash 1 shift*))))
+             (set-transient-intset-min! bs min*)
+             (set-transient-intset-shift! bs shift*)
+             (set-transient-intset-root! bs root*)
+             (adjoin-branch! (- i min*) shift* root*))
+            (else
+             (lp min* shift* root*)))))))
+     bs)
+    (($ <intset>)
+     (intset-add! (transient-intset bs) i))))
+
 (define (intset-add bs i)
   (define (adjoin i shift root)
     (cond
   (match bs
     (($ <intset> min shift root)
      (cond
+      ((< i 0)
+       ;; The power-of-two spanning trick doesn't work across 0.
+       (error "Intsets can only hold non-negative integers." i))
       ((not root)
        ;; Add first element.
        (let ((min (round-down i shift)))
       (else bs)))))
 
 (define (intset-ref bs i)
+  (define (ref min shift root)
+    (and (<= min i) (< i (+ min (ash 1 shift)))
+         (let ((i (- i min)))
+           (let lp ((node root) (shift shift))
+             (and node
+                  (if (= shift *leaf-bits*)
+                      (logbit? (logand i *leaf-mask*) node)
+                      (let* ((shift (- shift *branch-bits*))
+                             (idx (logand (ash i (- shift)) *branch-mask*)))
+                        (lp (vector-ref node idx) shift))))))))
   (match bs
     (($ <intset> min shift root)
-     (and (<= min i) (< i (+ min (ash 1 shift)))
-          (let ((i (- i min)))
-            (let lp ((node root) (shift shift))
-              (and node
-                   (if (= shift *leaf-bits*)
-                       (logbit? (logand i *leaf-mask*) node)
-                       (let* ((shift (- shift *branch-bits*))
-                              (idx (logand (ash i (- shift)) *branch-mask*)))
-                         (lp (vector-ref node idx) shift))))))))))
+     (ref min shift root))
+    (($ <transient-intset> min shift root edit)
+     (assert-readable! edit)
+     (ref min shift root))))
 
 (define (intset-next bs i)
   (define (visit-leaf node i)
   (define (visit-branch node shift i)
     (let lp ((i i) (idx (logand (ash i (- shift)) *branch-mask*)))
       (and (< idx *branch-size*)
-           (or (visit-node (vector-ref node idx) shift i)
+           (or (let ((node (vector-ref node idx)))
+                 (and node (visit-node node shift i)))
                (let ((inc (ash 1 shift)))
                  (lp (+ (round-down i shift) inc) (1+ idx)))))))
   (define (visit-node node shift i)
-    (and node
-         (if (= shift *leaf-bits*)
-             (visit-leaf node i)
-             (visit-branch node (- shift *branch-bits*) i))))
+    (if (= shift *leaf-bits*)
+        (visit-leaf node i)
+        (visit-branch node (- shift *branch-bits*) i)))
+  (define (next min shift root)
+    (let ((i (if (and i (< min i))
+                 (- i min)
+                 0)))
+      (and root (< i (ash 1 shift))
+           (let ((i (visit-node root shift i)))
+             (and i (+ min i))))))
   (match bs
     (($ <intset> min shift root)
-     (let ((i (if (and i (< min i))
-                  (- i min)
-                  0)))
-       (and (< i (ash 1 shift))
-            (let ((i (visit-node root shift i)))
-              (and i (+ min i))))))))
+     (next min shift root))
+    (($ <transient-intset> min shift root edit)
+     (assert-readable! edit)
+     (next min shift root))))
+
+(define (intset-fold f set seed)
+  (define (visit-branch node shift min seed)
+    (cond
+     ((= shift *leaf-bits*)
+      (let lp ((i 0) (seed seed))
+        (if (< i *leaf-size*)
+            (lp (1+ i)
+                (if (logbit? i node)
+                    (f (+ i min) seed)
+                    seed))
+            seed)))
+     (else
+      (let ((shift (- shift *branch-bits*)))
+        (let lp ((i 0) (seed seed))
+          (if (< i *branch-size*)
+              (let ((elt (vector-ref node i)))
+                (lp (1+ i)
+                    (if elt
+                        (visit-branch elt shift (+ min (ash i shift)) seed)
+                        seed)))
+              seed))))))
+  (match set
+    (($ <intset> min shift root)
+     (cond
+      ((not root) seed)
+      (else (visit-branch root shift min seed))))
+    (($ <transient-intset>)
+     (intset-fold f (persistent-intset set) seed))))
+
+(define (intset-fold2 f set s0 s1)
+  (define (visit-branch node shift min s0 s1)
+    (cond
+     ((= shift *leaf-bits*)
+      (let lp ((i 0) (s0 s0) (s1 s1))
+        (if (< i *leaf-size*)
+            (if (logbit? i node)
+                (call-with-values (lambda () (f (+ i min) s0 s1))
+                  (lambda (s0 s1)
+                    (lp (1+ i) s0 s1)))
+                (lp (1+ i) s0 s1))
+            (values s0 s1))))
+     (else
+      (let ((shift (- shift *branch-bits*)))
+        (let lp ((i 0) (s0 s0) (s1 s1))
+          (if (< i *branch-size*)
+              (let ((elt (vector-ref node i)))
+                (if elt
+                    (call-with-values
+                        (lambda ()
+                          (visit-branch elt shift (+ min (ash i shift)) s0 s1))
+                      (lambda (s0 s1)
+                        (lp (1+ i) s0 s1)))
+                    (lp (1+ i) s0 s1)))
+              (values s0 s1)))))))
+  (match set
+    (($ <intset> min shift root)
+     (cond
+      ((not root) (values s0 s1))
+      (else (visit-branch root shift min s0 s1))))
+    (($ <transient-intset>)
+     (intset-fold2 f (persistent-intset set) s0 s1))))
 
 (define (intset-size shift root)
   (cond
      ((eq? a-node b-node) a-node)
      ((= shift *leaf-bits*) (intersect-leaves a-node b-node))
      (else (intersect-branches (- shift *branch-bits*) a-node b-node))))
+
+  (define (different-mins lo-min lo-shift lo-root hi-min hi-shift hi lo-is-a?)
+    (cond
+     ((<= lo-shift hi-shift)
+      ;; If LO has a lower shift and a lower min, it is disjoint.  If
+      ;; it has the same shift and a different min, it is also
+      ;; disjoint.
+      empty-intset)
+     (else
+      (let* ((lo-shift (- lo-shift *branch-bits*))
+             (lo-idx (ash (- hi-min lo-min) (- lo-shift))))
+        (cond
+         ((>= lo-idx *branch-size*)
+          ;; HI has a lower shift, but it not within LO.
+          empty-intset)
+         ((vector-ref lo-root lo-idx)
+          => (lambda (lo-root)
+               (let ((lo (make-intset (+ lo-min (ash lo-idx lo-shift))
+                                      lo-shift
+                                      lo-root)))
+                 (if lo-is-a?
+                     (intset-intersect lo hi)
+                     (intset-intersect hi lo)))))
+         (else empty-intset))))))
+
+  (define (different-shifts-same-min min hi-shift hi-root lo lo-is-a?)
+    (cond
+     ((vector-ref hi-root 0)
+      => (lambda (hi-root)
+           (let ((hi (make-intset min
+                                  (- hi-shift *branch-bits*)
+                                  hi-root)))
+             (if lo-is-a?
+                 (intset-intersect lo hi)
+                 (intset-intersect hi lo)))))
+     (else empty-intset)))
+
   (match (cons a b)
     ((($ <intset> a-min a-shift a-root) . ($ <intset> b-min b-shift b-root))
      (cond
       ((< a-min b-min)
-       ;; Make A have the higher min.
-       (intset-intersect b a))
+       (different-mins a-min a-shift a-root b-min b-shift b #t))
       ((< b-min a-min)
-       (cond
-        ((<= b-shift a-shift)
-         ;; If B has a lower shift and a lower min, it is disjoint.  If
-         ;; it has the same shift and a different min, it is also
-         ;; disjoint.
-         empty-intset)
-        (else
-         (let* ((b-shift (- b-shift *branch-bits*))
-                (b-idx (ash (- a-min b-min) (- b-shift))))
-           (if (>= b-idx *branch-size*)
-               ;; A has a lower shift, but it not within B.
-               empty-intset
-               (intset-intersect a
-                                 (make-intset (+ b-min (ash b-idx b-shift))
-                                              b-shift
-                                              (vector-ref b-root b-idx))))))))
-      ((< b-shift a-shift)
-       ;; Make A have the lower shift.
-       (intset-intersect b a))
+       (different-mins b-min b-shift b-root a-min a-shift a #f))
       ((< a-shift b-shift)
-       ;; A and B have the same min but a different shift.  Recurse down.
-       (intset-intersect a
-                         (make-intset b-min
-                                      (- b-shift *branch-bits*)
-                                      (vector-ref b-root 0))))
+       (different-shifts-same-min b-min b-shift b-root a #t))
+      ((< b-shift a-shift)
+       (different-shifts-same-min a-min a-shift a-root b #f))
       (else
        ;; At this point, A and B cover the same range.
        (let ((root (intersect a-shift a-root b-root)))
           ((eq? root a-root) a)
           ((eq? root b-root) b)
           (else (make-intset/prune a-min a-shift root)))))))))
+
+(define (intset-subtract a b)
+  (define tmp (new-leaf))
+  ;; Intersect leaves.
+  (define (subtract-leaves a b)
+    (logand a (lognot b)))
+  ;; Subtract B from A starting at index I; the result will be fresh.
+  (define (subtract-branches/fresh shift a b i fresh)
+    (let lp ((i 0))
+      (cond
+       ((< i *branch-size*)
+        (let* ((a-child (vector-ref a i))
+               (b-child (vector-ref b i)))
+          (vector-set! fresh i (subtract-nodes shift a-child b-child))
+          (lp (1+ i))))
+       ((branch-empty? fresh) #f)
+       (else fresh))))
+  ;; Subtract B from A.  The result may be eq? to A.
+  (define (subtract-branches shift a b)
+    (let lp ((i 0))
+      (cond
+       ((< i *branch-size*)
+        (let* ((a-child (vector-ref a i))
+               (b-child (vector-ref b i)))
+          (let ((child (subtract-nodes shift a-child b-child)))
+            (cond
+             ((eq? a-child child)
+              (lp (1+ i)))
+             (else
+              (let ((result (clone-branch-and-set a i child)))
+                (subtract-branches/fresh shift a b (1+ i) result)))))))
+       (else a))))
+  (define (subtract-nodes shift a-node b-node)
+    (cond
+     ((or (not a-node) (not b-node)) a-node)
+     ((eq? a-node b-node) #f)
+     ((= shift *leaf-bits*) (subtract-leaves a-node b-node))
+     (else (subtract-branches (- shift *branch-bits*) a-node b-node))))
+
+  (match (cons a b)
+    ((($ <intset> a-min a-shift a-root) . ($ <intset> b-min b-shift b-root))
+     (define (return root)
+       (cond
+        ((eq? root a-root) a)
+        (else (make-intset/prune a-min a-shift root))))
+     (cond
+      ((<= a-shift b-shift)
+       (let lp ((b-min b-min) (b-shift b-shift) (b-root b-root))
+         (if (= a-shift b-shift)
+             (if (= a-min b-min)
+                 (return (subtract-nodes a-shift a-root b-root))
+                 a)
+             (let* ((b-shift (- b-shift *branch-bits*))
+                    (b-idx (ash (- a-min b-min) (- b-shift)))
+                    (b-min (+ b-min (ash b-idx b-shift)))
+                    (b-root (and b-root
+                                 (<= 0 b-idx)
+                                 (< b-idx *branch-size*)
+                                 (vector-ref b-root b-idx))))
+               (lp b-min b-shift b-root)))))
+      (else
+       (return
+        (let lp ((a-min a-min) (a-shift a-shift) (a-root a-root))
+          (if (= a-shift b-shift)
+              (if (= a-min b-min)
+                  (subtract-nodes a-shift a-root b-root)
+                  a-root)
+              (let* ((a-shift (- a-shift *branch-bits*))
+                     (a-idx (ash (- b-min a-min) (- a-shift)))
+                     (a-min (+ a-min (ash a-idx a-shift)))
+                     (old (and a-root
+                               (<= 0 a-idx)
+                               (< a-idx *branch-size*)
+                               (vector-ref a-root a-idx)))
+                     (new (lp a-min a-shift old)))
+                (if (eq? old new)
+                    a-root
+                    (clone-branch-and-set a-root a-idx new)))))))))))
+
+(define (bitvector->intset bv)
+  (define (finish-tail out min tail)
+    (if (zero? tail)
+        out
+        (intset-union out (make-intset min *leaf-bits* tail))))
+  (let lp ((out empty-intset) (min 0) (pos 0) (tail 0))
+    (let ((pos (bit-position #t bv pos)))
+      (cond
+       ((not pos)
+        (finish-tail out min tail))
+       ((< pos (+ min *leaf-size*))
+        (lp out min (1+ pos) (logior tail (ash 1 (- pos min)))))
+       (else
+        (let ((min* (round-down pos *leaf-bits*)))
+          (lp (finish-tail out min tail)
+              min* pos (ash 1 (- pos min*)))))))))