Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* Copyright (C) 2009 Matthew Fluet. |
2 | * Copyright (C) 1999-2007 Henry Cejtin, Matthew Fluet, Suresh | |
3 | * Jagannathan, and Stephen Weeks. | |
4 | * | |
5 | * MLton is released under a BSD-style license. | |
6 | * See the file MLton-LICENSE for details. | |
7 | *) | |
8 | ||
9 | functor HashedUniqueSet(structure Set : SET | |
10 | structure Element : sig include T val hash : t -> word end | |
11 | sharing type Set.Element.t = Element.t) : SET = | |
12 | struct | |
13 | ||
14 | structure Set = Set | |
15 | structure Element = Element | |
16 | val hash = Element.hash | |
17 | ||
18 | fun index (w: word, mask: word): int | |
19 | = Word.toInt (Word.andb (w, mask)) | |
20 | ||
21 | datatype t = T of {buckets: Set.t vector, | |
22 | mask: word} ref | |
23 | ||
24 | fun stats' {buckets, mask} | |
25 | = Vector.fold | |
26 | (buckets, | |
27 | (0, NONE, NONE), | |
28 | fn (s', (size, min, max)) => let | |
29 | val n = Set.size s' | |
30 | in | |
31 | (size + n, | |
32 | SOME (Option.fold(min,n,Int.min)), | |
33 | SOME (Option.fold(max,n,Int.max))) | |
34 | end) | |
35 | fun stats s | |
36 | = let | |
37 | val T (ref {buckets, mask}) = s | |
38 | in | |
39 | stats' {buckets = buckets, mask = mask} | |
40 | end | |
41 | ||
42 | fun grow {buckets, mask} | |
43 | = let | |
44 | val mask' = mask | |
45 | val mask = Word.orb (0wx1, Word.<<(mask, 0wx1)) | |
46 | val high = Word.andb (mask, Word.notb mask') | |
47 | ||
48 | val n = Vector.length buckets | |
49 | ||
50 | val buckets | |
51 | = (#1 o Vector.unfoldi) | |
52 | (2 * n, | |
53 | ([], false), | |
54 | fn (i, (l, b)) | |
55 | => if b | |
56 | then case l | |
57 | of h::t => (h, (t, b)) | |
58 | | _ => Error.bug "HashedUniqueSet.grow" | |
59 | else if i = n | |
60 | then case List.rev l | |
61 | of h::t => (h, (t, true)) | |
62 | | _ => Error.bug "HashedUniqueSet.grow" | |
63 | else let | |
64 | val {yes, no} | |
65 | = Set.partition | |
66 | (Vector.sub(buckets, i), | |
67 | fn x => Word.andb(high, hash x) = 0wx0) | |
68 | in | |
69 | (yes, (no::l, b)) | |
70 | end) | |
71 | in | |
72 | {buckets = buckets, mask = mask} | |
73 | end | |
74 | ||
75 | fun shrink {buckets, mask} | |
76 | = let | |
77 | val mask = Word.>>(mask, 0wx1) | |
78 | ||
79 | val n = (Vector.length buckets) div 2 | |
80 | ||
81 | val buckets | |
82 | = (#1 o Vector.unfoldi) | |
83 | (n, | |
84 | (), | |
85 | fn (i, _) => let | |
86 | val s1 = Vector.sub(buckets, i) | |
87 | val s2 = Vector.sub(buckets, i + n) | |
88 | in | |
89 | (Set.+(s1, s2), ()) | |
90 | end) | |
91 | in | |
92 | {buckets = buckets, mask = mask} | |
93 | end | |
94 | ||
95 | fun T' {buckets, mask} | |
96 | = let | |
97 | val (size,min,max) = stats' {buckets = buckets, mask = mask} | |
98 | val max = case max of SOME max => max | NONE => ~1 | |
99 | val n = Vector.length buckets | |
100 | in | |
101 | if max > n | |
102 | then T (ref (grow {buckets = buckets, mask = mask})) | |
103 | else if max < n div 2 andalso n > 2 | |
104 | then T (ref (shrink {buckets = buckets, mask = mask})) | |
105 | else T (ref {buckets = buckets, mask = mask}) | |
106 | end | |
107 | ||
108 | fun coerce (s1 as T (s1' as ref (s1'' as {buckets = buckets1, mask = mask1})), | |
109 | s2 as T (s2' as ref (s2'' as {buckets = buckets2, mask = mask2}))) | |
110 | = if mask1 = mask2 | |
111 | then () | |
112 | else if mask1 < mask2 | |
113 | then (s1' := grow s1''; | |
114 | coerce (s1, s2)) | |
115 | else (s2' := grow s2''; | |
116 | coerce (s1, s2)) | |
117 | ||
118 | ||
119 | val empty | |
120 | = let | |
121 | val mask = 0wx1 | |
122 | val buckets = Vector.new2 (Set.empty, Set.empty) | |
123 | in | |
124 | T (ref {buckets = buckets, | |
125 | mask = mask}) | |
126 | end | |
127 | fun singleton x | |
128 | = let | |
129 | val mask = 0wx1 | |
130 | val buckets | |
131 | = if Word.andb(mask, hash x) = 0wx0 | |
132 | then Vector.new2 (Set.singleton x, Set.empty) | |
133 | else Vector.new2 (Set.empty, Set.singleton x) | |
134 | in | |
135 | T (ref {buckets = buckets, | |
136 | mask = mask}) | |
137 | end | |
138 | ||
139 | ||
140 | fun walk1 (vw, sw) s | |
141 | = let | |
142 | val T (ref {buckets, mask}) = s | |
143 | in | |
144 | vw(buckets, fn s' => sw s') | |
145 | end | |
146 | fun walk2 (vw, sw) (s1, s2) | |
147 | = let | |
148 | val _ = coerce (s1, s2) | |
149 | val T (ref {buckets = buckets1, mask}) = s1 | |
150 | val T (ref {buckets = buckets2, mask}) = s2 | |
151 | in | |
152 | vw(buckets1, buckets2, fn (s1', s2') => sw (s1', s2')) | |
153 | end | |
154 | ||
155 | val areDisjoint = walk2 (Vector.forall2, Set.areDisjoint) | |
156 | val equals = walk2 (Vector.forall2, Set.equals) | |
157 | fun exists (s, p) = walk1 (Vector.exists, fn s' => Set.exists(s', p)) s | |
158 | fun forall (s, p) = walk1 (Vector.forall, fn s' => Set.forall(s', p)) s | |
159 | fun foreach (s, f) = walk1 (Vector.foreach, fn s' => Set.foreach(s', f)) s | |
160 | ||
161 | fun build1 sb s | |
162 | = let | |
163 | val T (ref {buckets, mask}) = s | |
164 | ||
165 | val buckets | |
166 | = (#1 o Vector.unfoldi) | |
167 | (Vector.length buckets, | |
168 | (), | |
169 | fn (i, _) => let | |
170 | val s' = Vector.sub(buckets, i) | |
171 | in | |
172 | (sb s', ()) | |
173 | end) | |
174 | in | |
175 | T' {buckets = buckets, mask = mask} | |
176 | end | |
177 | fun build2 sb (s1, s2) | |
178 | = let | |
179 | val _ = coerce (s1, s2) | |
180 | val T (ref {buckets = buckets1, mask}) = s1 | |
181 | val T (ref {buckets = buckets2, mask}) = s2 | |
182 | ||
183 | val buckets | |
184 | = (#1 o Vector.unfoldi) | |
185 | (Vector.length buckets1, | |
186 | (), | |
187 | fn (i, _) => let | |
188 | val s1' = Vector.sub(buckets1, i) | |
189 | val s2' = Vector.sub(buckets2, i) | |
190 | in | |
191 | (sb(s1', s2'), ()) | |
192 | end) | |
193 | in | |
194 | T' {buckets = buckets, mask = mask} | |
195 | end | |
196 | ||
197 | val difference = build2 Set.- | |
198 | val intersect = build2 Set.intersect | |
199 | fun subset (s, p) = build1 (fn s' => Set.subset(s', p)) s | |
200 | val union = build2 Set.+ | |
201 | fun unions [] = empty | |
202 | | unions [s] = s | |
203 | | unions [s1,s2] = union(s1, s2) | |
204 | | unions (s1::s2::ss) = unions(union(s1,s2)::ss) | |
205 | ||
206 | ||
207 | fun contains (s, x) | |
208 | = let | |
209 | val T (ref {buckets, mask}) = s | |
210 | in | |
211 | Set.contains(Vector.sub(buckets, index(hash x, mask)), x) | |
212 | end | |
213 | fun add (s, x) | |
214 | = if contains(s, x) | |
215 | then s | |
216 | else let | |
217 | val T (ref {buckets, mask}) = s | |
218 | val ix = index(hash x, mask) | |
219 | val buckets | |
220 | = (#1 o Vector.unfoldi) | |
221 | (Vector.length buckets, | |
222 | (), | |
223 | fn (i, _) | |
224 | => let | |
225 | val s' = Vector.sub(buckets, i) | |
226 | in | |
227 | if i = ix | |
228 | then (Set.add(s', x), ()) | |
229 | else (s', ()) | |
230 | end) | |
231 | in | |
232 | T' {buckets = buckets, | |
233 | mask = mask} | |
234 | end | |
235 | fun remove (s, x) | |
236 | = if not (contains(s, x)) | |
237 | then s | |
238 | else let | |
239 | val T (ref {buckets, mask}) = s | |
240 | val ix = index(hash x, mask) | |
241 | val buckets | |
242 | = (#1 o Vector.unfoldi) | |
243 | (Vector.length buckets, | |
244 | (), | |
245 | fn (i, _) | |
246 | => let | |
247 | val s' = Vector.sub(buckets, i) | |
248 | in | |
249 | if i = ix | |
250 | then (Set.remove(s', x), ()) | |
251 | else (s', ()) | |
252 | end) | |
253 | in | |
254 | T' {buckets = buckets, | |
255 | mask = mask} | |
256 | end | |
257 | fun partition (s, p) | |
258 | = let | |
259 | val T (ref {buckets, mask}) = s | |
260 | val n = Vector.length buckets | |
261 | val {yes, no} | |
262 | = Vector.fold | |
263 | (buckets, | |
264 | {yes = [], no = []}, | |
265 | fn (s', {yes, no}) | |
266 | => let | |
267 | val {yes = yes', no = no'} = Set.partition (s', p) | |
268 | in | |
269 | {yes = yes'::yes, | |
270 | no = no'::no} | |
271 | end) | |
272 | val yes | |
273 | = (#1 o Vector.unfoldi) | |
274 | (n, | |
275 | List.rev yes, | |
276 | fn (_, l) => case l | |
277 | of h::t => (h, t) | |
278 | | _ => Error.bug "HashedUniqueSet.partition.yes") | |
279 | val no | |
280 | = (#1 o Vector.unfoldi) | |
281 | (n, | |
282 | List.rev no, | |
283 | fn (_, l) => case l | |
284 | of h::t => (h, t) | |
285 | | _ => Error.bug "HashedUniqueSet.partition.no") | |
286 | in | |
287 | {yes = T' {buckets = yes, mask = mask}, | |
288 | no = T' {buckets = no, mask = mask}} | |
289 | end | |
290 | ||
291 | ||
292 | fun fold (s, b, f) | |
293 | = let | |
294 | val T (ref {buckets, mask}) = s | |
295 | in | |
296 | Vector.fold | |
297 | (buckets, | |
298 | b, | |
299 | fn (s', b) => Set.fold(s', b, f)) | |
300 | end | |
301 | ||
302 | fun fromList l = List.fold(l, empty, fn (x, s) => add(s, x)) | |
303 | fun toList s = fold(s, [], op ::) | |
304 | fun map (s, f) = fold(s, empty, fn (x, s) => add(s, f x)) | |
305 | fun replace (s, f) | |
306 | = fold(s, empty, fn (x, s) => case f x | |
307 | of NONE => s | |
308 | | SOME x' => add(s, x')) | |
309 | fun subsetSize (s, p) | |
310 | = fold(s, 0: int, fn (x, n) => if p x then n + 1 else n) | |
311 | fun size s = subsetSize(s, fn _ => true) | |
312 | ||
313 | ||
314 | fun layout s = List.layout Element.layout (toList s) | |
315 | ||
316 | fun power s = Error.bug "HashedUniqueSet.power" | |
317 | fun subsets (s, n) = Error.bug "HashedUniqueSet.subsets" | |
318 | ||
319 | fun isEmpty s = size s = 0 | |
320 | fun isSubsetEq (s1, s2) = size (difference (s1, s2)) = 0 | |
321 | fun isSubset (s1, s2) = (size s1 <> size s2) andalso isSubsetEq(s1, s2) | |
322 | fun isSupersetEq (s1, s2) = isSubsetEq(s2, s1) | |
323 | fun isSuperset (s1, s2) = isSubset(s2, s1) | |
324 | ||
325 | val op + = union | |
326 | val op - = difference | |
327 | val op < = isSubset | |
328 | val op <= = isSubsetEq | |
329 | val op > = isSuperset | |
330 | val op >= = isSupersetEq | |
331 | ||
332 | end |