| 1 | (* Copyright (C) 2009 Matthew Fluet. |
| 2 | * Copyright (C) 1999-2007 Henry Cejtin, Matthew Fluet, Suresh |
| 3 | * Jagannathan, and Stephen Weeks. |
| 4 | * |
| 5 | * MLton is released under a BSD-style license. |
| 6 | * See the file MLton-LICENSE for details. |
| 7 | *) |
| 8 | |
| 9 | functor HashedUniqueSet(structure Set : SET |
| 10 | structure Element : sig include T val hash : t -> word end |
| 11 | sharing type Set.Element.t = Element.t) : SET = |
| 12 | struct |
| 13 | |
| 14 | structure Set = Set |
| 15 | structure Element = Element |
| 16 | val hash = Element.hash |
| 17 | |
| 18 | fun index (w: word, mask: word): int |
| 19 | = Word.toInt (Word.andb (w, mask)) |
| 20 | |
| 21 | datatype t = T of {buckets: Set.t vector, |
| 22 | mask: word} ref |
| 23 | |
| 24 | fun stats' {buckets, mask} |
| 25 | = Vector.fold |
| 26 | (buckets, |
| 27 | (0, NONE, NONE), |
| 28 | fn (s', (size, min, max)) => let |
| 29 | val n = Set.size s' |
| 30 | in |
| 31 | (size + n, |
| 32 | SOME (Option.fold(min,n,Int.min)), |
| 33 | SOME (Option.fold(max,n,Int.max))) |
| 34 | end) |
| 35 | fun stats s |
| 36 | = let |
| 37 | val T (ref {buckets, mask}) = s |
| 38 | in |
| 39 | stats' {buckets = buckets, mask = mask} |
| 40 | end |
| 41 | |
| 42 | fun grow {buckets, mask} |
| 43 | = let |
| 44 | val mask' = mask |
| 45 | val mask = Word.orb (0wx1, Word.<<(mask, 0wx1)) |
| 46 | val high = Word.andb (mask, Word.notb mask') |
| 47 | |
| 48 | val n = Vector.length buckets |
| 49 | |
| 50 | val buckets |
| 51 | = (#1 o Vector.unfoldi) |
| 52 | (2 * n, |
| 53 | ([], false), |
| 54 | fn (i, (l, b)) |
| 55 | => if b |
| 56 | then case l |
| 57 | of h::t => (h, (t, b)) |
| 58 | | _ => Error.bug "HashedUniqueSet.grow" |
| 59 | else if i = n |
| 60 | then case List.rev l |
| 61 | of h::t => (h, (t, true)) |
| 62 | | _ => Error.bug "HashedUniqueSet.grow" |
| 63 | else let |
| 64 | val {yes, no} |
| 65 | = Set.partition |
| 66 | (Vector.sub(buckets, i), |
| 67 | fn x => Word.andb(high, hash x) = 0wx0) |
| 68 | in |
| 69 | (yes, (no::l, b)) |
| 70 | end) |
| 71 | in |
| 72 | {buckets = buckets, mask = mask} |
| 73 | end |
| 74 | |
| 75 | fun shrink {buckets, mask} |
| 76 | = let |
| 77 | val mask = Word.>>(mask, 0wx1) |
| 78 | |
| 79 | val n = (Vector.length buckets) div 2 |
| 80 | |
| 81 | val buckets |
| 82 | = (#1 o Vector.unfoldi) |
| 83 | (n, |
| 84 | (), |
| 85 | fn (i, _) => let |
| 86 | val s1 = Vector.sub(buckets, i) |
| 87 | val s2 = Vector.sub(buckets, i + n) |
| 88 | in |
| 89 | (Set.+(s1, s2), ()) |
| 90 | end) |
| 91 | in |
| 92 | {buckets = buckets, mask = mask} |
| 93 | end |
| 94 | |
| 95 | fun T' {buckets, mask} |
| 96 | = let |
| 97 | val (size,min,max) = stats' {buckets = buckets, mask = mask} |
| 98 | val max = case max of SOME max => max | NONE => ~1 |
| 99 | val n = Vector.length buckets |
| 100 | in |
| 101 | if max > n |
| 102 | then T (ref (grow {buckets = buckets, mask = mask})) |
| 103 | else if max < n div 2 andalso n > 2 |
| 104 | then T (ref (shrink {buckets = buckets, mask = mask})) |
| 105 | else T (ref {buckets = buckets, mask = mask}) |
| 106 | end |
| 107 | |
| 108 | fun coerce (s1 as T (s1' as ref (s1'' as {buckets = buckets1, mask = mask1})), |
| 109 | s2 as T (s2' as ref (s2'' as {buckets = buckets2, mask = mask2}))) |
| 110 | = if mask1 = mask2 |
| 111 | then () |
| 112 | else if mask1 < mask2 |
| 113 | then (s1' := grow s1''; |
| 114 | coerce (s1, s2)) |
| 115 | else (s2' := grow s2''; |
| 116 | coerce (s1, s2)) |
| 117 | |
| 118 | |
| 119 | val empty |
| 120 | = let |
| 121 | val mask = 0wx1 |
| 122 | val buckets = Vector.new2 (Set.empty, Set.empty) |
| 123 | in |
| 124 | T (ref {buckets = buckets, |
| 125 | mask = mask}) |
| 126 | end |
| 127 | fun singleton x |
| 128 | = let |
| 129 | val mask = 0wx1 |
| 130 | val buckets |
| 131 | = if Word.andb(mask, hash x) = 0wx0 |
| 132 | then Vector.new2 (Set.singleton x, Set.empty) |
| 133 | else Vector.new2 (Set.empty, Set.singleton x) |
| 134 | in |
| 135 | T (ref {buckets = buckets, |
| 136 | mask = mask}) |
| 137 | end |
| 138 | |
| 139 | |
| 140 | fun walk1 (vw, sw) s |
| 141 | = let |
| 142 | val T (ref {buckets, mask}) = s |
| 143 | in |
| 144 | vw(buckets, fn s' => sw s') |
| 145 | end |
| 146 | fun walk2 (vw, sw) (s1, s2) |
| 147 | = let |
| 148 | val _ = coerce (s1, s2) |
| 149 | val T (ref {buckets = buckets1, mask}) = s1 |
| 150 | val T (ref {buckets = buckets2, mask}) = s2 |
| 151 | in |
| 152 | vw(buckets1, buckets2, fn (s1', s2') => sw (s1', s2')) |
| 153 | end |
| 154 | |
| 155 | val areDisjoint = walk2 (Vector.forall2, Set.areDisjoint) |
| 156 | val equals = walk2 (Vector.forall2, Set.equals) |
| 157 | fun exists (s, p) = walk1 (Vector.exists, fn s' => Set.exists(s', p)) s |
| 158 | fun forall (s, p) = walk1 (Vector.forall, fn s' => Set.forall(s', p)) s |
| 159 | fun foreach (s, f) = walk1 (Vector.foreach, fn s' => Set.foreach(s', f)) s |
| 160 | |
| 161 | fun build1 sb s |
| 162 | = let |
| 163 | val T (ref {buckets, mask}) = s |
| 164 | |
| 165 | val buckets |
| 166 | = (#1 o Vector.unfoldi) |
| 167 | (Vector.length buckets, |
| 168 | (), |
| 169 | fn (i, _) => let |
| 170 | val s' = Vector.sub(buckets, i) |
| 171 | in |
| 172 | (sb s', ()) |
| 173 | end) |
| 174 | in |
| 175 | T' {buckets = buckets, mask = mask} |
| 176 | end |
| 177 | fun build2 sb (s1, s2) |
| 178 | = let |
| 179 | val _ = coerce (s1, s2) |
| 180 | val T (ref {buckets = buckets1, mask}) = s1 |
| 181 | val T (ref {buckets = buckets2, mask}) = s2 |
| 182 | |
| 183 | val buckets |
| 184 | = (#1 o Vector.unfoldi) |
| 185 | (Vector.length buckets1, |
| 186 | (), |
| 187 | fn (i, _) => let |
| 188 | val s1' = Vector.sub(buckets1, i) |
| 189 | val s2' = Vector.sub(buckets2, i) |
| 190 | in |
| 191 | (sb(s1', s2'), ()) |
| 192 | end) |
| 193 | in |
| 194 | T' {buckets = buckets, mask = mask} |
| 195 | end |
| 196 | |
| 197 | val difference = build2 Set.- |
| 198 | val intersect = build2 Set.intersect |
| 199 | fun subset (s, p) = build1 (fn s' => Set.subset(s', p)) s |
| 200 | val union = build2 Set.+ |
| 201 | fun unions [] = empty |
| 202 | | unions [s] = s |
| 203 | | unions [s1,s2] = union(s1, s2) |
| 204 | | unions (s1::s2::ss) = unions(union(s1,s2)::ss) |
| 205 | |
| 206 | |
| 207 | fun contains (s, x) |
| 208 | = let |
| 209 | val T (ref {buckets, mask}) = s |
| 210 | in |
| 211 | Set.contains(Vector.sub(buckets, index(hash x, mask)), x) |
| 212 | end |
| 213 | fun add (s, x) |
| 214 | = if contains(s, x) |
| 215 | then s |
| 216 | else let |
| 217 | val T (ref {buckets, mask}) = s |
| 218 | val ix = index(hash x, mask) |
| 219 | val buckets |
| 220 | = (#1 o Vector.unfoldi) |
| 221 | (Vector.length buckets, |
| 222 | (), |
| 223 | fn (i, _) |
| 224 | => let |
| 225 | val s' = Vector.sub(buckets, i) |
| 226 | in |
| 227 | if i = ix |
| 228 | then (Set.add(s', x), ()) |
| 229 | else (s', ()) |
| 230 | end) |
| 231 | in |
| 232 | T' {buckets = buckets, |
| 233 | mask = mask} |
| 234 | end |
| 235 | fun remove (s, x) |
| 236 | = if not (contains(s, x)) |
| 237 | then s |
| 238 | else let |
| 239 | val T (ref {buckets, mask}) = s |
| 240 | val ix = index(hash x, mask) |
| 241 | val buckets |
| 242 | = (#1 o Vector.unfoldi) |
| 243 | (Vector.length buckets, |
| 244 | (), |
| 245 | fn (i, _) |
| 246 | => let |
| 247 | val s' = Vector.sub(buckets, i) |
| 248 | in |
| 249 | if i = ix |
| 250 | then (Set.remove(s', x), ()) |
| 251 | else (s', ()) |
| 252 | end) |
| 253 | in |
| 254 | T' {buckets = buckets, |
| 255 | mask = mask} |
| 256 | end |
| 257 | fun partition (s, p) |
| 258 | = let |
| 259 | val T (ref {buckets, mask}) = s |
| 260 | val n = Vector.length buckets |
| 261 | val {yes, no} |
| 262 | = Vector.fold |
| 263 | (buckets, |
| 264 | {yes = [], no = []}, |
| 265 | fn (s', {yes, no}) |
| 266 | => let |
| 267 | val {yes = yes', no = no'} = Set.partition (s', p) |
| 268 | in |
| 269 | {yes = yes'::yes, |
| 270 | no = no'::no} |
| 271 | end) |
| 272 | val yes |
| 273 | = (#1 o Vector.unfoldi) |
| 274 | (n, |
| 275 | List.rev yes, |
| 276 | fn (_, l) => case l |
| 277 | of h::t => (h, t) |
| 278 | | _ => Error.bug "HashedUniqueSet.partition.yes") |
| 279 | val no |
| 280 | = (#1 o Vector.unfoldi) |
| 281 | (n, |
| 282 | List.rev no, |
| 283 | fn (_, l) => case l |
| 284 | of h::t => (h, t) |
| 285 | | _ => Error.bug "HashedUniqueSet.partition.no") |
| 286 | in |
| 287 | {yes = T' {buckets = yes, mask = mask}, |
| 288 | no = T' {buckets = no, mask = mask}} |
| 289 | end |
| 290 | |
| 291 | |
| 292 | fun fold (s, b, f) |
| 293 | = let |
| 294 | val T (ref {buckets, mask}) = s |
| 295 | in |
| 296 | Vector.fold |
| 297 | (buckets, |
| 298 | b, |
| 299 | fn (s', b) => Set.fold(s', b, f)) |
| 300 | end |
| 301 | |
| 302 | fun fromList l = List.fold(l, empty, fn (x, s) => add(s, x)) |
| 303 | fun toList s = fold(s, [], op ::) |
| 304 | fun map (s, f) = fold(s, empty, fn (x, s) => add(s, f x)) |
| 305 | fun replace (s, f) |
| 306 | = fold(s, empty, fn (x, s) => case f x |
| 307 | of NONE => s |
| 308 | | SOME x' => add(s, x')) |
| 309 | fun subsetSize (s, p) |
| 310 | = fold(s, 0: int, fn (x, n) => if p x then n + 1 else n) |
| 311 | fun size s = subsetSize(s, fn _ => true) |
| 312 | |
| 313 | |
| 314 | fun layout s = List.layout Element.layout (toList s) |
| 315 | |
| 316 | fun power s = Error.bug "HashedUniqueSet.power" |
| 317 | fun subsets (s, n) = Error.bug "HashedUniqueSet.subsets" |
| 318 | |
| 319 | fun isEmpty s = size s = 0 |
| 320 | fun isSubsetEq (s1, s2) = size (difference (s1, s2)) = 0 |
| 321 | fun isSubset (s1, s2) = (size s1 <> size s2) andalso isSubsetEq(s1, s2) |
| 322 | fun isSupersetEq (s1, s2) = isSubsetEq(s2, s1) |
| 323 | fun isSuperset (s1, s2) = isSubset(s2, s1) |
| 324 | |
| 325 | val op + = union |
| 326 | val op - = difference |
| 327 | val op < = isSubset |
| 328 | val op <= = isSubsetEq |
| 329 | val op > = isSuperset |
| 330 | val op >= = isSupersetEq |
| 331 | |
| 332 | end |