Import Upstream version 20180207
[hcoop/debian/mlton.git] / lib / mlton / set / hashed-unique-set.fun
1 (* Copyright (C) 2009 Matthew Fluet.
2 * Copyright (C) 1999-2007 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 *
5 * MLton is released under a BSD-style license.
6 * See the file MLton-LICENSE for details.
7 *)
8
9 functor HashedUniqueSet(structure Set : SET
10 structure Element : sig include T val hash : t -> word end
11 sharing type Set.Element.t = Element.t) : SET =
12 struct
13
14 structure Set = Set
15 structure Element = Element
16 val hash = Element.hash
17
18 fun index (w: word, mask: word): int
19 = Word.toInt (Word.andb (w, mask))
20
21 datatype t = T of {buckets: Set.t vector,
22 mask: word} ref
23
24 fun stats' {buckets, mask}
25 = Vector.fold
26 (buckets,
27 (0, NONE, NONE),
28 fn (s', (size, min, max)) => let
29 val n = Set.size s'
30 in
31 (size + n,
32 SOME (Option.fold(min,n,Int.min)),
33 SOME (Option.fold(max,n,Int.max)))
34 end)
35 fun stats s
36 = let
37 val T (ref {buckets, mask}) = s
38 in
39 stats' {buckets = buckets, mask = mask}
40 end
41
42 fun grow {buckets, mask}
43 = let
44 val mask' = mask
45 val mask = Word.orb (0wx1, Word.<<(mask, 0wx1))
46 val high = Word.andb (mask, Word.notb mask')
47
48 val n = Vector.length buckets
49
50 val buckets
51 = (#1 o Vector.unfoldi)
52 (2 * n,
53 ([], false),
54 fn (i, (l, b))
55 => if b
56 then case l
57 of h::t => (h, (t, b))
58 | _ => Error.bug "HashedUniqueSet.grow"
59 else if i = n
60 then case List.rev l
61 of h::t => (h, (t, true))
62 | _ => Error.bug "HashedUniqueSet.grow"
63 else let
64 val {yes, no}
65 = Set.partition
66 (Vector.sub(buckets, i),
67 fn x => Word.andb(high, hash x) = 0wx0)
68 in
69 (yes, (no::l, b))
70 end)
71 in
72 {buckets = buckets, mask = mask}
73 end
74
75 fun shrink {buckets, mask}
76 = let
77 val mask = Word.>>(mask, 0wx1)
78
79 val n = (Vector.length buckets) div 2
80
81 val buckets
82 = (#1 o Vector.unfoldi)
83 (n,
84 (),
85 fn (i, _) => let
86 val s1 = Vector.sub(buckets, i)
87 val s2 = Vector.sub(buckets, i + n)
88 in
89 (Set.+(s1, s2), ())
90 end)
91 in
92 {buckets = buckets, mask = mask}
93 end
94
95 fun T' {buckets, mask}
96 = let
97 val (size,min,max) = stats' {buckets = buckets, mask = mask}
98 val max = case max of SOME max => max | NONE => ~1
99 val n = Vector.length buckets
100 in
101 if max > n
102 then T (ref (grow {buckets = buckets, mask = mask}))
103 else if max < n div 2 andalso n > 2
104 then T (ref (shrink {buckets = buckets, mask = mask}))
105 else T (ref {buckets = buckets, mask = mask})
106 end
107
108 fun coerce (s1 as T (s1' as ref (s1'' as {buckets = buckets1, mask = mask1})),
109 s2 as T (s2' as ref (s2'' as {buckets = buckets2, mask = mask2})))
110 = if mask1 = mask2
111 then ()
112 else if mask1 < mask2
113 then (s1' := grow s1'';
114 coerce (s1, s2))
115 else (s2' := grow s2'';
116 coerce (s1, s2))
117
118
119 val empty
120 = let
121 val mask = 0wx1
122 val buckets = Vector.new2 (Set.empty, Set.empty)
123 in
124 T (ref {buckets = buckets,
125 mask = mask})
126 end
127 fun singleton x
128 = let
129 val mask = 0wx1
130 val buckets
131 = if Word.andb(mask, hash x) = 0wx0
132 then Vector.new2 (Set.singleton x, Set.empty)
133 else Vector.new2 (Set.empty, Set.singleton x)
134 in
135 T (ref {buckets = buckets,
136 mask = mask})
137 end
138
139
140 fun walk1 (vw, sw) s
141 = let
142 val T (ref {buckets, mask}) = s
143 in
144 vw(buckets, fn s' => sw s')
145 end
146 fun walk2 (vw, sw) (s1, s2)
147 = let
148 val _ = coerce (s1, s2)
149 val T (ref {buckets = buckets1, mask}) = s1
150 val T (ref {buckets = buckets2, mask}) = s2
151 in
152 vw(buckets1, buckets2, fn (s1', s2') => sw (s1', s2'))
153 end
154
155 val areDisjoint = walk2 (Vector.forall2, Set.areDisjoint)
156 val equals = walk2 (Vector.forall2, Set.equals)
157 fun exists (s, p) = walk1 (Vector.exists, fn s' => Set.exists(s', p)) s
158 fun forall (s, p) = walk1 (Vector.forall, fn s' => Set.forall(s', p)) s
159 fun foreach (s, f) = walk1 (Vector.foreach, fn s' => Set.foreach(s', f)) s
160
161 fun build1 sb s
162 = let
163 val T (ref {buckets, mask}) = s
164
165 val buckets
166 = (#1 o Vector.unfoldi)
167 (Vector.length buckets,
168 (),
169 fn (i, _) => let
170 val s' = Vector.sub(buckets, i)
171 in
172 (sb s', ())
173 end)
174 in
175 T' {buckets = buckets, mask = mask}
176 end
177 fun build2 sb (s1, s2)
178 = let
179 val _ = coerce (s1, s2)
180 val T (ref {buckets = buckets1, mask}) = s1
181 val T (ref {buckets = buckets2, mask}) = s2
182
183 val buckets
184 = (#1 o Vector.unfoldi)
185 (Vector.length buckets1,
186 (),
187 fn (i, _) => let
188 val s1' = Vector.sub(buckets1, i)
189 val s2' = Vector.sub(buckets2, i)
190 in
191 (sb(s1', s2'), ())
192 end)
193 in
194 T' {buckets = buckets, mask = mask}
195 end
196
197 val difference = build2 Set.-
198 val intersect = build2 Set.intersect
199 fun subset (s, p) = build1 (fn s' => Set.subset(s', p)) s
200 val union = build2 Set.+
201 fun unions [] = empty
202 | unions [s] = s
203 | unions [s1,s2] = union(s1, s2)
204 | unions (s1::s2::ss) = unions(union(s1,s2)::ss)
205
206
207 fun contains (s, x)
208 = let
209 val T (ref {buckets, mask}) = s
210 in
211 Set.contains(Vector.sub(buckets, index(hash x, mask)), x)
212 end
213 fun add (s, x)
214 = if contains(s, x)
215 then s
216 else let
217 val T (ref {buckets, mask}) = s
218 val ix = index(hash x, mask)
219 val buckets
220 = (#1 o Vector.unfoldi)
221 (Vector.length buckets,
222 (),
223 fn (i, _)
224 => let
225 val s' = Vector.sub(buckets, i)
226 in
227 if i = ix
228 then (Set.add(s', x), ())
229 else (s', ())
230 end)
231 in
232 T' {buckets = buckets,
233 mask = mask}
234 end
235 fun remove (s, x)
236 = if not (contains(s, x))
237 then s
238 else let
239 val T (ref {buckets, mask}) = s
240 val ix = index(hash x, mask)
241 val buckets
242 = (#1 o Vector.unfoldi)
243 (Vector.length buckets,
244 (),
245 fn (i, _)
246 => let
247 val s' = Vector.sub(buckets, i)
248 in
249 if i = ix
250 then (Set.remove(s', x), ())
251 else (s', ())
252 end)
253 in
254 T' {buckets = buckets,
255 mask = mask}
256 end
257 fun partition (s, p)
258 = let
259 val T (ref {buckets, mask}) = s
260 val n = Vector.length buckets
261 val {yes, no}
262 = Vector.fold
263 (buckets,
264 {yes = [], no = []},
265 fn (s', {yes, no})
266 => let
267 val {yes = yes', no = no'} = Set.partition (s', p)
268 in
269 {yes = yes'::yes,
270 no = no'::no}
271 end)
272 val yes
273 = (#1 o Vector.unfoldi)
274 (n,
275 List.rev yes,
276 fn (_, l) => case l
277 of h::t => (h, t)
278 | _ => Error.bug "HashedUniqueSet.partition.yes")
279 val no
280 = (#1 o Vector.unfoldi)
281 (n,
282 List.rev no,
283 fn (_, l) => case l
284 of h::t => (h, t)
285 | _ => Error.bug "HashedUniqueSet.partition.no")
286 in
287 {yes = T' {buckets = yes, mask = mask},
288 no = T' {buckets = no, mask = mask}}
289 end
290
291
292 fun fold (s, b, f)
293 = let
294 val T (ref {buckets, mask}) = s
295 in
296 Vector.fold
297 (buckets,
298 b,
299 fn (s', b) => Set.fold(s', b, f))
300 end
301
302 fun fromList l = List.fold(l, empty, fn (x, s) => add(s, x))
303 fun toList s = fold(s, [], op ::)
304 fun map (s, f) = fold(s, empty, fn (x, s) => add(s, f x))
305 fun replace (s, f)
306 = fold(s, empty, fn (x, s) => case f x
307 of NONE => s
308 | SOME x' => add(s, x'))
309 fun subsetSize (s, p)
310 = fold(s, 0: int, fn (x, n) => if p x then n + 1 else n)
311 fun size s = subsetSize(s, fn _ => true)
312
313
314 fun layout s = List.layout Element.layout (toList s)
315
316 fun power s = Error.bug "HashedUniqueSet.power"
317 fun subsets (s, n) = Error.bug "HashedUniqueSet.subsets"
318
319 fun isEmpty s = size s = 0
320 fun isSubsetEq (s1, s2) = size (difference (s1, s2)) = 0
321 fun isSubset (s1, s2) = (size s1 <> size s2) andalso isSubsetEq(s1, s2)
322 fun isSupersetEq (s1, s2) = isSubsetEq(s2, s1)
323 fun isSuperset (s1, s2) = isSubset(s2, s1)
324
325 val op + = union
326 val op - = difference
327 val op < = isSubset
328 val op <= = isSubsetEq
329 val op > = isSuperset
330 val op >= = isSupersetEq
331
332 end