Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* Copyright (C) 1999-2005 Henry Cejtin, Matthew Fluet, Suresh |
2 | * Jagannathan, and Stephen Weeks. | |
3 | * | |
4 | * MLton is released under a BSD-style license. | |
5 | * See the file MLton-LICENSE for details. | |
6 | *) | |
7 | ||
8 | functor MergeSort (S: | |
9 | sig | |
10 | type 'a t | |
11 | val make: ('a * 'a -> bool) -> {isSorted: 'a t -> bool, | |
12 | merge: 'a t * 'a t -> 'a t, | |
13 | sort: 'a t -> 'a t} | |
14 | end): MERGE_SORT = | |
15 | struct | |
16 | open S | |
17 | ||
18 | fun isSorted (l, le) = #isSorted (make le) l | |
19 | fun merge (l, l', le) = #merge (make le) (l, l') | |
20 | fun sort (l, le) = #sort (make le) l | |
21 | end | |
22 | ||
23 | structure MergeSortList: MERGE_SORT = | |
24 | MergeSort | |
25 | (type 'a t = 'a list | |
26 | ||
27 | (* This is a variant of mergesort that runs in O (n log n) time. *) | |
28 | fun make (op <= : 'a * 'a -> bool) = | |
29 | let | |
30 | fun assert f = Assert.assert ("MergeSort.assert", f) | |
31 | fun isSorted l = | |
32 | case l of | |
33 | [] => true | |
34 | | x :: l => | |
35 | let | |
36 | fun loop (x, l) = | |
37 | case l of | |
38 | [] => true | |
39 | | x' :: l => x <= x' andalso loop (x', l) | |
40 | in loop (x, l) | |
41 | end | |
42 | fun merge (l1, l2) = | |
43 | (assert (fn () => isSorted l1 andalso isSorted l2) | |
44 | ; (case (l1, l2) of | |
45 | ([], _) => l2 | |
46 | | (_, []) => l1 | |
47 | | (x1 :: l1', x2 :: l2') => | |
48 | if x1 <= x2 | |
49 | then x1 :: merge (l1', l2) | |
50 | else x2 :: merge (l1, l2'))) | |
51 | fun sort l = | |
52 | let | |
53 | val numBuckets = 25 | |
54 | val _ = assert (fn () => length l < Int.pow (2, numBuckets) - 1) | |
55 | val a: 'a list array = Array.new (numBuckets, []) | |
56 | fun invariant () = | |
57 | assert (fn () => Array.foralli (a, fn (i, l) => | |
58 | case l of | |
59 | [] => true | |
60 | | _ => (length l = Int.pow (2, i) | |
61 | andalso isSorted l))) | |
62 | fun mergeIn (i: int, l: 'a list): unit = | |
63 | (assert (fn () => length l = Int.pow (2, i)) | |
64 | ; (case Array.sub (a, i) of | |
65 | [] => Array.update (a, i, l) | |
66 | | l' => (Array.update (a, i, []) | |
67 | ; mergeIn (i + 1, merge (l, l'))))) | |
68 | val _ = List.foreach (l, fn x => mergeIn (0, [x])) | |
69 | val l = Array.fold (a, [], fn (l, l') => | |
70 | case l of | |
71 | [] => l' | |
72 | | _ => merge (l, l')) | |
73 | val _ = assert (fn () => isSorted l) | |
74 | in l | |
75 | end | |
76 | in | |
77 | {isSorted = isSorted, | |
78 | merge = merge, | |
79 | sort = sort} | |
80 | end) | |
81 | ||
82 | structure MergeSortVector: MERGE_SORT = | |
83 | MergeSort | |
84 | (type 'a t = 'a vector | |
85 | ||
86 | fun make (op <=) = | |
87 | let | |
88 | fun isSorted v = Vector.isSorted (v, op <=) | |
89 | fun merge (v, v') = | |
90 | let | |
91 | val _ = Assert.assert ("MergeSortVector.merge: pre", fn () => | |
92 | isSorted (v, op <=) | |
93 | andalso isSorted (v', op <=)) | |
94 | val n = length v | |
95 | val n' = length v' | |
96 | val r = ref 0 | |
97 | val r' = ref 0 | |
98 | fun next _ = | |
99 | let | |
100 | val i = !r | |
101 | val i' = !r' | |
102 | (* 0 <= i <= n andalso 0 <= i' <= n' *) | |
103 | in | |
104 | if i = n | |
105 | then | |
106 | let | |
107 | val res = sub (v', i') | |
108 | val _ = Int.inc r' | |
109 | in res | |
110 | end | |
111 | else if i' = n' | |
112 | then | |
113 | let | |
114 | val res = sub (v, i) | |
115 | val _ = Int.inc r | |
116 | in res | |
117 | end | |
118 | else | |
119 | let | |
120 | val a = sub (v, i) | |
121 | val a' = sub (v', i') | |
122 | in | |
123 | if a <= a' | |
124 | then (Int.inc r; a) | |
125 | else (Int.inc r'; a') | |
126 | end | |
127 | end | |
128 | val v = tabulate (n + n', fn _ => next ()) | |
129 | val _ = Assert.assert ("MergeSortVector.merge: post", fn () => | |
130 | isSorted (v, op <=)) | |
131 | in | |
132 | v | |
133 | end | |
134 | ||
135 | fun sort v = | |
136 | let | |
137 | fun loop v = | |
138 | if isSorted (v, op <=) | |
139 | then v | |
140 | else | |
141 | let | |
142 | val n = length v | |
143 | val m = n div 2 | |
144 | val m' = n - m | |
145 | fun get (m, start) = | |
146 | loop | |
147 | (tabulate (m, | |
148 | let val r = ref start | |
149 | in fn _ => | |
150 | let | |
151 | val i = !r | |
152 | val res = sub (v, i) | |
153 | val _ = r := 2 + i | |
154 | in res | |
155 | end | |
156 | end)) | |
157 | in merge (get (m', 0), get (m, 1), op <=) | |
158 | end | |
159 | val v = loop v | |
160 | val _ = Assert.assert ("MergeSortVector.sort", fn () => | |
161 | isSorted (v, op <=)) | |
162 | in | |
163 | v | |
164 | end | |
165 | in | |
166 | {isSorted = isSorted, | |
167 | merge = merge, | |
168 | sort = sort} | |
169 | end) |