Import Upstream version 20180207
[hcoop/debian/mlton.git] / lib / mlton / basic / unique-set.fun
1 (* Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 *
4 * MLton is released under a BSD-style license.
5 * See the file MLton-LICENSE for details.
6 *)
7
8 structure UniqueSetRep =
9 struct
10 datatype 'a t = T of {elements: 'a list,
11 plist: PropertyList.t}
12
13 end
14
15 functor UniqueSet (S: UNIQUE_SET_STRUCTS): UNIQUE_SET =
16 struct
17
18 open S
19
20 val _ = Assert.assert ("UniqueSet: cacheSize, bits", fn () =>
21 cacheSize >= 1 andalso bits >= 1)
22
23 type elements = Element.t list
24
25 structure Tree: sig
26 structure Set:
27 sig
28 type t
29
30 val equals: t * t -> bool
31 val toList: t -> elements
32 val plist: t -> PropertyList.t
33 end
34
35 type t
36
37 val new: unit -> t
38 val insert: t * elements -> Set.t
39 val size: t -> int
40 end =
41 struct
42 structure Set =
43 struct
44 open UniqueSetRep
45 type t = Element.t t
46
47 fun new elements = T {elements = elements,
48 plist = PropertyList.new()}
49
50 fun elements (T {elements, ...}) = elements
51 fun plist (T {plist, ...}) = plist
52
53 val toList = elements
54
55 fun equals (s, s') = PropertyList.equals (plist s, plist s')
56 end
57
58 datatype node =
59 Node of {element: Element.t,
60 isIn: t,
61 isNotIn: t}
62 | Leaf of Set.t
63 withtype t = node option ref
64
65 fun new(): t = ref NONE
66
67 fun size(t: t): int =
68 case !t of
69 NONE => 0
70 | SOME(Leaf _) => 1
71 | SOME(Node{isIn, isNotIn, ...}) => size isIn + size isNotIn
72
73 fun contains(es, e) = List.exists(es, fn e' => Element.equals(e, e'))
74
75 fun insert(tree, elements) =
76 let
77 fun loop tree =
78 case !tree of
79 NONE => let val s = Set.new elements
80 in tree := SOME(Leaf s); s
81 end
82 | SOME(Node{element, isIn, isNotIn}) =>
83 if contains(elements, element)
84 then loop isIn
85 else loop isNotIn
86 | SOME(Leaf s') =>
87 let
88 fun loop arg =
89 case arg of
90 ([], []) => s' (* same set *)
91 | ([], x' :: _) =>
92 let val s = Set.new elements
93 in tree :=
94 SOME(Node{element = x',
95 isIn = ref(SOME(Leaf s')),
96 isNotIn = ref(SOME(Leaf s))})
97 ; s
98 end
99 | (x :: xs, xs') =>
100 let
101 fun loop2(xs', accum) =
102 case xs' of
103 [] =>
104 let val s = Set.new elements
105 in tree :=
106 SOME(Node{element = x,
107 isIn = ref(SOME(Leaf s)),
108 isNotIn =
109 ref(SOME(Leaf s'))})
110 ; s
111 end
112 | x' :: xs' =>
113 if Element.equals(x, x')
114 then loop(xs, accum @ xs')
115 else loop2(xs', x' :: accum)
116 in loop2(xs', [])
117 end
118 in loop(elements, Set.elements s')
119 end
120 in loop tree
121 end
122
123 end
124
125 open Tree.Set
126
127 val tableSize = Int.pow (2, bits)
128
129 val maxIndex = tableSize - 1
130
131 val mask = Word.fromInt maxIndex
132
133 val table = Array.tabulate(tableSize, fn _ => Tree.new())
134
135 fun hashToIndex(w: Word.t): int = Word.toInt(Word.andb(w, mask))
136
137 fun intern(l: Element.t list, h: Word.t) =
138 Tree.insert(Array.sub(table, hashToIndex h), l)
139
140 (* the hash of a set is the xorb of the hash of its members *)
141 fun hash(l: Element.t list) =
142 List.fold(l, 0w0, fn (e, w) => Word.xorb(w, Element.hash e))
143
144 fun fromList l =
145 let val l = List.fold(l, [], fn (x, l) =>
146 if List.exists(l, fn x' => Element.equals(x, x'))
147 then l
148 else x :: l)
149 in intern(l, hash l)
150 end
151
152 val empty = fromList []
153
154 fun isEmpty s = equals(s, empty)
155
156 fun foreach(s, f) = List.foreach(toList s, f)
157
158 fun singleton x = fromList [x]
159
160 val cacheHits: int ref = ref 0
161 val cacheMisses: int ref = ref 0
162
163 fun stats() = {hits = !cacheHits, misses = !cacheMisses}
164 fun reset() =
165 (* need to clear out and reset the tables *)
166 (cacheHits := 0
167 ; cacheMisses := 0
168 ; Int.for(0, tableSize, fn i => Array.update(table, i, Tree.new())))
169
170 (* Int.foreach(0, maxIndex, fn i =>
171 let val n = Tree.size(Vector.sub(table, i))
172 in if n > 0
173 then Control.message(seq[Int.layout i,
174 str " -> ",
175 Int.layout n])
176 else ()
177 end)*)
178
179 local
180 fun binary (oper: elements * elements -> elements) =
181 let
182 val cache = Array.new(cacheSize, NONE)
183 in
184 fn (s: t, s': t) =>
185 let
186 fun loop i =
187 if i >= cacheSize
188 then
189 let
190 val s'' = fromList(oper(toList s, toList s'))
191 val () = Int.inc cacheMisses
192 val () =
193 Array.update (cache,
194 Random.natLessThan cacheSize,
195 SOME (s, s', s''))
196 in
197 s''
198 end
199 else case Array.sub(cache, i) of
200 NONE => loop(i + 1)
201 | SOME(s1, s1', s'') =>
202 if equals(s, s1) andalso equals(s', s1')
203 then (Int.inc cacheHits; s'')
204 else loop(i + 1)
205 in loop 0
206 end
207 end
208
209 val {+, -, intersect, layout, ...} =
210 List.set{equals = Element.equals,
211 layout = Element.layout}
212 in
213 val op + = binary op +
214 val op - = binary op -
215 val op intersect = binary intersect
216
217 val layout = layout o toList
218 end
219
220 (* val fromList = Trace.trace("fromList", List.layout Element.layout, layout) fromList *)
221
222 fun traceBinary (name, f) = Trace.trace2 (name, layout, layout, layout) f
223
224 val op + = traceBinary ("UniqueSet.+", op +)
225 val op - = traceBinary ("UniqueSet.-", op -)
226 val op intersect = traceBinary ("UniqueSet.intersect", intersect)
227
228 end