Commit | Line | Data |
---|---|---|
7f918cf1 CE |
1 | (* Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh |
2 | * Jagannathan, and Stephen Weeks. | |
3 | * | |
4 | * MLton is released under a BSD-style license. | |
5 | * See the file MLton-LICENSE for details. | |
6 | *) | |
7 | ||
8 | functor AlphaBeta (S: ALPHA_BETA_STRUCTS): ALPHA_BETA = | |
9 | struct | |
10 | ||
11 | open S | |
12 | ||
13 | val traceAlphaBeta = | |
14 | Trace.trace3 | |
15 | ("AlphaBeta.alphaBeta", State.layout, Value.layout, Value.layout, Value.layout) | |
16 | ||
17 | fun messenger () = | |
18 | let | |
19 | val numCalls = ref 0 | |
20 | val next = ref 1 | |
21 | fun count f = | |
22 | (Int.inc numCalls | |
23 | ; if !numCalls = !next | |
24 | then | |
25 | let | |
26 | val _ = next := 2 * !next | |
27 | open Layout | |
28 | in | |
29 | output (seq [str (Justify.justify (Int.toString (!numCalls), | |
30 | 10, | |
31 | Justify.Left)), | |
32 | str " ", | |
33 | f ()], | |
34 | Out.error) | |
35 | ; Out.newline Out.error | |
36 | end | |
37 | else ()) | |
38 | in count | |
39 | end | |
40 | ||
41 | fun alphaBeta arg = | |
42 | let | |
43 | val count = messenger () | |
44 | fun alphaBeta arg : Value.t = | |
45 | traceAlphaBeta | |
46 | (fn (s: State.t, a: Value.t, b: Value.t) => | |
47 | (count (fn () => | |
48 | let open Layout | |
49 | in align [tuple [Value.layout a, Value.layout b], | |
50 | State.layout s] | |
51 | end) | |
52 | ; (case State.evaluate s of | |
53 | State.Leaf v => | |
54 | if Value.<= (v, a) | |
55 | then a | |
56 | else if Value.<= (b, v) | |
57 | then b | |
58 | else v | |
59 | | State.NonLeaf {lower, upper} => | |
60 | if Value.<= (upper, a) | |
61 | then a | |
62 | else if Value.<= (b, lower) | |
63 | then b | |
64 | else | |
65 | let | |
66 | val a' = Value.move b | |
67 | val b' = Value.move a | |
68 | (* inv: a' <= b'' <= b' *) | |
69 | fun loop (ss, b'') = | |
70 | if Value.equals (a', b'') then b'' | |
71 | else | |
72 | case ss of | |
73 | [] => b'' | |
74 | | s :: ss => | |
75 | loop (ss, alphaBeta (s, a', b'')) | |
76 | in Value.unmove (loop (State.succ s, b')) | |
77 | end))) arg | |
78 | in | |
79 | alphaBeta arg | |
80 | end | |
81 | ||
82 | val alphaBetaNoCache = alphaBeta | |
83 | ||
84 | structure Interval = | |
85 | struct | |
86 | datatype t = T of {lower: Value.t, | |
87 | upper: Value.t} | |
88 | ||
89 | fun layout (T {lower, upper}) = | |
90 | if Value.equals (lower, upper) | |
91 | then Value.layout lower | |
92 | else let open Layout | |
93 | in seq [Value.layout lower, str "-", Value.layout upper] | |
94 | end | |
95 | ||
96 | val make = T | |
97 | ||
98 | local | |
99 | fun make f (T r) = f r | |
100 | in | |
101 | val lower = make #lower | |
102 | val upper = make #upper | |
103 | end | |
104 | ||
105 | val all = T {lower = Value.smallest, | |
106 | upper = Value.largest} | |
107 | ||
108 | fun above (v: Value.t): t = T {lower = v, upper = Value.largest} | |
109 | ||
110 | fun below (v: Value.t): t = T {lower = Value.smallest, upper = v} | |
111 | ||
112 | fun isPoint (T {lower, upper}) = Value.equals (lower, upper) | |
113 | ||
114 | fun point v = T {lower = v, upper = v} | |
115 | ||
116 | fun closest (T {lower, upper}, v: Value.t): Value.t = | |
117 | if Value.<= (v, lower) then lower | |
118 | else if Value.>= (v, upper) then upper | |
119 | else v | |
120 | ||
121 | fun contains (T {lower, upper}, v: Value.t): bool = | |
122 | Value.<= (lower, v) andalso Value.<= (v, upper) | |
123 | ||
124 | fun move (T {lower, upper}): t = | |
125 | T {lower = Value.move upper, | |
126 | upper = Value.move lower} | |
127 | ||
128 | fun intersect (i: t, i': t): t = | |
129 | let val lower = Value.max (lower i, lower i') | |
130 | val upper = Value.min (upper i, upper i') | |
131 | in if Value.> (lower, upper) | |
132 | then Error.bug "AlphaBeta.Interval.intersect: empty intersection" | |
133 | else T {lower = lower, upper = upper} | |
134 | end | |
135 | (* val intersect = Trace.trace2 ("intersect", layout, layout, layout) intersect *) | |
136 | end | |
137 | ||
138 | (* val trace = | |
139 | * Trace.trace2 ("alphaBetaCache", State.layout, Interval.layout, Value.layout) | |
140 | * | |
141 | * fun traceAlphaBeta f (s, i) = | |
142 | * let val v = trace f (s, i) | |
143 | * val v' = alphaBetaNoCache (s, Interval.lower i, Interval.upper i) | |
144 | * in if Value.equals (v, v') | |
145 | * then () | |
146 | * else Misc.bug (let open Layout | |
147 | * in align [str "v = ", Value.layout v, | |
148 | * str "v' = ", Value.layout v'] | |
149 | * end); | |
150 | * v | |
151 | * end | |
152 | * | |
153 | * val traceSearch = Trace.trace ("search", Interval.layout, Value.layout) | |
154 | *) | |
155 | ||
156 | fun alphaBetaCache (s: State.t, i: Interval.t, c: Interval.t Cache.t): Value.t = | |
157 | let | |
158 | val count = messenger () | |
159 | fun alphaBeta (s: State.t, i: Interval.t): Value.t = | |
160 | (count (fn () => | |
161 | let open Layout | |
162 | in align [Interval.layout i, State.layout s] | |
163 | end) | |
164 | ; (case State.evaluate s of | |
165 | State.Leaf v => Interval.closest (i, v) | |
166 | | State.NonLeaf {lower, upper} => | |
167 | if Value.<= (upper, Interval.lower i) | |
168 | then Interval.lower i | |
169 | else if Value.<= (Interval.upper i, lower) | |
170 | then Interval.upper i | |
171 | else | |
172 | let | |
173 | val {update, value} = Cache.peek (c, s) | |
174 | fun search iKnown = | |
175 | let | |
176 | val iSearch = Interval.intersect (i, iKnown) | |
177 | val Interval.T {lower, upper} = | |
178 | Interval.move iSearch | |
179 | (* inv: lower <= v <= upper *) | |
180 | fun loop (ss, v) = | |
181 | if Value.equals (lower, v) then v | |
182 | else | |
183 | case ss of | |
184 | [] => v | |
185 | | s :: ss => | |
186 | loop | |
187 | (ss, | |
188 | alphaBeta | |
189 | (s, Interval.T {lower = lower, | |
190 | upper = v})) | |
191 | val v = | |
192 | Value.unmove (loop (State.succ s, upper)) | |
193 | val Interval.T {lower, upper} = iSearch | |
194 | val iKnown = | |
195 | Interval.intersect | |
196 | (iKnown, | |
197 | if Value.equals (v, upper) | |
198 | then Interval.above upper | |
199 | else if Value.equals (v, lower) | |
200 | then Interval.below lower | |
201 | else Interval.point v) | |
202 | in (*Misc.assert (fn () => | |
203 | Interval.contains | |
204 | (iKnown, | |
205 | alphaBetaNoCache (s, Value.smallest, | |
206 | Value.largest))); *) | |
207 | update iKnown; v | |
208 | end | |
209 | in | |
210 | case value of | |
211 | SOME i' => | |
212 | let | |
213 | val Interval.T {lower, upper} = i | |
214 | val Interval.T {lower = lower', upper = upper'} = | |
215 | i' | |
216 | in if Value.<= (upper', lower) | |
217 | then lower | |
218 | else if Value.>= (lower', upper) | |
219 | then upper | |
220 | else search i' | |
221 | end | |
222 | | NONE => search Interval.all | |
223 | end)) | |
224 | in alphaBeta (s, i) | |
225 | end | |
226 | ||
227 | end |