Import Upstream version 20180207
[hcoop/debian/mlton.git] / mlton / ssa /
1(* Copyright (C) 2009,2014 Matthew Fluet.
2 * Copyright (C) 1999-2008 Henry Cejtin, Matthew Fluet, Suresh
3 * Jagannathan, and Stephen Weeks.
4 * Copyright (C) 1997-2000 NEC Research Institute.
5 *
6 * MLton is released under a BSD-style license.
7 * See the file MLton-LICENSE for details.
8 *)
13open S
16 * This pass implements polymorphic equality.
17 *
18 * For each datatype tycon and vector type, it builds an equality function and
19 * translates calls to MLton_equal into calls to that function.
20 *
21 * Also generates calls to primitive wordEqual.
22 *
23 * For tuples, it does the equality test inline. I.E. it does not create
24 * a separate equality function for each tuple type.
25 *
26 * All equality functions are only created if necessary, i.e. if equality
27 * is actually used at a type.
28 *
29 * Optimizations:
30 * - For datatype tycons that are enumerations, do not build a case dispatch,
31 * just use eq, since you know the backend will represent these as ints.
32 * - Deep equality always does an eq test first.
33 * - If one argument to = is a constant and the type will get translated to
34 * an IntOrPointer, then just use eq instead of the full equality. This is
35 * important for implementing code like the following efficiently:
36 * if x = 0 ... (where x is an
37 *
38 * Also convert pointer equality on scalar types to type specific primitives.
39 *)
41open Exp Transfer
43structure Dexp =
44 struct
45 open DirectExp
47 fun conjoin (e1: t, e2: t): t =
48 casee {test = e1,
49 cases = Con (Vector.new2 ({con = Con.truee,
50 args = Vector.new0 (),
51 body = e2},
52 {con = Con.falsee,
53 args = Vector.new0 (),
54 body = falsee})),
55 default = NONE,
56 ty = Type.bool}
58 fun disjoin (e1: t, e2:t): t =
59 casee {test = e1,
60 cases = Con (Vector.new2 ({con = Con.truee,
61 args = Vector.new0 (),
62 body = truee},
63 {con = Con.falsee,
64 args = Vector.new0 (),
65 body = e2})),
66 default = NONE,
67 ty = Type.bool}
69 local
70 fun mk prim =
71 fn (e1: t, e2: t, s) =>
72 primApp {prim = prim s,
73 targs = Vector.new0 (),
74 args = Vector.new2 (e1, e2),
75 ty = Type.word s}
76 in
77 val add = mk Prim.wordAdd
78 val andb = mk Prim.wordAndb
79 val orb = mk Prim.wordOrb
80 end
82 fun wordEqual (e1: t, e2: t, s): t =
83 primApp {prim = Prim.wordEqual s,
84 targs = Vector.new0 (),
85 args = Vector.new2 (e1, e2),
86 ty = Type.bool}
87 end
89fun transform (Program.T {datatypes, globals, functions, main}) =
90 let
91 val {get = funcInfo: Func.t -> {hasEqual: bool},
92 set = setFuncInfo, ...} =
93 Property.getSet (Func.plist, Property.initConst {hasEqual = false})
94 val {get = labelInfo: Label.t -> {hasEqual: bool},
95 set = setLabelInfo, ...} =
96 Property.getSet (Label.plist, Property.initConst {hasEqual = false})
97 val {get = varInfo: Var.t -> {isConst: bool},
98 set = setVarInfo, ...} =
99 Property.getSetOnce (Var.plist, Property.initConst {isConst = false})
100 val {get = tyconInfo: Tycon.t -> {isEnum: bool,
101 cons: {con: Con.t,
102 args: Type.t vector} vector},
103 set = setTyconInfo, ...} =
104 Property.getSetOnce
105 (Tycon.plist, Property.initRaise ("PolyEqual.tyconInfo", Tycon.layout))
106 val isEnum = #isEnum o tyconInfo
107 val tyconCons = #cons o tyconInfo
108 val {get = getTyconEqualFunc: Tycon.t -> Func.t option,
109 set = setTyconEqualFunc, ...} =
110 Property.getSet (Tycon.plist, Property.initConst NONE)
111 val {get = getVectorEqualFunc: Type.t -> Func.t option,
112 set = setVectorEqualFunc,
113 destroy = destroyVectorEqualFunc} =
114 Property.destGetSet (Type.plist, Property.initConst NONE)
115 val (getIntInfEqualFunc: unit -> Func.t option,
116 setIntInfEqualFunc: Func.t option -> unit) =
117 let
118 val r = ref NONE
119 in
120 (fn () => !r, fn fo => r := fo)
121 end
122 val returns = SOME (Vector.new1 Type.bool)
123 val seqIndexWordSize = WordSize.seqIndex ()
124 val seqIndexTy = Type.word seqIndexWordSize
125 val newFunctions: Function.t list ref = ref []
126 fun newFunction z =
127 List.push (newFunctions,
128 Function.profile ( z,
129 SourceInfo.polyEqual))
130 fun equalTyconFunc (tycon: Tycon.t): Func.t =
131 case getTyconEqualFunc tycon of
132 SOME f => f
133 | NONE =>
134 let
135 val name =
136 Func.newString (concat ["equal_", Tycon.originalName tycon])
137 val _ = setTyconEqualFunc (tycon, SOME name)
138 val ty = Type.datatypee tycon
139 val arg1 = (Var.newNoname (), ty)
140 val arg2 = (Var.newNoname (), ty)
141 val args = Vector.new2 (arg1, arg2)
142 val darg1 = Dexp.var arg1
143 val darg2 = Dexp.var arg2
144 val cons = tyconCons tycon
145 val body =
146 Dexp.disjoin
147 (Dexp.eq (Dexp.var arg1, Dexp.var arg2, ty),
148 Dexp.casee
149 {test = darg1,
150 ty = Type.bool,
151 default = (if Vector.exists (cons, fn {args, ...} =>
152 Vector.isEmpty args)
153 then SOME Dexp.falsee
154 else NONE),
155 cases =
156 Dexp.Con
157 (Vector.keepAllMap
158 (cons, fn {con, args} =>
159 if Vector.isEmpty args
160 then NONE
161 else
162 let
163 fun makeArgs () =
164 (args, fn ty =>
165 (Var.newNoname (), ty))
166 val xs = makeArgs ()
167 val ys = makeArgs ()
168 in
169 SOME
170 {con = con,
171 args = xs,
172 body =
173 Dexp.casee
174 {test = darg2,
175 ty = Type.bool,
176 default = if 1 = Vector.length cons
177 then NONE
178 else SOME Dexp.falsee,
179 cases =
180 Dexp.Con
181 (Vector.new1
182 {con = con,
183 args = ys,
184 body =
185 Vector.fold2
186 (xs, ys, Dexp.truee,
187 fn ((x, ty), (y, _), de) =>
188 Dexp.conjoin (de, equal (x, y, ty)))})}}
189 end))})
190 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
191 val blocks = Vector.fromList blocks
192 val _ =
193 newFunction {args = args,
194 blocks = blocks,
195 mayInline = true,
196 name = name,
197 raises = NONE,
198 returns = returns,
199 start = start}
200 in
201 name
202 end
203 and mkVectorEqualFunc {name: Func.t,
204 ty: Type.t, doEq: bool}: unit =
205 let
206 val loop = Func.newString (Func.originalName name ^ "Loop")
207 (* Build two functions, one that checks the lengths and the
208 * other that loops.
209 *)
210 val vty = Type.vector ty
211 local
212 val vec1 = (Var.newNoname (), vty)
213 val vec2 = (Var.newNoname (), vty)
214 val args = Vector.new2 (vec1, vec2)
215 val dvec1 = Dexp.var vec1
216 val dvec2 = Dexp.var vec2
217 val len1 = (Var.newNoname (), seqIndexTy)
218 val dlen1 = Dexp.var len1
219 val len2 = (Var.newNoname (), seqIndexTy)
220 val dlen2 = Dexp.var len2
222 val body =
223 let
224 fun length dvec =
225 Dexp.primApp {prim = Prim.vectorLength,
226 targs = Vector.new1 ty,
227 args = Vector.new1 dvec,
228 ty = Type.word seqIndexWordSize}
229 val body =
230 Dexp.lett
231 {decs = [{var = #1 len1, exp = length dvec1},
232 {var = #1 len2, exp = length dvec2}],
233 body =
234 Dexp.conjoin
235 (Dexp.wordEqual (dlen1, dlen2, seqIndexWordSize),
237 {func = loop,
238 args = Vector.new4
239 (dvec1, dvec2, dlen1,
240 Dexp.word ( seqIndexWordSize)),
241 ty = Type.bool})}
242 in
243 if doEq
244 then Dexp.disjoin (Dexp.eq (dvec1, dvec2, vty), body)
245 else body
246 end
247 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
248 val blocks = Vector.fromList blocks
249 in
250 val _ =
251 newFunction {args = args,
252 blocks = blocks,
253 mayInline = true,
254 name = name,
255 raises = NONE,
256 returns = returns,
257 start = start}
258 end
259 local
260 val vec1 = (Var.newNoname (), vty)
261 val vec2 = (Var.newNoname (), vty)
262 val len = (Var.newNoname (), seqIndexTy)
263 val i = (Var.newNoname (), seqIndexTy)
264 val args = Vector.new4 (vec1, vec2, len, i)
265 val dvec1 = Dexp.var vec1
266 val dvec2 = Dexp.var vec2
267 val dlen = Dexp.var len
268 val di = Dexp.var i
269 val body =
270 let
271 fun sub (dvec, di) =
272 Dexp.primApp {prim = Prim.vectorSub,
273 targs = Vector.new1 ty,
274 args = Vector.new2 (dvec, di),
275 ty = ty}
276 val args =
277 Vector.new4
278 (dvec1, dvec2, dlen,
279 Dexp.add
280 (di, Dexp.word ( seqIndexWordSize),
281 seqIndexWordSize))
282 in
283 Dexp.disjoin
284 (Dexp.wordEqual
285 (di, dlen, seqIndexWordSize),
286 Dexp.conjoin
287 (equalExp (sub (dvec1, di), sub (dvec2, di), ty),
288 {args = args,
289 func = loop,
290 ty = Type.bool}))
291 end
292 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
293 val blocks = Vector.fromList blocks
294 in
295 val _ =
296 newFunction {args = args,
297 blocks = blocks,
298 mayInline = true,
299 name = loop,
300 raises = NONE,
301 returns = returns,
302 start = start}
303 end
304 in
305 ()
306 end
307 and vectorEqualFunc (ty: Type.t): Func.t =
308 case getVectorEqualFunc ty of
309 SOME f => f
310 | NONE =>
311 let
312 val name = Func.newString "vectorEqual"
313 val _ = setVectorEqualFunc (ty, SOME name)
314 val () = mkVectorEqualFunc {name = name, ty = ty, doEq = true}
315 in
316 name
317 end
318 and intInfEqualFunc (): Func.t =
319 case getIntInfEqualFunc () of
320 SOME f => f
321 | NONE =>
322 let
323 val intInfEqual = Func.newString "intInfEqual"
324 val _ = setIntInfEqualFunc (SOME intInfEqual)
326 val bws = WordSize.bigIntInfWord ()
327 val sws = WordSize.smallIntInfWord ()
329 val bigIntInfEqual = Func.newString "bigIntInfEqual"
330 val () = mkVectorEqualFunc {name = bigIntInfEqual,
331 ty = Type.word bws,
332 doEq = false}
334 local
335 val arg1 = (Var.newNoname (), Type.intInf)
336 val arg2 = (Var.newNoname (), Type.intInf)
337 val args = Vector.new2 (arg1, arg2)
338 val darg1 = Dexp.var arg1
339 val darg2 = Dexp.var arg2
340 fun toWord dx =
341 Dexp.primApp
342 {prim = Prim.intInfToWord,
343 targs = Vector.new0 (),
344 args = Vector.new1 dx,
345 ty = Type.word sws}
346 fun toVector dx =
347 Dexp.primApp
348 {prim = Prim.intInfToVector,
349 targs = Vector.new0 (),
350 args = Vector.new1 dx,
351 ty = Type.vector (Type.word bws)}
352 val one = Dexp.word ( sws)
353 val body =
354 Dexp.disjoin
355 (Dexp.eq (darg1, darg2, Type.intInf),
356 Dexp.casee
357 {test = Dexp.wordEqual (Dexp.andb (Dexp.orb (toWord darg1, toWord darg2, sws), one, sws), one, sws),
358 ty = Type.bool,
359 default = NONE,
360 cases =
361 (Dexp.Con o Vector.new2)
362 ({con = Con.truee,
363 args = Vector.new0 (),
364 body = Dexp.falsee},
365 {con = Con.falsee,
366 args = Vector.new0 (),
367 body =
368 {func = bigIntInfEqual,
369 args = Vector.new2 (toVector darg1, toVector darg2),
370 ty = Type.bool}})})
371 val (start, blocks) = Dexp.linearize (body, Handler.Caller)
372 val blocks = Vector.fromList blocks
373 in
374 val _ =
375 newFunction {args = args,
376 blocks = blocks,
377 mayInline = true,
378 name = intInfEqual,
379 raises = NONE,
380 returns = returns,
381 start = start}
382 end
383 in
384 intInfEqual
385 end
386 and equalExp (e1: Dexp.t, e2: Dexp.t, ty: Type.t): Dexp.t =
387 (e1, fn x1 =>
388 (e2, fn x2 => equal (x1, x2, ty)))
389 and equal (x1: Var.t, x2: Var.t, ty: Type.t): Dexp.t =
390 let
391 val dx1 = Dexp.var (x1, ty)
392 val dx2 = Dexp.var (x2, ty)
393 fun prim (p, targs) =
394 Dexp.primApp {prim = p,
395 targs = targs,
396 args = Vector.new2 (dx1, dx2),
397 ty = Type.bool}
398 fun eq () = prim (Prim.eq, Vector.new1 ty)
399 fun hasConstArg () = #isConst (varInfo x1) orelse #isConst (varInfo x2)
400 in
401 case Type.dest ty of
402 Type.Array _ => eq ()
403 | Type.CPointer => prim (Prim.cpointerEqual, Vector.new0 ())
404 | Type.Datatype tycon =>
405 if isEnum tycon orelse hasConstArg ()
406 then eq ()
407 else {func = equalTyconFunc tycon,
408 args = Vector.new2 (dx1, dx2),
409 ty = Type.bool}
410 | Type.IntInf =>
411 if hasConstArg ()
412 then eq ()
413 else {func = intInfEqualFunc (),
414 args = Vector.new2 (dx1, dx2),
415 ty = Type.bool}
416 | Type.Real rs =>
417 let
418 val ws = WordSize.fromBits (RealSize.bits rs)
419 fun toWord dx =
420 Dexp.primApp
421 {prim = Prim.realCastToWord (rs, ws),
422 targs = Vector.new0 (),
423 args = Vector.new1 dx,
424 ty = Type.word ws}
425 in
426 Dexp.wordEqual (toWord dx1, toWord dx2, ws)
427 end
428 | Type.Ref _ => eq ()
429 | Type.Thread => eq ()
430 | Type.Tuple tys =>
431 let
432 val max = Vector.length tys - 1
433 (* test components i, i+1, ... *)
434 fun loop (i: int): Dexp.t =
435 if i > max
436 then Dexp.truee
437 else let
438 val ty = Vector.sub (tys, i)
439 fun select dx =
440 {tuple = dx,
441 offset = i,
442 ty = ty}
443 in
444 Dexp.conjoin
445 (equalExp (select dx1, select dx2, ty),
446 loop (i + 1))
447 end
448 in
449 loop 0
450 end
451 | Type.Vector ty =>
452 {func = vectorEqualFunc ty,
453 args = Vector.new2 (dx1, dx2),
454 ty = Type.bool}
455 | Type.Weak _ => eq ()
456 | Type.Word ws => prim (Prim.wordEqual ws, Vector.new0 ())
457 end
459 val _ =
460 Vector.foreach
461 (datatypes, fn Datatype.T {tycon, cons} =>
462 setTyconInfo (tycon,
463 {isEnum = Vector.forall (cons, fn {args, ...} =>
464 Vector.isEmpty args),
465 cons = cons}))
466 fun setBind (Statement.T {exp, var, ...}) =
467 let
468 fun const () =
469 case var of
470 NONE => ()
471 | SOME x => setVarInfo (x, {isConst = true})
472 in
473 case exp of
474 Const c =>
475 (case c of
476 Const.IntInf i =>
477 (case Const.IntInfRep.fromIntInf i of
478 Const.IntInfRep.Big _ => ()
479 | Const.IntInfRep.Small _ => const ())
480 | Const.Word _ => const ()
481 | _ => ())
482 | ConApp {args, ...} =>
483 if Vector.isEmpty args then const () else ()
484 | _ => ()
485 end
486 val _ = Vector.foreach (globals, setBind)
487 val () =
488 List.foreach
489 (functions, fn f =>
490 let
491 val {name, blocks, ...} = Function.dest f
492 in
493 Vector.foreach
494 (blocks, fn Block.T {label, statements, ...} =>
495 let
496 fun setHasEqual () =
497 (setFuncInfo (name, {hasEqual = true})
498 ; setLabelInfo (label, {hasEqual = true}))
499 in
500 Vector.foreach
501 (statements, fn stmt as Statement.T {exp, ...} =>
502 (setBind stmt;
503 case exp of
504 PrimApp {prim, ...} =>
505 (case prim of
506 Prim.Name.MLton_eq => setHasEqual ()
507 | Prim.Name.MLton_equal => setHasEqual ()
508 | _ => ())
509 | _ => ()))
510 end)
511 end)
512 fun doit blocks =
513 let
514 val blocks =
515 Vector.fold
516 (blocks, [],
517 fn (block as Block.T {label, args, statements, transfer}, blocks) =>
518 if not (#hasEqual (labelInfo label))
519 then block::blocks
520 else
521 let
522 fun finish ({label, args, statements}, transfer) =
523 Block.T {label = label,
524 args = args,
525 statements = Vector.fromListRev statements,
526 transfer = transfer}
527 val (blocks, las) =
528 Vector.fold
529 (statements,
530 (blocks, {label = label, args = args, statements = []}),
531 fn (stmt as Statement.T {exp, var, ...},
532 (blocks, las as {label, args, statements})) =>
533 let
534 fun normal () = (blocks,
535 {label = label,
536 args = args,
537 statements = stmt::statements})
538 fun adds ss = (blocks,
539 {label = label,
540 args = args,
541 statements = ss @ statements})
542 in
543 case exp of
544 PrimApp {prim, targs, args, ...} =>
545 (case ( prim, Vector.length targs) of
546 (Prim.Name.MLton_eq, 1) =>
547 (case Type.dest (Vector.first targs) of
548 Type.CPointer =>
549 let
550 val cp0 = Vector.sub (args, 0)
551 val cp1 = Vector.sub (args, 1)
552 val cpointerEqStmt =
553 Statement.T
554 {var = var,
555 ty = Type.bool,
556 exp = Exp.PrimApp
557 {prim = Prim.cpointerEqual,
558 targs = Vector.new0 (),
559 args = Vector.new2 (cp0,cp1)}}
560 in
561 adds [cpointerEqStmt]
562 end
563 | Type.Real rs =>
564 let
565 val ws = WordSize.fromBits (RealSize.bits rs)
566 val wt = Type.word ws
567 val r0 = Vector.sub (args, 0)
568 val r1 = Vector.sub (args, 1)
569 val w0 = Var.newNoname ()
570 val w1 = Var.newNoname ()
571 fun realCastToWordStmt (r, w) =
572 Statement.T
573 {var = SOME w,
574 ty = wt,
575 exp = Exp.PrimApp
576 {prim = Prim.realCastToWord (rs, ws),
577 targs = Vector.new0 (),
578 args = Vector.new1 r}}
579 val wordEqStmt =
580 Statement.T
581 {var = var,
582 ty = Type.bool,
583 exp = Exp.PrimApp
584 {prim = Prim.wordEqual ws,
585 targs = Vector.new0 (),
586 args = Vector.new2 (w0,w1)}}
587 in
588 adds [wordEqStmt,
589 realCastToWordStmt (r1, w1),
590 realCastToWordStmt (r0, w0)]
591 end
592 | Type.Word ws =>
593 let
594 val w0 = Vector.sub (args, 0)
595 val w1 = Vector.sub (args, 1)
596 val wordEqStmt =
597 Statement.T
598 {var = var,
599 ty = Type.bool,
600 exp = Exp.PrimApp
601 {prim = Prim.wordEqual ws,
602 targs = Vector.new0 (),
603 args = Vector.new2 (w0,w1)}}
604 in
605 adds [wordEqStmt]
606 end
607 | _ => normal ())
608 | (Prim.Name.MLton_equal, 1) =>
609 let
610 val ty = Vector.sub (targs, 0)
611 fun arg i = Vector.sub (args, i)
612 val l = Label.newNoname ()
613 val (start',bs') =
614 Dexp.linearizeGoto
615 (equal (arg 0, arg 1, ty),
616 Handler.Dead,
617 l)
618 in
619 (finish (las,
620 Goto {dst = start',
621 args = Vector.new0 ()})
622 :: (bs' @ blocks),
623 {label = l,
624 args = Vector.new1 (valOf var, Type.bool),
625 statements = []})
626 end
627 | _ => normal ())
628 | _ => normal ()
629 end)
630 in
631 finish (las, transfer)
632 :: blocks
633 end)
634 in
635 Vector.fromList blocks
636 end
637 val functions =
638 List.revMap
639 (functions, fn f =>
640 let
641 val {args, blocks, mayInline, name, raises, returns, start} =
642 Function.dest f
643 val f =
644 if #hasEqual (funcInfo name)
645 then {args = args,
646 blocks = doit blocks,
647 mayInline = mayInline,
648 name = name,
649 raises = raises,
650 returns = returns,
651 start = start}
652 else f
653 val () = Function.clear f
654 in
655 f
656 end)
657 val program =
658 Program.T {datatypes = datatypes,
659 globals = globals,
660 functions = (!newFunctions) @ functions,
661 main = main}
662 val _ = destroyVectorEqualFunc ()
663 val _ = Program.clearTop program
664 in
665 program
666 end