Import Upstream version 20180207
[hcoop/debian/mlton.git] / mlton / ssa / poly-equal.fun
CommitLineData
7f918cf1
CE
1(* Copyright (C) 2009,2014 Matthew Fluet.
2 * Copyright (C) 1999-2008 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 * Copyright (C) 1997-2000 NEC Research Institute.
5 *
6 * MLton is released under a BSD-style license.
7 * See the file MLton-LICENSE for details.
8 *)
9
10functor PolyEqual (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
11struct
12
13open S
14
15(*
16 * This pass implements polymorphic equality.
17 *
18 * For each datatype tycon and vector type, it builds an equality function and
19 * translates calls to MLton_equal into calls to that function.
20 *
21 * Also generates calls to primitive wordEqual.
22 *
23 * For tuples, it does the equality test inline. I.E. it does not create
24 * a separate equality function for each tuple type.
25 *
26 * All equality functions are only created if necessary, i.e. if equality
27 * is actually used at a type.
28 *
29 * Optimizations:
30 * - For datatype tycons that are enumerations, do not build a case dispatch,
31 * just use eq, since you know the backend will represent these as ints.
32 * - Deep equality always does an eq test first.
33 * - If one argument to = is a constant and the type will get translated to
34 * an IntOrPointer, then just use eq instead of the full equality. This is
35 * important for implementing code like the following efficiently:
36 * if x = 0 ... (where x is an IntInf.int)
37 *
38 * Also convert pointer equality on scalar types to type specific primitives.
39 *)
40
41open Exp Transfer
42
43structure Dexp =
44 struct
45 open DirectExp
46
47 fun conjoin (e1: t, e2: t): t =
48 casee {test = e1,
49 cases = Con (Vector.new2 ({con = Con.truee,
50 args = Vector.new0 (),
51 body = e2},
52 {con = Con.falsee,
53 args = Vector.new0 (),
54 body = falsee})),
55 default = NONE,
56 ty = Type.bool}
57
58 fun disjoin (e1: t, e2:t): t =
59 casee {test = e1,
60 cases = Con (Vector.new2 ({con = Con.truee,
61 args = Vector.new0 (),
62 body = truee},
63 {con = Con.falsee,
64 args = Vector.new0 (),
65 body = e2})),
66 default = NONE,
67 ty = Type.bool}
68
69 local
70 fun mk prim =
71 fn (e1: t, e2: t, s) =>
72 primApp {prim = prim s,
73 targs = Vector.new0 (),
74 args = Vector.new2 (e1, e2),
75 ty = Type.word s}
76 in
77 val add = mk Prim.wordAdd
78 val andb = mk Prim.wordAndb
79 val orb = mk Prim.wordOrb
80 end
81
82 fun wordEqual (e1: t, e2: t, s): t =
83 primApp {prim = Prim.wordEqual s,
84 targs = Vector.new0 (),
85 args = Vector.new2 (e1, e2),
86 ty = Type.bool}
87 end
88
89fun transform (Program.T {datatypes, globals, functions, main}) =
90 let
91 val {get = funcInfo: Func.t -> {hasEqual: bool},
92 set = setFuncInfo, ...} =
93 Property.getSet (Func.plist, Property.initConst {hasEqual = false})
94 val {get = labelInfo: Label.t -> {hasEqual: bool},
95 set = setLabelInfo, ...} =
96 Property.getSet (Label.plist, Property.initConst {hasEqual = false})
97 val {get = varInfo: Var.t -> {isConst: bool},
98 set = setVarInfo, ...} =
99 Property.getSetOnce (Var.plist, Property.initConst {isConst = false})
100 val {get = tyconInfo: Tycon.t -> {isEnum: bool,
101 cons: {con: Con.t,
102 args: Type.t vector} vector},
103 set = setTyconInfo, ...} =
104 Property.getSetOnce
105 (Tycon.plist, Property.initRaise ("PolyEqual.tyconInfo", Tycon.layout))
106 val isEnum = #isEnum o tyconInfo
107 val tyconCons = #cons o tyconInfo
108 val {get = getTyconEqualFunc: Tycon.t -> Func.t option,
109 set = setTyconEqualFunc, ...} =
110 Property.getSet (Tycon.plist, Property.initConst NONE)
111 val {get = getVectorEqualFunc: Type.t -> Func.t option,
112 set = setVectorEqualFunc,
113 destroy = destroyVectorEqualFunc} =
114 Property.destGetSet (Type.plist, Property.initConst NONE)
115 val (getIntInfEqualFunc: unit -> Func.t option,
116 setIntInfEqualFunc: Func.t option -> unit) =
117 let
118 val r = ref NONE
119 in
120 (fn () => !r, fn fo => r := fo)
121 end
122 val returns = SOME (Vector.new1 Type.bool)
123 val seqIndexWordSize = WordSize.seqIndex ()
124 val seqIndexTy = Type.word seqIndexWordSize
125 val newFunctions: Function.t list ref = ref []
126 fun newFunction z =
127 List.push (newFunctions,
128 Function.profile (Function.new z,
129 SourceInfo.polyEqual))
130 fun equalTyconFunc (tycon: Tycon.t): Func.t =
131 case getTyconEqualFunc tycon of
132 SOME f => f
133 | NONE =>
134 let
135 val name =
136 Func.newString (concat ["equal_", Tycon.originalName tycon])
137 val _ = setTyconEqualFunc (tycon, SOME name)
138 val ty = Type.datatypee tycon
139 val arg1 = (Var.newNoname (), ty)
140 val arg2 = (Var.newNoname (), ty)
141 val args = Vector.new2 (arg1, arg2)
142 val darg1 = Dexp.var arg1
143 val darg2 = Dexp.var arg2
144 val cons = tyconCons tycon
145 val body =
146 Dexp.disjoin
147 (Dexp.eq (Dexp.var arg1, Dexp.var arg2, ty),
148 Dexp.casee
149 {test = darg1,
150 ty = Type.bool,
151 default = (if Vector.exists (cons, fn {args, ...} =>
152 Vector.isEmpty args)
153 then SOME Dexp.falsee
154 else NONE),
155 cases =
156 Dexp.Con
157 (Vector.keepAllMap
158 (cons, fn {con, args} =>
159 if Vector.isEmpty args
160 then NONE
161 else
162 let
163 fun makeArgs () =
164 Vector.map (args, fn ty =>
165 (Var.newNoname (), ty))
166 val xs = makeArgs ()
167 val ys = makeArgs ()
168 in
169 SOME
170 {con = con,
171 args = xs,
172 body =
173 Dexp.casee
174 {test = darg2,
175 ty = Type.bool,
176 default = if 1 = Vector.length cons
177 then NONE
178 else SOME Dexp.falsee,
179 cases =
180 Dexp.Con
181 (Vector.new1
182 {con = con,
183 args = ys,
184 body =
185 Vector.fold2
186 (xs, ys, Dexp.truee,
187 fn ((x, ty), (y, _), de) =>
188 Dexp.conjoin (de, equal (x, y, ty)))})}}
189 end))})
190 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
191 val blocks = Vector.fromList blocks
192 val _ =
193 newFunction {args = args,
194 blocks = blocks,
195 mayInline = true,
196 name = name,
197 raises = NONE,
198 returns = returns,
199 start = start}
200 in
201 name
202 end
203 and mkVectorEqualFunc {name: Func.t,
204 ty: Type.t, doEq: bool}: unit =
205 let
206 val loop = Func.newString (Func.originalName name ^ "Loop")
207 (* Build two functions, one that checks the lengths and the
208 * other that loops.
209 *)
210 val vty = Type.vector ty
211 local
212 val vec1 = (Var.newNoname (), vty)
213 val vec2 = (Var.newNoname (), vty)
214 val args = Vector.new2 (vec1, vec2)
215 val dvec1 = Dexp.var vec1
216 val dvec2 = Dexp.var vec2
217 val len1 = (Var.newNoname (), seqIndexTy)
218 val dlen1 = Dexp.var len1
219 val len2 = (Var.newNoname (), seqIndexTy)
220 val dlen2 = Dexp.var len2
221
222 val body =
223 let
224 fun length dvec =
225 Dexp.primApp {prim = Prim.vectorLength,
226 targs = Vector.new1 ty,
227 args = Vector.new1 dvec,
228 ty = Type.word seqIndexWordSize}
229 val body =
230 Dexp.lett
231 {decs = [{var = #1 len1, exp = length dvec1},
232 {var = #1 len2, exp = length dvec2}],
233 body =
234 Dexp.conjoin
235 (Dexp.wordEqual (dlen1, dlen2, seqIndexWordSize),
236 Dexp.call
237 {func = loop,
238 args = Vector.new4
239 (dvec1, dvec2, dlen1,
240 Dexp.word (WordX.zero seqIndexWordSize)),
241 ty = Type.bool})}
242 in
243 if doEq
244 then Dexp.disjoin (Dexp.eq (dvec1, dvec2, vty), body)
245 else body
246 end
247 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
248 val blocks = Vector.fromList blocks
249 in
250 val _ =
251 newFunction {args = args,
252 blocks = blocks,
253 mayInline = true,
254 name = name,
255 raises = NONE,
256 returns = returns,
257 start = start}
258 end
259 local
260 val vec1 = (Var.newNoname (), vty)
261 val vec2 = (Var.newNoname (), vty)
262 val len = (Var.newNoname (), seqIndexTy)
263 val i = (Var.newNoname (), seqIndexTy)
264 val args = Vector.new4 (vec1, vec2, len, i)
265 val dvec1 = Dexp.var vec1
266 val dvec2 = Dexp.var vec2
267 val dlen = Dexp.var len
268 val di = Dexp.var i
269 val body =
270 let
271 fun sub (dvec, di) =
272 Dexp.primApp {prim = Prim.vectorSub,
273 targs = Vector.new1 ty,
274 args = Vector.new2 (dvec, di),
275 ty = ty}
276 val args =
277 Vector.new4
278 (dvec1, dvec2, dlen,
279 Dexp.add
280 (di, Dexp.word (WordX.one seqIndexWordSize),
281 seqIndexWordSize))
282 in
283 Dexp.disjoin
284 (Dexp.wordEqual
285 (di, dlen, seqIndexWordSize),
286 Dexp.conjoin
287 (equalExp (sub (dvec1, di), sub (dvec2, di), ty),
288 Dexp.call {args = args,
289 func = loop,
290 ty = Type.bool}))
291 end
292 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
293 val blocks = Vector.fromList blocks
294 in
295 val _ =
296 newFunction {args = args,
297 blocks = blocks,
298 mayInline = true,
299 name = loop,
300 raises = NONE,
301 returns = returns,
302 start = start}
303 end
304 in
305 ()
306 end
307 and vectorEqualFunc (ty: Type.t): Func.t =
308 case getVectorEqualFunc ty of
309 SOME f => f
310 | NONE =>
311 let
312 val name = Func.newString "vectorEqual"
313 val _ = setVectorEqualFunc (ty, SOME name)
314 val () = mkVectorEqualFunc {name = name, ty = ty, doEq = true}
315 in
316 name
317 end
318 and intInfEqualFunc (): Func.t =
319 case getIntInfEqualFunc () of
320 SOME f => f
321 | NONE =>
322 let
323 val intInfEqual = Func.newString "intInfEqual"
324 val _ = setIntInfEqualFunc (SOME intInfEqual)
325
326 val bws = WordSize.bigIntInfWord ()
327 val sws = WordSize.smallIntInfWord ()
328
329 val bigIntInfEqual = Func.newString "bigIntInfEqual"
330 val () = mkVectorEqualFunc {name = bigIntInfEqual,
331 ty = Type.word bws,
332 doEq = false}
333
334 local
335 val arg1 = (Var.newNoname (), Type.intInf)
336 val arg2 = (Var.newNoname (), Type.intInf)
337 val args = Vector.new2 (arg1, arg2)
338 val darg1 = Dexp.var arg1
339 val darg2 = Dexp.var arg2
340 fun toWord dx =
341 Dexp.primApp
342 {prim = Prim.intInfToWord,
343 targs = Vector.new0 (),
344 args = Vector.new1 dx,
345 ty = Type.word sws}
346 fun toVector dx =
347 Dexp.primApp
348 {prim = Prim.intInfToVector,
349 targs = Vector.new0 (),
350 args = Vector.new1 dx,
351 ty = Type.vector (Type.word bws)}
352 val one = Dexp.word (WordX.one sws)
353 val body =
354 Dexp.disjoin
355 (Dexp.eq (darg1, darg2, Type.intInf),
356 Dexp.casee
357 {test = Dexp.wordEqual (Dexp.andb (Dexp.orb (toWord darg1, toWord darg2, sws), one, sws), one, sws),
358 ty = Type.bool,
359 default = NONE,
360 cases =
361 (Dexp.Con o Vector.new2)
362 ({con = Con.truee,
363 args = Vector.new0 (),
364 body = Dexp.falsee},
365 {con = Con.falsee,
366 args = Vector.new0 (),
367 body =
368 Dexp.call {func = bigIntInfEqual,
369 args = Vector.new2 (toVector darg1, toVector darg2),
370 ty = Type.bool}})})
371 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
372 val blocks = Vector.fromList blocks
373 in
374 val _ =
375 newFunction {args = args,
376 blocks = blocks,
377 mayInline = true,
378 name = intInfEqual,
379 raises = NONE,
380 returns = returns,
381 start = start}
382 end
383 in
384 intInfEqual
385 end
386 and equalExp (e1: Dexp.t, e2: Dexp.t, ty: Type.t): Dexp.t =
387 Dexp.name (e1, fn x1 =>
388 Dexp.name (e2, fn x2 => equal (x1, x2, ty)))
389 and equal (x1: Var.t, x2: Var.t, ty: Type.t): Dexp.t =
390 let
391 val dx1 = Dexp.var (x1, ty)
392 val dx2 = Dexp.var (x2, ty)
393 fun prim (p, targs) =
394 Dexp.primApp {prim = p,
395 targs = targs,
396 args = Vector.new2 (dx1, dx2),
397 ty = Type.bool}
398 fun eq () = prim (Prim.eq, Vector.new1 ty)
399 fun hasConstArg () = #isConst (varInfo x1) orelse #isConst (varInfo x2)
400 in
401 case Type.dest ty of
402 Type.Array _ => eq ()
403 | Type.CPointer => prim (Prim.cpointerEqual, Vector.new0 ())
404 | Type.Datatype tycon =>
405 if isEnum tycon orelse hasConstArg ()
406 then eq ()
407 else Dexp.call {func = equalTyconFunc tycon,
408 args = Vector.new2 (dx1, dx2),
409 ty = Type.bool}
410 | Type.IntInf =>
411 if hasConstArg ()
412 then eq ()
413 else Dexp.call {func = intInfEqualFunc (),
414 args = Vector.new2 (dx1, dx2),
415 ty = Type.bool}
416 | Type.Real rs =>
417 let
418 val ws = WordSize.fromBits (RealSize.bits rs)
419 fun toWord dx =
420 Dexp.primApp
421 {prim = Prim.realCastToWord (rs, ws),
422 targs = Vector.new0 (),
423 args = Vector.new1 dx,
424 ty = Type.word ws}
425 in
426 Dexp.wordEqual (toWord dx1, toWord dx2, ws)
427 end
428 | Type.Ref _ => eq ()
429 | Type.Thread => eq ()
430 | Type.Tuple tys =>
431 let
432 val max = Vector.length tys - 1
433 (* test components i, i+1, ... *)
434 fun loop (i: int): Dexp.t =
435 if i > max
436 then Dexp.truee
437 else let
438 val ty = Vector.sub (tys, i)
439 fun select dx =
440 Dexp.select {tuple = dx,
441 offset = i,
442 ty = ty}
443 in
444 Dexp.conjoin
445 (equalExp (select dx1, select dx2, ty),
446 loop (i + 1))
447 end
448 in
449 loop 0
450 end
451 | Type.Vector ty =>
452 Dexp.call {func = vectorEqualFunc ty,
453 args = Vector.new2 (dx1, dx2),
454 ty = Type.bool}
455 | Type.Weak _ => eq ()
456 | Type.Word ws => prim (Prim.wordEqual ws, Vector.new0 ())
457 end
458
459 val _ =
460 Vector.foreach
461 (datatypes, fn Datatype.T {tycon, cons} =>
462 setTyconInfo (tycon,
463 {isEnum = Vector.forall (cons, fn {args, ...} =>
464 Vector.isEmpty args),
465 cons = cons}))
466 fun setBind (Statement.T {exp, var, ...}) =
467 let
468 fun const () =
469 case var of
470 NONE => ()
471 | SOME x => setVarInfo (x, {isConst = true})
472 in
473 case exp of
474 Const c =>
475 (case c of
476 Const.IntInf i =>
477 (case Const.IntInfRep.fromIntInf i of
478 Const.IntInfRep.Big _ => ()
479 | Const.IntInfRep.Small _ => const ())
480 | Const.Word _ => const ()
481 | _ => ())
482 | ConApp {args, ...} =>
483 if Vector.isEmpty args then const () else ()
484 | _ => ()
485 end
486 val _ = Vector.foreach (globals, setBind)
487 val () =
488 List.foreach
489 (functions, fn f =>
490 let
491 val {name, blocks, ...} = Function.dest f
492 in
493 Vector.foreach
494 (blocks, fn Block.T {label, statements, ...} =>
495 let
496 fun setHasEqual () =
497 (setFuncInfo (name, {hasEqual = true})
498 ; setLabelInfo (label, {hasEqual = true}))
499 in
500 Vector.foreach
501 (statements, fn stmt as Statement.T {exp, ...} =>
502 (setBind stmt;
503 case exp of
504 PrimApp {prim, ...} =>
505 (case Prim.name prim of
506 Prim.Name.MLton_eq => setHasEqual ()
507 | Prim.Name.MLton_equal => setHasEqual ()
508 | _ => ())
509 | _ => ()))
510 end)
511 end)
512 fun doit blocks =
513 let
514 val blocks =
515 Vector.fold
516 (blocks, [],
517 fn (block as Block.T {label, args, statements, transfer}, blocks) =>
518 if not (#hasEqual (labelInfo label))
519 then block::blocks
520 else
521 let
522 fun finish ({label, args, statements}, transfer) =
523 Block.T {label = label,
524 args = args,
525 statements = Vector.fromListRev statements,
526 transfer = transfer}
527 val (blocks, las) =
528 Vector.fold
529 (statements,
530 (blocks, {label = label, args = args, statements = []}),
531 fn (stmt as Statement.T {exp, var, ...},
532 (blocks, las as {label, args, statements})) =>
533 let
534 fun normal () = (blocks,
535 {label = label,
536 args = args,
537 statements = stmt::statements})
538 fun adds ss = (blocks,
539 {label = label,
540 args = args,
541 statements = ss @ statements})
542 in
543 case exp of
544 PrimApp {prim, targs, args, ...} =>
545 (case (Prim.name prim, Vector.length targs) of
546 (Prim.Name.MLton_eq, 1) =>
547 (case Type.dest (Vector.first targs) of
548 Type.CPointer =>
549 let
550 val cp0 = Vector.sub (args, 0)
551 val cp1 = Vector.sub (args, 1)
552 val cpointerEqStmt =
553 Statement.T
554 {var = var,
555 ty = Type.bool,
556 exp = Exp.PrimApp
557 {prim = Prim.cpointerEqual,
558 targs = Vector.new0 (),
559 args = Vector.new2 (cp0,cp1)}}
560 in
561 adds [cpointerEqStmt]
562 end
563 | Type.Real rs =>
564 let
565 val ws = WordSize.fromBits (RealSize.bits rs)
566 val wt = Type.word ws
567 val r0 = Vector.sub (args, 0)
568 val r1 = Vector.sub (args, 1)
569 val w0 = Var.newNoname ()
570 val w1 = Var.newNoname ()
571 fun realCastToWordStmt (r, w) =
572 Statement.T
573 {var = SOME w,
574 ty = wt,
575 exp = Exp.PrimApp
576 {prim = Prim.realCastToWord (rs, ws),
577 targs = Vector.new0 (),
578 args = Vector.new1 r}}
579 val wordEqStmt =
580 Statement.T
581 {var = var,
582 ty = Type.bool,
583 exp = Exp.PrimApp
584 {prim = Prim.wordEqual ws,
585 targs = Vector.new0 (),
586 args = Vector.new2 (w0,w1)}}
587 in
588 adds [wordEqStmt,
589 realCastToWordStmt (r1, w1),
590 realCastToWordStmt (r0, w0)]
591 end
592 | Type.Word ws =>
593 let
594 val w0 = Vector.sub (args, 0)
595 val w1 = Vector.sub (args, 1)
596 val wordEqStmt =
597 Statement.T
598 {var = var,
599 ty = Type.bool,
600 exp = Exp.PrimApp
601 {prim = Prim.wordEqual ws,
602 targs = Vector.new0 (),
603 args = Vector.new2 (w0,w1)}}
604 in
605 adds [wordEqStmt]
606 end
607 | _ => normal ())
608 | (Prim.Name.MLton_equal, 1) =>
609 let
610 val ty = Vector.sub (targs, 0)
611 fun arg i = Vector.sub (args, i)
612 val l = Label.newNoname ()
613 val (start',bs') =
614 Dexp.linearizeGoto
615 (equal (arg 0, arg 1, ty),
616 Handler.Dead,
617 l)
618 in
619 (finish (las,
620 Goto {dst = start',
621 args = Vector.new0 ()})
622 :: (bs' @ blocks),
623 {label = l,
624 args = Vector.new1 (valOf var, Type.bool),
625 statements = []})
626 end
627 | _ => normal ())
628 | _ => normal ()
629 end)
630 in
631 finish (las, transfer)
632 :: blocks
633 end)
634 in
635 Vector.fromList blocks
636 end
637 val functions =
638 List.revMap
639 (functions, fn f =>
640 let
641 val {args, blocks, mayInline, name, raises, returns, start} =
642 Function.dest f
643 val f =
644 if #hasEqual (funcInfo name)
645 then Function.new {args = args,
646 blocks = doit blocks,
647 mayInline = mayInline,
648 name = name,
649 raises = raises,
650 returns = returns,
651 start = start}
652 else f
653 val () = Function.clear f
654 in
655 f
656 end)
657 val program =
658 Program.T {datatypes = datatypes,
659 globals = globals,
660 functions = (!newFunctions) @ functions,
661 main = main}
662 val _ = destroyVectorEqualFunc ()
663 val _ = Program.clearTop program
664 in
665 program
666 end
667
668end