Import Upstream version 20180207
[hcoop/debian/mlton.git] / lib / mlton / basic / alpha-beta.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh
2 * Jagannathan, and Stephen Weeks.
3 *
4 * MLton is released under a BSD-style license.
5 * See the file MLton-LICENSE for details.
6 *)
7
8functor AlphaBeta (S: ALPHA_BETA_STRUCTS): ALPHA_BETA =
9struct
10
11open S
12
13val traceAlphaBeta =
14 Trace.trace3
15 ("AlphaBeta.alphaBeta", State.layout, Value.layout, Value.layout, Value.layout)
16
17fun messenger () =
18 let
19 val numCalls = ref 0
20 val next = ref 1
21 fun count f =
22 (Int.inc numCalls
23 ; if !numCalls = !next
24 then
25 let
26 val _ = next := 2 * !next
27 open Layout
28 in
29 output (seq [str (Justify.justify (Int.toString (!numCalls),
30 10,
31 Justify.Left)),
32 str " ",
33 f ()],
34 Out.error)
35 ; Out.newline Out.error
36 end
37 else ())
38 in count
39 end
40
41fun alphaBeta arg =
42 let
43 val count = messenger ()
44 fun alphaBeta arg : Value.t =
45 traceAlphaBeta
46 (fn (s: State.t, a: Value.t, b: Value.t) =>
47 (count (fn () =>
48 let open Layout
49 in align [tuple [Value.layout a, Value.layout b],
50 State.layout s]
51 end)
52 ; (case State.evaluate s of
53 State.Leaf v =>
54 if Value.<= (v, a)
55 then a
56 else if Value.<= (b, v)
57 then b
58 else v
59 | State.NonLeaf {lower, upper} =>
60 if Value.<= (upper, a)
61 then a
62 else if Value.<= (b, lower)
63 then b
64 else
65 let
66 val a' = Value.move b
67 val b' = Value.move a
68 (* inv: a' <= b'' <= b' *)
69 fun loop (ss, b'') =
70 if Value.equals (a', b'') then b''
71 else
72 case ss of
73 [] => b''
74 | s :: ss =>
75 loop (ss, alphaBeta (s, a', b''))
76 in Value.unmove (loop (State.succ s, b'))
77 end))) arg
78 in
79 alphaBeta arg
80 end
81
82val alphaBetaNoCache = alphaBeta
83
84structure Interval =
85 struct
86 datatype t = T of {lower: Value.t,
87 upper: Value.t}
88
89 fun layout (T {lower, upper}) =
90 if Value.equals (lower, upper)
91 then Value.layout lower
92 else let open Layout
93 in seq [Value.layout lower, str "-", Value.layout upper]
94 end
95
96 val make = T
97
98 local
99 fun make f (T r) = f r
100 in
101 val lower = make #lower
102 val upper = make #upper
103 end
104
105 val all = T {lower = Value.smallest,
106 upper = Value.largest}
107
108 fun above (v: Value.t): t = T {lower = v, upper = Value.largest}
109
110 fun below (v: Value.t): t = T {lower = Value.smallest, upper = v}
111
112 fun isPoint (T {lower, upper}) = Value.equals (lower, upper)
113
114 fun point v = T {lower = v, upper = v}
115
116 fun closest (T {lower, upper}, v: Value.t): Value.t =
117 if Value.<= (v, lower) then lower
118 else if Value.>= (v, upper) then upper
119 else v
120
121 fun contains (T {lower, upper}, v: Value.t): bool =
122 Value.<= (lower, v) andalso Value.<= (v, upper)
123
124 fun move (T {lower, upper}): t =
125 T {lower = Value.move upper,
126 upper = Value.move lower}
127
128 fun intersect (i: t, i': t): t =
129 let val lower = Value.max (lower i, lower i')
130 val upper = Value.min (upper i, upper i')
131 in if Value.> (lower, upper)
132 then Error.bug "AlphaBeta.Interval.intersect: empty intersection"
133 else T {lower = lower, upper = upper}
134 end
135(* val intersect = Trace.trace2 ("intersect", layout, layout, layout) intersect *)
136 end
137
138(* val trace =
139 * Trace.trace2 ("alphaBetaCache", State.layout, Interval.layout, Value.layout)
140 *
141 * fun traceAlphaBeta f (s, i) =
142 * let val v = trace f (s, i)
143 * val v' = alphaBetaNoCache (s, Interval.lower i, Interval.upper i)
144 * in if Value.equals (v, v')
145 * then ()
146 * else Misc.bug (let open Layout
147 * in align [str "v = ", Value.layout v,
148 * str "v' = ", Value.layout v']
149 * end);
150 * v
151 * end
152 *
153 * val traceSearch = Trace.trace ("search", Interval.layout, Value.layout)
154 *)
155
156fun alphaBetaCache (s: State.t, i: Interval.t, c: Interval.t Cache.t): Value.t =
157 let
158 val count = messenger ()
159 fun alphaBeta (s: State.t, i: Interval.t): Value.t =
160 (count (fn () =>
161 let open Layout
162 in align [Interval.layout i, State.layout s]
163 end)
164 ; (case State.evaluate s of
165 State.Leaf v => Interval.closest (i, v)
166 | State.NonLeaf {lower, upper} =>
167 if Value.<= (upper, Interval.lower i)
168 then Interval.lower i
169 else if Value.<= (Interval.upper i, lower)
170 then Interval.upper i
171 else
172 let
173 val {update, value} = Cache.peek (c, s)
174 fun search iKnown =
175 let
176 val iSearch = Interval.intersect (i, iKnown)
177 val Interval.T {lower, upper} =
178 Interval.move iSearch
179 (* inv: lower <= v <= upper *)
180 fun loop (ss, v) =
181 if Value.equals (lower, v) then v
182 else
183 case ss of
184 [] => v
185 | s :: ss =>
186 loop
187 (ss,
188 alphaBeta
189 (s, Interval.T {lower = lower,
190 upper = v}))
191 val v =
192 Value.unmove (loop (State.succ s, upper))
193 val Interval.T {lower, upper} = iSearch
194 val iKnown =
195 Interval.intersect
196 (iKnown,
197 if Value.equals (v, upper)
198 then Interval.above upper
199 else if Value.equals (v, lower)
200 then Interval.below lower
201 else Interval.point v)
202 in (*Misc.assert (fn () =>
203 Interval.contains
204 (iKnown,
205 alphaBetaNoCache (s, Value.smallest,
206 Value.largest))); *)
207 update iKnown; v
208 end
209 in
210 case value of
211 SOME i' =>
212 let
213 val Interval.T {lower, upper} = i
214 val Interval.T {lower = lower', upper = upper'} =
215 i'
216 in if Value.<= (upper', lower)
217 then lower
218 else if Value.>= (lower', upper)
219 then upper
220 else search i'
221 end
222 | NONE => search Interval.all
223 end))
224 in alphaBeta (s, i)
225 end
226
227end