| 1 | (* Copyright (C) 1999-2006 Henry Cejtin, Matthew Fluet, Suresh |
| 2 | * Jagannathan, and Stephen Weeks. |
| 3 | * |
| 4 | * MLton is released under a BSD-style license. |
| 5 | * See the file MLton-LICENSE for details. |
| 6 | *) |
| 7 | |
| 8 | functor AlphaBeta (S: ALPHA_BETA_STRUCTS): ALPHA_BETA = |
| 9 | struct |
| 10 | |
| 11 | open S |
| 12 | |
| 13 | val traceAlphaBeta = |
| 14 | Trace.trace3 |
| 15 | ("AlphaBeta.alphaBeta", State.layout, Value.layout, Value.layout, Value.layout) |
| 16 | |
| 17 | fun messenger () = |
| 18 | let |
| 19 | val numCalls = ref 0 |
| 20 | val next = ref 1 |
| 21 | fun count f = |
| 22 | (Int.inc numCalls |
| 23 | ; if !numCalls = !next |
| 24 | then |
| 25 | let |
| 26 | val _ = next := 2 * !next |
| 27 | open Layout |
| 28 | in |
| 29 | output (seq [str (Justify.justify (Int.toString (!numCalls), |
| 30 | 10, |
| 31 | Justify.Left)), |
| 32 | str " ", |
| 33 | f ()], |
| 34 | Out.error) |
| 35 | ; Out.newline Out.error |
| 36 | end |
| 37 | else ()) |
| 38 | in count |
| 39 | end |
| 40 | |
| 41 | fun alphaBeta arg = |
| 42 | let |
| 43 | val count = messenger () |
| 44 | fun alphaBeta arg : Value.t = |
| 45 | traceAlphaBeta |
| 46 | (fn (s: State.t, a: Value.t, b: Value.t) => |
| 47 | (count (fn () => |
| 48 | let open Layout |
| 49 | in align [tuple [Value.layout a, Value.layout b], |
| 50 | State.layout s] |
| 51 | end) |
| 52 | ; (case State.evaluate s of |
| 53 | State.Leaf v => |
| 54 | if Value.<= (v, a) |
| 55 | then a |
| 56 | else if Value.<= (b, v) |
| 57 | then b |
| 58 | else v |
| 59 | | State.NonLeaf {lower, upper} => |
| 60 | if Value.<= (upper, a) |
| 61 | then a |
| 62 | else if Value.<= (b, lower) |
| 63 | then b |
| 64 | else |
| 65 | let |
| 66 | val a' = Value.move b |
| 67 | val b' = Value.move a |
| 68 | (* inv: a' <= b'' <= b' *) |
| 69 | fun loop (ss, b'') = |
| 70 | if Value.equals (a', b'') then b'' |
| 71 | else |
| 72 | case ss of |
| 73 | [] => b'' |
| 74 | | s :: ss => |
| 75 | loop (ss, alphaBeta (s, a', b'')) |
| 76 | in Value.unmove (loop (State.succ s, b')) |
| 77 | end))) arg |
| 78 | in |
| 79 | alphaBeta arg |
| 80 | end |
| 81 | |
| 82 | val alphaBetaNoCache = alphaBeta |
| 83 | |
| 84 | structure Interval = |
| 85 | struct |
| 86 | datatype t = T of {lower: Value.t, |
| 87 | upper: Value.t} |
| 88 | |
| 89 | fun layout (T {lower, upper}) = |
| 90 | if Value.equals (lower, upper) |
| 91 | then Value.layout lower |
| 92 | else let open Layout |
| 93 | in seq [Value.layout lower, str "-", Value.layout upper] |
| 94 | end |
| 95 | |
| 96 | val make = T |
| 97 | |
| 98 | local |
| 99 | fun make f (T r) = f r |
| 100 | in |
| 101 | val lower = make #lower |
| 102 | val upper = make #upper |
| 103 | end |
| 104 | |
| 105 | val all = T {lower = Value.smallest, |
| 106 | upper = Value.largest} |
| 107 | |
| 108 | fun above (v: Value.t): t = T {lower = v, upper = Value.largest} |
| 109 | |
| 110 | fun below (v: Value.t): t = T {lower = Value.smallest, upper = v} |
| 111 | |
| 112 | fun isPoint (T {lower, upper}) = Value.equals (lower, upper) |
| 113 | |
| 114 | fun point v = T {lower = v, upper = v} |
| 115 | |
| 116 | fun closest (T {lower, upper}, v: Value.t): Value.t = |
| 117 | if Value.<= (v, lower) then lower |
| 118 | else if Value.>= (v, upper) then upper |
| 119 | else v |
| 120 | |
| 121 | fun contains (T {lower, upper}, v: Value.t): bool = |
| 122 | Value.<= (lower, v) andalso Value.<= (v, upper) |
| 123 | |
| 124 | fun move (T {lower, upper}): t = |
| 125 | T {lower = Value.move upper, |
| 126 | upper = Value.move lower} |
| 127 | |
| 128 | fun intersect (i: t, i': t): t = |
| 129 | let val lower = Value.max (lower i, lower i') |
| 130 | val upper = Value.min (upper i, upper i') |
| 131 | in if Value.> (lower, upper) |
| 132 | then Error.bug "AlphaBeta.Interval.intersect: empty intersection" |
| 133 | else T {lower = lower, upper = upper} |
| 134 | end |
| 135 | (* val intersect = Trace.trace2 ("intersect", layout, layout, layout) intersect *) |
| 136 | end |
| 137 | |
| 138 | (* val trace = |
| 139 | * Trace.trace2 ("alphaBetaCache", State.layout, Interval.layout, Value.layout) |
| 140 | * |
| 141 | * fun traceAlphaBeta f (s, i) = |
| 142 | * let val v = trace f (s, i) |
| 143 | * val v' = alphaBetaNoCache (s, Interval.lower i, Interval.upper i) |
| 144 | * in if Value.equals (v, v') |
| 145 | * then () |
| 146 | * else Misc.bug (let open Layout |
| 147 | * in align [str "v = ", Value.layout v, |
| 148 | * str "v' = ", Value.layout v'] |
| 149 | * end); |
| 150 | * v |
| 151 | * end |
| 152 | * |
| 153 | * val traceSearch = Trace.trace ("search", Interval.layout, Value.layout) |
| 154 | *) |
| 155 | |
| 156 | fun alphaBetaCache (s: State.t, i: Interval.t, c: Interval.t Cache.t): Value.t = |
| 157 | let |
| 158 | val count = messenger () |
| 159 | fun alphaBeta (s: State.t, i: Interval.t): Value.t = |
| 160 | (count (fn () => |
| 161 | let open Layout |
| 162 | in align [Interval.layout i, State.layout s] |
| 163 | end) |
| 164 | ; (case State.evaluate s of |
| 165 | State.Leaf v => Interval.closest (i, v) |
| 166 | | State.NonLeaf {lower, upper} => |
| 167 | if Value.<= (upper, Interval.lower i) |
| 168 | then Interval.lower i |
| 169 | else if Value.<= (Interval.upper i, lower) |
| 170 | then Interval.upper i |
| 171 | else |
| 172 | let |
| 173 | val {update, value} = Cache.peek (c, s) |
| 174 | fun search iKnown = |
| 175 | let |
| 176 | val iSearch = Interval.intersect (i, iKnown) |
| 177 | val Interval.T {lower, upper} = |
| 178 | Interval.move iSearch |
| 179 | (* inv: lower <= v <= upper *) |
| 180 | fun loop (ss, v) = |
| 181 | if Value.equals (lower, v) then v |
| 182 | else |
| 183 | case ss of |
| 184 | [] => v |
| 185 | | s :: ss => |
| 186 | loop |
| 187 | (ss, |
| 188 | alphaBeta |
| 189 | (s, Interval.T {lower = lower, |
| 190 | upper = v})) |
| 191 | val v = |
| 192 | Value.unmove (loop (State.succ s, upper)) |
| 193 | val Interval.T {lower, upper} = iSearch |
| 194 | val iKnown = |
| 195 | Interval.intersect |
| 196 | (iKnown, |
| 197 | if Value.equals (v, upper) |
| 198 | then Interval.above upper |
| 199 | else if Value.equals (v, lower) |
| 200 | then Interval.below lower |
| 201 | else Interval.point v) |
| 202 | in (*Misc.assert (fn () => |
| 203 | Interval.contains |
| 204 | (iKnown, |
| 205 | alphaBetaNoCache (s, Value.smallest, |
| 206 | Value.largest))); *) |
| 207 | update iKnown; v |
| 208 | end |
| 209 | in |
| 210 | case value of |
| 211 | SOME i' => |
| 212 | let |
| 213 | val Interval.T {lower, upper} = i |
| 214 | val Interval.T {lower = lower', upper = upper'} = |
| 215 | i' |
| 216 | in if Value.<= (upper', lower) |
| 217 | then lower |
| 218 | else if Value.>= (lower', upper) |
| 219 | then upper |
| 220 | else search i' |
| 221 | end |
| 222 | | NONE => search Interval.all |
| 223 | end)) |
| 224 | in alphaBeta (s, i) |
| 225 | end |
| 226 | |
| 227 | end |