Import Upstream version 20180207
[hcoop/debian/mlton.git] / lib / mlton / basic / merge-sort.sml
1 (* Copyright (C) 1999-2005 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 *
4 * MLton is released under a BSD-style license.
5 * See the file MLton-LICENSE for details.
6 *)
7
8 functor MergeSort (S:
9 sig
10 type 'a t
11 val make: ('a * 'a -> bool) -> {isSorted: 'a t -> bool,
12 merge: 'a t * 'a t -> 'a t,
13 sort: 'a t -> 'a t}
14 end): MERGE_SORT =
15 struct
16 open S
17
18 fun isSorted (l, le) = #isSorted (make le) l
19 fun merge (l, l', le) = #merge (make le) (l, l')
20 fun sort (l, le) = #sort (make le) l
21 end
22
23 structure MergeSortList: MERGE_SORT =
24 MergeSort
25 (type 'a t = 'a list
26
27 (* This is a variant of mergesort that runs in O (n log n) time. *)
28 fun make (op <= : 'a * 'a -> bool) =
29 let
30 fun assert f = Assert.assert ("MergeSort.assert", f)
31 fun isSorted l =
32 case l of
33 [] => true
34 | x :: l =>
35 let
36 fun loop (x, l) =
37 case l of
38 [] => true
39 | x' :: l => x <= x' andalso loop (x', l)
40 in loop (x, l)
41 end
42 fun merge (l1, l2) =
43 (assert (fn () => isSorted l1 andalso isSorted l2)
44 ; (case (l1, l2) of
45 ([], _) => l2
46 | (_, []) => l1
47 | (x1 :: l1', x2 :: l2') =>
48 if x1 <= x2
49 then x1 :: merge (l1', l2)
50 else x2 :: merge (l1, l2')))
51 fun sort l =
52 let
53 val numBuckets = 25
54 val _ = assert (fn () => length l < Int.pow (2, numBuckets) - 1)
55 val a: 'a list array = Array.new (numBuckets, [])
56 fun invariant () =
57 assert (fn () => Array.foralli (a, fn (i, l) =>
58 case l of
59 [] => true
60 | _ => (length l = Int.pow (2, i)
61 andalso isSorted l)))
62 fun mergeIn (i: int, l: 'a list): unit =
63 (assert (fn () => length l = Int.pow (2, i))
64 ; (case Array.sub (a, i) of
65 [] => Array.update (a, i, l)
66 | l' => (Array.update (a, i, [])
67 ; mergeIn (i + 1, merge (l, l')))))
68 val _ = List.foreach (l, fn x => mergeIn (0, [x]))
69 val l = Array.fold (a, [], fn (l, l') =>
70 case l of
71 [] => l'
72 | _ => merge (l, l'))
73 val _ = assert (fn () => isSorted l)
74 in l
75 end
76 in
77 {isSorted = isSorted,
78 merge = merge,
79 sort = sort}
80 end)
81
82 structure MergeSortVector: MERGE_SORT =
83 MergeSort
84 (type 'a t = 'a vector
85
86 fun make (op <=) =
87 let
88 fun isSorted v = Vector.isSorted (v, op <=)
89 fun merge (v, v') =
90 let
91 val _ = Assert.assert ("MergeSortVector.merge: pre", fn () =>
92 isSorted (v, op <=)
93 andalso isSorted (v', op <=))
94 val n = length v
95 val n' = length v'
96 val r = ref 0
97 val r' = ref 0
98 fun next _ =
99 let
100 val i = !r
101 val i' = !r'
102 (* 0 <= i <= n andalso 0 <= i' <= n' *)
103 in
104 if i = n
105 then
106 let
107 val res = sub (v', i')
108 val _ = Int.inc r'
109 in res
110 end
111 else if i' = n'
112 then
113 let
114 val res = sub (v, i)
115 val _ = Int.inc r
116 in res
117 end
118 else
119 let
120 val a = sub (v, i)
121 val a' = sub (v', i')
122 in
123 if a <= a'
124 then (Int.inc r; a)
125 else (Int.inc r'; a')
126 end
127 end
128 val v = tabulate (n + n', fn _ => next ())
129 val _ = Assert.assert ("MergeSortVector.merge: post", fn () =>
130 isSorted (v, op <=))
131 in
132 v
133 end
134
135 fun sort v =
136 let
137 fun loop v =
138 if isSorted (v, op <=)
139 then v
140 else
141 let
142 val n = length v
143 val m = n div 2
144 val m' = n - m
145 fun get (m, start) =
146 loop
147 (tabulate (m,
148 let val r = ref start
149 in fn _ =>
150 let
151 val i = !r
152 val res = sub (v, i)
153 val _ = r := 2 + i
154 in res
155 end
156 end))
157 in merge (get (m', 0), get (m, 1), op <=)
158 end
159 val v = loop v
160 val _ = Assert.assert ("MergeSortVector.sort", fn () =>
161 isSorted (v, op <=))
162 in
163 v
164 end
165 in
166 {isSorted = isSorted,
167 merge = merge,
168 sort = sort}
169 end)