1 (* Copyright (C) 2009,2014 Matthew Fluet.
2 * Copyright (C) 1999-2008 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 * Copyright (C) 1997-2000 NEC Research Institute.
6 * MLton is released under a BSD-style license.
7 * See the file MLton-LICENSE for details.
10 functor PolyEqual (S: SSA_TRANSFORM_STRUCTS): SSA_TRANSFORM =
16 * This pass implements polymorphic equality.
18 * For each datatype tycon and vector type, it builds an equality function and
19 * translates calls to MLton_equal into calls to that function.
21 * Also generates calls to primitive wordEqual.
23 * For tuples, it does the equality test inline. I.E. it does not create
24 * a separate equality function for each tuple type.
26 * All equality functions are only created if necessary, i.e. if equality
27 * is actually used at a type.
30 * - For datatype tycons that are enumerations, do not build a case dispatch,
31 * just use eq, since you know the backend will represent these as ints.
32 * - Deep equality always does an eq test first.
33 * - If one argument to = is a constant and the type will get translated to
34 * an IntOrPointer, then just use eq instead of the full equality. This is
35 * important for implementing code like the following efficiently:
36 * if x = 0 ... (where x is an IntInf.int)
38 * Also convert pointer equality on scalar types to type specific primitives.
47 fun conjoin (e1: t, e2: t): t =
49 cases = Con (Vector.new2 ({con = Con.truee,
50 args = Vector.new0 (),
53 args = Vector.new0 (),
58 fun disjoin (e1: t, e2:t): t =
60 cases = Con (Vector.new2 ({con = Con.truee,
61 args = Vector.new0 (),
64 args = Vector.new0 (),
71 fn (e1: t, e2: t, s) =>
72 primApp {prim = prim s,
73 targs = Vector.new0 (),
74 args = Vector.new2 (e1, e2),
77 val add = mk Prim.wordAdd
78 val andb = mk Prim.wordAndb
79 val orb = mk Prim.wordOrb
82 fun wordEqual (e1: t, e2: t, s): t =
83 primApp {prim = Prim.wordEqual s,
84 targs = Vector.new0 (),
85 args = Vector.new2 (e1, e2),
89 fun transform (Program.T {datatypes, globals, functions, main}) =
91 val {get = funcInfo: Func.t -> {hasEqual: bool},
92 set = setFuncInfo, ...} =
93 Property.getSet (Func.plist, Property.initConst {hasEqual = false})
94 val {get = labelInfo: Label.t -> {hasEqual: bool},
95 set = setLabelInfo, ...} =
96 Property.getSet (Label.plist, Property.initConst {hasEqual = false})
97 val {get = varInfo: Var.t -> {isConst: bool},
98 set = setVarInfo, ...} =
99 Property.getSetOnce (Var.plist, Property.initConst {isConst = false})
100 val {get = tyconInfo: Tycon.t -> {isEnum: bool,
102 args: Type.t vector} vector},
103 set = setTyconInfo, ...} =
105 (Tycon.plist, Property.initRaise ("PolyEqual.tyconInfo", Tycon.layout))
106 val isEnum = #isEnum o tyconInfo
107 val tyconCons = #cons o tyconInfo
108 val {get = getTyconEqualFunc: Tycon.t -> Func.t option,
109 set = setTyconEqualFunc, ...} =
110 Property.getSet (Tycon.plist, Property.initConst NONE)
111 val {get = getVectorEqualFunc: Type.t -> Func.t option,
112 set = setVectorEqualFunc,
113 destroy = destroyVectorEqualFunc} =
114 Property.destGetSet (Type.plist, Property.initConst NONE)
115 val (getIntInfEqualFunc: unit -> Func.t option,
116 setIntInfEqualFunc: Func.t option -> unit) =
120 (fn () => !r, fn fo => r := fo)
122 val returns = SOME (Vector.new1 Type.bool)
123 val seqIndexWordSize = WordSize.seqIndex ()
124 val seqIndexTy = Type.word seqIndexWordSize
125 val newFunctions: Function.t list ref = ref []
127 List.push (newFunctions,
128 Function.profile (Function.new z,
129 SourceInfo.polyEqual))
130 fun equalTyconFunc (tycon: Tycon.t): Func.t =
131 case getTyconEqualFunc tycon of
136 Func.newString (concat ["equal_", Tycon.originalName tycon])
137 val _ = setTyconEqualFunc (tycon, SOME name)
138 val ty = Type.datatypee tycon
139 val arg1 = (Var.newNoname (), ty)
140 val arg2 = (Var.newNoname (), ty)
141 val args = Vector.new2 (arg1, arg2)
142 val darg1 = Dexp.var arg1
143 val darg2 = Dexp.var arg2
144 val cons = tyconCons tycon
147 (Dexp.eq (Dexp.var arg1, Dexp.var arg2, ty),
151 default = (if Vector.exists (cons, fn {args, ...} =>
153 then SOME Dexp.falsee
158 (cons, fn {con, args} =>
159 if Vector.isEmpty args
164 Vector.map (args, fn ty =>
165 (Var.newNoname (), ty))
176 default = if 1 = Vector.length cons
178 else SOME Dexp.falsee,
187 fn ((x, ty), (y, _), de) =>
188 Dexp.conjoin (de, equal (x, y, ty)))})}}
190 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
191 val blocks = Vector.fromList blocks
193 newFunction {args = args,
203 and mkVectorEqualFunc {name: Func.t,
204 ty: Type.t, doEq: bool}: unit =
206 val loop = Func.newString (Func.originalName name ^ "Loop")
207 (* Build two functions, one that checks the lengths and the
210 val vty = Type.vector ty
212 val vec1 = (Var.newNoname (), vty)
213 val vec2 = (Var.newNoname (), vty)
214 val args = Vector.new2 (vec1, vec2)
215 val dvec1 = Dexp.var vec1
216 val dvec2 = Dexp.var vec2
217 val len1 = (Var.newNoname (), seqIndexTy)
218 val dlen1 = Dexp.var len1
219 val len2 = (Var.newNoname (), seqIndexTy)
220 val dlen2 = Dexp.var len2
225 Dexp.primApp {prim = Prim.vectorLength,
226 targs = Vector.new1 ty,
227 args = Vector.new1 dvec,
228 ty = Type.word seqIndexWordSize}
231 {decs = [{var = #1 len1, exp = length dvec1},
232 {var = #1 len2, exp = length dvec2}],
235 (Dexp.wordEqual (dlen1, dlen2, seqIndexWordSize),
239 (dvec1, dvec2, dlen1,
240 Dexp.word (WordX.zero seqIndexWordSize)),
244 then Dexp.disjoin (Dexp.eq (dvec1, dvec2, vty), body)
247 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
248 val blocks = Vector.fromList blocks
251 newFunction {args = args,
260 val vec1 = (Var.newNoname (), vty)
261 val vec2 = (Var.newNoname (), vty)
262 val len = (Var.newNoname (), seqIndexTy)
263 val i = (Var.newNoname (), seqIndexTy)
264 val args = Vector.new4 (vec1, vec2, len, i)
265 val dvec1 = Dexp.var vec1
266 val dvec2 = Dexp.var vec2
267 val dlen = Dexp.var len
272 Dexp.primApp {prim = Prim.vectorSub,
273 targs = Vector.new1 ty,
274 args = Vector.new2 (dvec, di),
280 (di, Dexp.word (WordX.one seqIndexWordSize),
285 (di, dlen, seqIndexWordSize),
287 (equalExp (sub (dvec1, di), sub (dvec2, di), ty),
288 Dexp.call {args = args,
292 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
293 val blocks = Vector.fromList blocks
296 newFunction {args = args,
307 and vectorEqualFunc (ty: Type.t): Func.t =
308 case getVectorEqualFunc ty of
312 val name = Func.newString "vectorEqual"
313 val _ = setVectorEqualFunc (ty, SOME name)
314 val () = mkVectorEqualFunc {name = name, ty = ty, doEq = true}
318 and intInfEqualFunc (): Func.t =
319 case getIntInfEqualFunc () of
323 val intInfEqual = Func.newString "intInfEqual"
324 val _ = setIntInfEqualFunc (SOME intInfEqual)
326 val bws = WordSize.bigIntInfWord ()
327 val sws = WordSize.smallIntInfWord ()
329 val bigIntInfEqual = Func.newString "bigIntInfEqual"
330 val () = mkVectorEqualFunc {name = bigIntInfEqual,
335 val arg1 = (Var.newNoname (), Type.intInf)
336 val arg2 = (Var.newNoname (), Type.intInf)
337 val args = Vector.new2 (arg1, arg2)
338 val darg1 = Dexp.var arg1
339 val darg2 = Dexp.var arg2
342 {prim = Prim.intInfToWord,
343 targs = Vector.new0 (),
344 args = Vector.new1 dx,
348 {prim = Prim.intInfToVector,
349 targs = Vector.new0 (),
350 args = Vector.new1 dx,
351 ty = Type.vector (Type.word bws)}
352 val one = Dexp.word (WordX.one sws)
355 (Dexp.eq (darg1, darg2, Type.intInf),
357 {test = Dexp.wordEqual (Dexp.andb (Dexp.orb (toWord darg1, toWord darg2, sws), one, sws), one, sws),
361 (Dexp.Con o Vector.new2)
363 args = Vector.new0 (),
366 args = Vector.new0 (),
368 Dexp.call {func = bigIntInfEqual,
369 args = Vector.new2 (toVector darg1, toVector darg2),
371 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
372 val blocks = Vector.fromList blocks
375 newFunction {args = args,
386 and equalExp (e1: Dexp.t, e2: Dexp.t, ty: Type.t): Dexp.t =
387 Dexp.name (e1, fn x1 =>
388 Dexp.name (e2, fn x2 => equal (x1, x2, ty)))
389 and equal (x1: Var.t, x2: Var.t, ty: Type.t): Dexp.t =
391 val dx1 = Dexp.var (x1, ty)
392 val dx2 = Dexp.var (x2, ty)
393 fun prim (p, targs) =
394 Dexp.primApp {prim = p,
396 args = Vector.new2 (dx1, dx2),
398 fun eq () = prim (Prim.eq, Vector.new1 ty)
399 fun hasConstArg () = #isConst (varInfo x1) orelse #isConst (varInfo x2)
402 Type.Array _ => eq ()
403 | Type.CPointer => prim (Prim.cpointerEqual, Vector.new0 ())
404 | Type.Datatype tycon =>
405 if isEnum tycon orelse hasConstArg ()
407 else Dexp.call {func = equalTyconFunc tycon,
408 args = Vector.new2 (dx1, dx2),
413 else Dexp.call {func = intInfEqualFunc (),
414 args = Vector.new2 (dx1, dx2),
418 val ws = WordSize.fromBits (RealSize.bits rs)
421 {prim = Prim.realCastToWord (rs, ws),
422 targs = Vector.new0 (),
423 args = Vector.new1 dx,
426 Dexp.wordEqual (toWord dx1, toWord dx2, ws)
428 | Type.Ref _ => eq ()
429 | Type.Thread => eq ()
432 val max = Vector.length tys - 1
433 (* test components i, i+1, ... *)
434 fun loop (i: int): Dexp.t =
438 val ty = Vector.sub (tys, i)
440 Dexp.select {tuple = dx,
445 (equalExp (select dx1, select dx2, ty),
452 Dexp.call {func = vectorEqualFunc ty,
453 args = Vector.new2 (dx1, dx2),
455 | Type.Weak _ => eq ()
456 | Type.Word ws => prim (Prim.wordEqual ws, Vector.new0 ())
461 (datatypes, fn Datatype.T {tycon, cons} =>
463 {isEnum = Vector.forall (cons, fn {args, ...} =>
464 Vector.isEmpty args),
466 fun setBind (Statement.T {exp, var, ...}) =
471 | SOME x => setVarInfo (x, {isConst = true})
477 (case Const.IntInfRep.fromIntInf i of
478 Const.IntInfRep.Big _ => ()
479 | Const.IntInfRep.Small _ => const ())
480 | Const.Word _ => const ()
482 | ConApp {args, ...} =>
483 if Vector.isEmpty args then const () else ()
486 val _ = Vector.foreach (globals, setBind)
491 val {name, blocks, ...} = Function.dest f
494 (blocks, fn Block.T {label, statements, ...} =>
497 (setFuncInfo (name, {hasEqual = true})
498 ; setLabelInfo (label, {hasEqual = true}))
501 (statements, fn stmt as Statement.T {exp, ...} =>
504 PrimApp {prim, ...} =>
505 (case Prim.name prim of
506 Prim.Name.MLton_eq => setHasEqual ()
507 | Prim.Name.MLton_equal => setHasEqual ()
517 fn (block as Block.T {label, args, statements, transfer}, blocks) =>
518 if not (#hasEqual (labelInfo label))
522 fun finish ({label, args, statements}, transfer) =
523 Block.T {label = label,
525 statements = Vector.fromListRev statements,
530 (blocks, {label = label, args = args, statements = []}),
531 fn (stmt as Statement.T {exp, var, ...},
532 (blocks, las as {label, args, statements})) =>
534 fun normal () = (blocks,
537 statements = stmt::statements})
538 fun adds ss = (blocks,
541 statements = ss @ statements})
544 PrimApp {prim, targs, args, ...} =>
545 (case (Prim.name prim, Vector.length targs) of
546 (Prim.Name.MLton_eq, 1) =>
547 (case Type.dest (Vector.first targs) of
550 val cp0 = Vector.sub (args, 0)
551 val cp1 = Vector.sub (args, 1)
557 {prim = Prim.cpointerEqual,
558 targs = Vector.new0 (),
559 args = Vector.new2 (cp0,cp1)}}
561 adds [cpointerEqStmt]
565 val ws = WordSize.fromBits (RealSize.bits rs)
566 val wt = Type.word ws
567 val r0 = Vector.sub (args, 0)
568 val r1 = Vector.sub (args, 1)
569 val w0 = Var.newNoname ()
570 val w1 = Var.newNoname ()
571 fun realCastToWordStmt (r, w) =
576 {prim = Prim.realCastToWord (rs, ws),
577 targs = Vector.new0 (),
578 args = Vector.new1 r}}
584 {prim = Prim.wordEqual ws,
585 targs = Vector.new0 (),
586 args = Vector.new2 (w0,w1)}}
589 realCastToWordStmt (r1, w1),
590 realCastToWordStmt (r0, w0)]
594 val w0 = Vector.sub (args, 0)
595 val w1 = Vector.sub (args, 1)
601 {prim = Prim.wordEqual ws,
602 targs = Vector.new0 (),
603 args = Vector.new2 (w0,w1)}}
608 | (Prim.Name.MLton_equal, 1) =>
610 val ty = Vector.sub (targs, 0)
611 fun arg i = Vector.sub (args, i)
612 val l = Label.newNoname ()
615 (equal (arg 0, arg 1, ty),
621 args = Vector.new0 ()})
624 args = Vector.new1 (valOf var, Type.bool),
631 finish (las, transfer)
635 Vector.fromList blocks
641 val {args, blocks, mayInline, name, raises, returns, start} =
644 if #hasEqual (funcInfo name)
645 then Function.new {args = args,
646 blocks = doit blocks,
647 mayInline = mayInline,
653 val () = Function.clear f
658 Program.T {datatypes = datatypes,
660 functions = (!newFunctions) @ functions,
662 val _ = destroyVectorEqualFunc ()
663 val _ = Program.clearTop program