Import Upstream version 20180207
[hcoop/debian/mlton.git] / lib / mlton / basic / euclidean-ring.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 *
4 * MLton is released under a BSD-style license.
5 * See the file MLton-LICENSE for details.
6 *)
7
8functor EuclideanRing(S: EUCLIDEAN_RING_STRUCTS)
9 :> EUCLIDEAN_RING where type t = S.t =
10struct
11
12open S
13
14structure IntInf = Pervasive.IntInf
15
16val divMod =
17 Trace.traceAssert
18 ("EuclideanRing.divMod",
19 Layout.tuple2(layout, layout),
20 Layout.tuple2(layout, layout),
21 fn (p, q) => (not(equals(q, zero)),
22 fn (d, m) => (equals(p, q * d + m)
23 andalso (equals(m, zero)
24 orelse IntInf.<(metric m, metric q)))))
25 divMod
26
27fun p div q = #1(divMod(p, q))
28
29fun p mod q = #2(divMod(p, q))
30
31fun divides(d: t, x: t): bool = equals(x mod d, zero)
32
33val divides =
34 Trace.trace("EuclideanRing.divides", Layout.tuple2(layout, layout), Bool.layout) divides
35
36(* Taken from page 812 of CLR. *)
37fun extendedEuclidTerm(a: t, b: t, done: t * t -> bool, trace): t * t * t =
38 let
39 fun loop(a, b) =
40 if done(a, b)
41 then (a, one, zero)
42 else let val (d, m) = divMod(a, b)
43 val (d', x', y') = loop(b, m)
44 in (d', y', x' - d * y')
45 end
46 in trace loop(a, b)
47 end
48
49fun makeTraceExtendedEuclid f =
50 Trace.traceAssert
51 ("EuclideanRing.extendedEuclid",
52 Layout.tuple2(layout, layout),
53 Layout.tuple3(layout, layout, layout),
54 fn (a, b) => (not(isZero a) andalso not(isZero b),
55 fn (d, x, y) => (f(d, x, y)
56 andalso equals(d, a * x + b * y))))
57
58local
59 val trace =
60 makeTraceExtendedEuclid
61 (fn (d, x, y) => divides(d, x) andalso divides(d, y))
62in
63 (* Page 72 of Bach and Shallit. *)
64 (* Identical to algorithm on page 23 of Berlekamp. *)
65 (* This algorithm is slower (about 2x) than the recursive extendedEuclid
66 * given above, but stores only a constant number of ring elements.
67 * Thus, for now, it is overridden below.
68 *)
69 fun extendedEuclid(u0: t, u1: t): t * t * t =
70 let
71 val rec loop =
72 fn (r as {m11, m12, m21, m22, u, v, nEven}) =>
73 (Assert.assert("EuclideanRing.extendedEuclid", fn () =>
74 equals(u0, m11 * u + m12 * v)
75 andalso equals(u1, m21 * u + m22 * v)
76 andalso equals(if nEven then one else negOne,
77 m11 * m22 - m12 * m21))
78 ; if isZero v
79 then r
80 else
81 let val (q, r) = divMod(u, v)
82 in loop{m11 = q * m11 + m12,
83 m12 = m11,
84 m21 = q * m21 + m22,
85 m22 = m21,
86 u = v,
87 v = r,
88 nEven = not nEven}
89 end)
90 val {m12, m22, u, nEven, ...} =
91 loop{m11 = one, m12 = zero, m21 = zero, m22 = one,
92 u = u0, v = u1, nEven = true}
93 val (a, b) = if nEven then (m22, ~m12) else (~m22, m12)
94 in (u, a, b)
95 end
96
97 val _ = extendedEuclid
98
99 fun extendedEuclid (a, b) =
100 extendedEuclidTerm (a, b, fn (_, b) => equals (b, zero), trace)
101end
102
103local
104 val trace = makeTraceExtendedEuclid(fn _ => true)
105in
106 val extendedEuclidTerm =
107 fn (a, b, done) => extendedEuclidTerm(a, b, done, trace)
108end
109
110val lastPrime = ref one
111
112fun gcd(a, b) = if isZero b then a else gcd(b, a mod b)
113
114fun lcm(a, b) = (a * b) div gcd(a, b)
115
116val primes: t Stream.t =
117 let
118 fun loop(s: t Stream.t) =
119 Stream.delay
120 (fn () =>
121 let val (p, s) = valOf(Stream.force s)
122 val _ = lastPrime := p
123 in Stream.cons
124 (p, loop(Stream.keep(s, fn x => not(divides(p, x)))))
125 end)
126 in loop monics
127 end
128
129structure Int =
130 struct
131 open Pervasive.Int
132 type t = int
133 val layout = Layout.str o toString
134 end
135
136type factors = (t * Int.t) list
137
138fun factor(n: t): factors =
139 let
140 fun loop(n: t, primes: t Stream.t, factors: factors) =
141 if equals(n, one)
142 then factors
143 else let val (p, primes) = valOf(Stream.force primes)
144 val (n, k) =
145 let
146 fun loop(n, k) =
147 let val (q, r) = divMod(n, p)
148 in if isZero r
149 then loop(q, Int.+(k, 1))
150 else (n, k)
151 end
152 in loop(n, 0)
153 end
154 in loop(n, primes,
155 if k = 0
156 then factors
157 else (p, k) :: factors)
158 end
159 in loop(n, primes, [])
160 end
161
162val factor =
163 Trace.traceAssert
164 ("EuclideanRing.factor", layout, List.layout (Layout.tuple2(layout, Int.layout)),
165 fn n => (not(isZero n), fn factors =>
166 equals(n, List.fold(factors, one, fn ((p, k), prod) =>
167 prod * pow (p, k)))))
168 factor
169
170fun existsPrimeOfSmallerMetric(m: IntInf.int, f: t -> bool): bool =
171 let
172 fun loop primes =
173 let val (p, primes) = valOf(Stream.force primes)
174 in IntInf.<(metric p, m)
175 andalso (f p orelse loop primes)
176 end
177 in loop primes
178 end
179
180fun isPrime(r: t): bool =
181 let val r = unitEquivalent r
182 in existsPrimeOfSmallerMetric(IntInf.+ (metric r, 1),
183 fn p => equals(r, p))
184 end
185
186fun isComposite(r: t): bool =
187 existsPrimeOfSmallerMetric(metric r, fn p => divides(p, r))
188
189end