ba66ec3ff1212d71de9edf0cd2ba4ac6d3b05a4f
[bpt/guile.git] / module / language / cps / type-fold.scm
1 ;;; Abstract constant folding on CPS
2 ;;; Copyright (C) 2014, 2015 Free Software Foundation, Inc.
3 ;;;
4 ;;; This library is free software: you can redistribute it and/or modify
5 ;;; it under the terms of the GNU Lesser General Public License as
6 ;;; published by the Free Software Foundation, either version 3 of the
7 ;;; License, or (at your option) any later version.
8 ;;;
9 ;;; This library is distributed in the hope that it will be useful, but
10 ;;; WITHOUT ANY WARRANTY; without even the implied warranty of
11 ;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 ;;; Lesser General Public License for more details.
13 ;;;
14 ;;; You should have received a copy of the GNU Lesser General Public
15 ;;; License along with this program. If not, see
16 ;;; <http://www.gnu.org/licenses/>.
17
18 ;;; Commentary:
19 ;;;
20 ;;; This pass uses the abstract interpretation provided by type analysis
21 ;;; to fold constant values and type predicates. It is most profitably
22 ;;; run after CSE, to take advantage of scalar replacement.
23 ;;;
24 ;;; Code:
25
26 (define-module (language cps type-fold)
27 #:use-module (ice-9 match)
28 #:use-module (language cps)
29 #:use-module (language cps dfg)
30 #:use-module (language cps renumber)
31 #:use-module (language cps types)
32 #:use-module (system base target)
33 #:export (type-fold))
34
35
36 \f
37
38 ;; Branch folders.
39
40 (define &scalar-types
41 (logior &exact-integer &flonum &char &unspecified &false &true &nil &null))
42
43 (define *branch-folders* (make-hash-table))
44
45 (define-syntax-rule (define-branch-folder name f)
46 (hashq-set! *branch-folders* 'name f))
47
48 (define-syntax-rule (define-branch-folder-alias to from)
49 (hashq-set! *branch-folders* 'to (hashq-ref *branch-folders* 'from)))
50
51 (define-syntax-rule (define-unary-branch-folder (name arg min max) body ...)
52 (define-branch-folder name (lambda (arg min max) body ...)))
53
54 (define-syntax-rule (define-binary-branch-folder (name arg0 min0 max0
55 arg1 min1 max1)
56 body ...)
57 (define-branch-folder name (lambda (arg0 min0 max0 arg1 min1 max1) body ...)))
58
59 (define-syntax-rule (define-unary-type-predicate-folder name &type)
60 (define-unary-branch-folder (name type min max)
61 (let ((type* (logand type &type)))
62 (cond
63 ((zero? type*) (values #t #f))
64 ((eqv? type type*) (values #t #t))
65 (else (values #f #f))))))
66
67 ;; All the cases that are in compile-bytecode.
68 (define-unary-type-predicate-folder pair? &pair)
69 (define-unary-type-predicate-folder null? &null)
70 (define-unary-type-predicate-folder nil? &nil)
71 (define-unary-type-predicate-folder symbol? &symbol)
72 (define-unary-type-predicate-folder variable? &box)
73 (define-unary-type-predicate-folder vector? &vector)
74 (define-unary-type-predicate-folder struct? &struct)
75 (define-unary-type-predicate-folder string? &string)
76 (define-unary-type-predicate-folder number? &number)
77 (define-unary-type-predicate-folder char? &char)
78
79 (define-binary-branch-folder (eq? type0 min0 max0 type1 min1 max1)
80 (cond
81 ((or (zero? (logand type0 type1)) (< max0 min1) (< max1 min0))
82 (values #t #f))
83 ((and (eqv? type0 type1)
84 (eqv? min0 min1 max0 max1)
85 (zero? (logand type0 (1- type0)))
86 (not (zero? (logand type0 &scalar-types))))
87 (values #t #t))
88 (else
89 (values #f #f))))
90 (define-branch-folder-alias eqv? eq?)
91 (define-branch-folder-alias equal? eq?)
92
93 (define (compare-ranges type0 min0 max0 type1 min1 max1)
94 (and (zero? (logand (logior type0 type1) (lognot &real)))
95 (cond ((< max0 min1) '<)
96 ((> min0 max1) '>)
97 ((= min0 max0 min1 max1) '=)
98 ((<= max0 min1) '<=)
99 ((>= min0 max1) '>=)
100 (else #f))))
101
102 (define-binary-branch-folder (< type0 min0 max0 type1 min1 max1)
103 (case (compare-ranges type0 min0 max0 type1 min1 max1)
104 ((<) (values #t #t))
105 ((= >= >) (values #t #f))
106 (else (values #f #f))))
107
108 (define-binary-branch-folder (<= type0 min0 max0 type1 min1 max1)
109 (case (compare-ranges type0 min0 max0 type1 min1 max1)
110 ((< <= =) (values #t #t))
111 ((>) (values #t #f))
112 (else (values #f #f))))
113
114 (define-binary-branch-folder (= type0 min0 max0 type1 min1 max1)
115 (case (compare-ranges type0 min0 max0 type1 min1 max1)
116 ((=) (values #t #t))
117 ((< >) (values #t #f))
118 (else (values #f #f))))
119
120 (define-binary-branch-folder (>= type0 min0 max0 type1 min1 max1)
121 (case (compare-ranges type0 min0 max0 type1 min1 max1)
122 ((> >= =) (values #t #t))
123 ((<) (values #t #f))
124 (else (values #f #f))))
125
126 (define-binary-branch-folder (> type0 min0 max0 type1 min1 max1)
127 (case (compare-ranges type0 min0 max0 type1 min1 max1)
128 ((>) (values #t #t))
129 ((= <= <) (values #t #f))
130 (else (values #f #f))))
131
132 (define-binary-branch-folder (logtest type0 min0 max0 type1 min1 max1)
133 (define (logand-min a b)
134 (if (< a b 0)
135 (min a b)
136 0))
137 (define (logand-max a b)
138 (if (< a b 0)
139 0
140 (max a b)))
141 (if (and (= min0 max0) (= min1 max1) (eqv? type0 type1 &exact-integer))
142 (values #t (logtest min0 min1))
143 (values #f #f)))
144
145
146 \f
147
148 ;; Strength reduction.
149
150 (define *primcall-reducers* (make-hash-table))
151
152 (define-syntax-rule (define-primcall-reducer name f)
153 (hashq-set! *primcall-reducers* 'name f))
154
155 (define-syntax-rule (define-unary-primcall-reducer (name dfg k src
156 arg type min max)
157 body ...)
158 (define-primcall-reducer name
159 (lambda (dfg k src arg type min max) body ...)))
160
161 (define-syntax-rule (define-binary-primcall-reducer (name dfg k src
162 arg0 type0 min0 max0
163 arg1 type1 min1 max1)
164 body ...)
165 (define-primcall-reducer name
166 (lambda (dfg k src arg0 type0 min0 max0 arg1 type1 min1 max1) body ...)))
167
168 (define-binary-primcall-reducer (mul dfg k src
169 arg0 type0 min0 max0
170 arg1 type1 min1 max1)
171 (define (negate arg)
172 (let-fresh (kzero) (zero)
173 (build-cps-term
174 ($letk ((kzero ($kargs (#f) (zero)
175 ($continue k src ($primcall 'sub (zero arg))))))
176 ($continue kzero src ($const 0))))))
177 (define (zero)
178 (build-cps-term ($continue k src ($const 0))))
179 (define (identity arg)
180 (build-cps-term ($continue k src ($values (arg)))))
181 (define (double arg)
182 (build-cps-term ($continue k src ($primcall 'add (arg arg)))))
183 (define (power-of-two constant arg)
184 (let ((n (let lp ((bits 0) (constant constant))
185 (if (= constant 1) bits (lp (1+ bits) (ash constant -1))))))
186 (let-fresh (kbits) (bits)
187 (build-cps-term
188 ($letk ((kbits ($kargs (#f) (bits)
189 ($continue k src ($primcall 'ash (arg bits))))))
190 ($continue kbits src ($const n)))))))
191 (define (mul/constant constant constant-type arg arg-type)
192 (and (or (= constant-type &exact-integer) (= constant-type arg-type))
193 (case constant
194 ;; (* arg -1) -> (- 0 arg)
195 ((-1) (negate arg))
196 ;; (* arg 0) -> 0 if arg is not a flonum or complex
197 ((0) (and (= constant-type &exact-integer)
198 (zero? (logand arg-type
199 (lognot (logior &flonum &complex))))
200 (zero)))
201 ;; (* arg 1) -> arg
202 ((1) (identity arg))
203 ;; (* arg 2) -> (+ arg arg)
204 ((2) (double arg))
205 (else (and (= constant-type arg-type &exact-integer)
206 (positive? constant)
207 (zero? (logand constant (1- constant)))
208 (power-of-two constant arg))))))
209 (cond
210 ((logtest (logior type0 type1) (lognot &number)) #f)
211 ((= min0 max0) (mul/constant min0 type0 arg1 type1))
212 ((= min1 max1) (mul/constant min1 type1 arg0 type0))
213 (else #f)))
214
215 (define-binary-primcall-reducer (logbit? dfg k src
216 arg0 type0 min0 max0
217 arg1 type1 min1 max1)
218 (define (convert-to-logtest bool-term)
219 (let-fresh (kt kf kmask kbool) (mask bool)
220 (build-cps-term
221 ($letk ((kt ($kargs () ()
222 ($continue kbool src ($const #t))))
223 (kf ($kargs () ()
224 ($continue kbool src ($const #f))))
225 (kbool ($kargs (#f) (bool)
226 ,(bool-term bool)))
227 (kmask ($kargs (#f) (mask)
228 ($continue kf src
229 ($branch kt ($primcall 'logtest (mask arg1)))))))
230 ,(if (eq? min0 max0)
231 ($continue kmask src ($const (ash 1 min0)))
232 (let-fresh (kone) (one)
233 (build-cps-term
234 ($letk ((kone ($kargs (#f) (one)
235 ($continue kmask src
236 ($primcall 'ash (one arg0))))))
237 ($continue kone src ($const 1))))))))))
238 ;; Hairiness because we are converting from a primcall with unknown
239 ;; arity to a branching primcall.
240 (let ((positive-fixnum-bits (- (* (target-word-size) 8) 3)))
241 (and (= type0 &exact-integer)
242 (<= 0 min0 positive-fixnum-bits)
243 (<= 0 max0 positive-fixnum-bits)
244 (match (lookup-cont k dfg)
245 (($ $kreceive arity kargs)
246 (match arity
247 (($ $arity (_) () (not #f) () #f)
248 (convert-to-logtest
249 (lambda (bool)
250 (let-fresh (knil) (nil)
251 (build-cps-term
252 ($letk ((knil ($kargs (#f) (nil)
253 ($continue kargs src
254 ($values (bool nil))))))
255 ($continue knil src ($const '()))))))))
256 (_
257 (convert-to-logtest
258 (lambda (bool)
259 (build-cps-term
260 ($continue k src ($primcall 'values (bool)))))))))
261 (($ $ktail)
262 (convert-to-logtest
263 (lambda (bool)
264 (build-cps-term
265 ($continue k src ($primcall 'return (bool)))))))))))
266
267
268 \f
269
270 ;;
271
272 (define (fold-and-reduce fun dfg min-label min-var)
273 (define (scalar-value type val)
274 (cond
275 ((eqv? type &exact-integer) val)
276 ((eqv? type &flonum) (exact->inexact val))
277 ((eqv? type &char) (integer->char val))
278 ((eqv? type &unspecified) *unspecified*)
279 ((eqv? type &false) #f)
280 ((eqv? type &true) #t)
281 ((eqv? type &nil) #nil)
282 ((eqv? type &null) '())
283 (else (error "unhandled type" type val))))
284 (let* ((typev (infer-types fun dfg))
285 (label-count ((make-local-cont-folder label-count)
286 (lambda (k cont label-count) (1+ label-count))
287 fun 0))
288 (folded? (make-bitvector label-count #f))
289 (folded-values (make-vector label-count #f))
290 (reduced-terms (make-vector label-count #f)))
291 (define (label->idx label) (- label min-label))
292 (define (var->idx var) (- var min-var))
293 (define (maybe-reduce-primcall! label k src name args)
294 (let* ((reducer (hashq-ref *primcall-reducers* name)))
295 (when reducer
296 (vector-set!
297 reduced-terms
298 (label->idx label)
299 (match args
300 ((arg0)
301 (call-with-values (lambda () (lookup-pre-type typev label arg0))
302 (lambda (type0 min0 max0)
303 (reducer dfg k src arg0 type0 min0 max0))))
304 ((arg0 arg1)
305 (call-with-values (lambda () (lookup-pre-type typev label arg0))
306 (lambda (type0 min0 max0)
307 (call-with-values (lambda () (lookup-pre-type typev label arg1))
308 (lambda (type1 min1 max1)
309 (reducer dfg k src arg0 type0 min0 max0
310 arg1 type1 min1 max1))))))
311 (_ #f))))))
312 (define (maybe-fold-value! label name def)
313 (call-with-values (lambda () (lookup-post-type typev label def 0))
314 (lambda (type min max)
315 (cond
316 ((and (not (zero? type))
317 (zero? (logand type (1- type)))
318 (zero? (logand type (lognot &scalar-types)))
319 (eqv? min max))
320 (bitvector-set! folded? (label->idx label) #t)
321 (vector-set! folded-values (label->idx label)
322 (scalar-value type min))
323 #t)
324 (else #f)))))
325 (define (maybe-fold-unary-branch! label name arg)
326 (let* ((folder (hashq-ref *branch-folders* name)))
327 (when folder
328 (call-with-values (lambda () (lookup-pre-type typev label arg))
329 (lambda (type min max)
330 (call-with-values (lambda () (folder type min max))
331 (lambda (f? v)
332 (bitvector-set! folded? (label->idx label) f?)
333 (vector-set! folded-values (label->idx label) v))))))))
334 (define (maybe-fold-binary-branch! label name arg0 arg1)
335 (let* ((folder (hashq-ref *branch-folders* name)))
336 (when folder
337 (call-with-values (lambda () (lookup-pre-type typev label arg0))
338 (lambda (type0 min0 max0)
339 (call-with-values (lambda () (lookup-pre-type typev label arg1))
340 (lambda (type1 min1 max1)
341 (call-with-values (lambda ()
342 (folder type0 min0 max0 type1 min1 max1))
343 (lambda (f? v)
344 (bitvector-set! folded? (label->idx label) f?)
345 (vector-set! folded-values (label->idx label) v))))))))))
346 (define (visit-cont cont)
347 (match cont
348 (($ $cont label ($ $kargs _ _ body))
349 (visit-term body label))
350 (($ $cont label ($ $kclause arity body alternate))
351 (visit-cont body)
352 (visit-cont alternate))
353 (_ #f)))
354 (define (visit-term term label)
355 (match term
356 (($ $letk conts body)
357 (for-each visit-cont conts)
358 (visit-term body label))
359 (($ $continue k src ($ $primcall name args))
360 ;; We might be able to fold primcalls that define a value.
361 (match (lookup-cont k dfg)
362 (($ $kargs (_) (def))
363 ;(pk 'maybe-fold-value src name args)
364 (unless (maybe-fold-value! label name def)
365 (maybe-reduce-primcall! label k src name args)))
366 (_
367 (maybe-reduce-primcall! label k src name args))))
368 (($ $continue kf src ($ $branch kt ($ $primcall name args)))
369 ;; We might be able to fold primcalls that branch.
370 ;(pk 'maybe-fold-branch label src name args)
371 (match args
372 ((arg)
373 (maybe-fold-unary-branch! label name arg))
374 ((arg0 arg1)
375 (maybe-fold-binary-branch! label name arg0 arg1))))
376 (_ #f)))
377 (when typev
378 (match fun
379 (($ $cont kfun ($ $kfun src meta self tail clause))
380 (visit-cont clause))))
381 (values folded? folded-values reduced-terms)))
382
383 (define (fold-constants* fun dfg)
384 (match fun
385 (($ $cont min-label ($ $kfun _ _ min-var))
386 (call-with-values (lambda () (fold-and-reduce fun dfg min-label min-var))
387 (lambda (folded? folded-values reduced-terms)
388 (define (label->idx label) (- label min-label))
389 (define (var->idx var) (- var min-var))
390 (define (visit-cont cont)
391 (rewrite-cps-cont cont
392 (($ $cont label ($ $kargs names syms body))
393 (label ($kargs names syms ,(visit-term body label))))
394 (($ $cont label ($ $kclause arity body alternate))
395 (label ($kclause ,arity ,(visit-cont body)
396 ,(and alternate (visit-cont alternate)))))
397 (_ ,cont)))
398 (define (visit-term term label)
399 (rewrite-cps-term term
400 (($ $letk conts body)
401 ($letk ,(map visit-cont conts)
402 ,(visit-term body label)))
403 (($ $continue k src (and fun ($ $fun)))
404 ($continue k src ,(visit-fun fun)))
405 (($ $continue k src ($ $rec names vars funs))
406 ($continue k src ($rec names vars (map visit-fun funs))))
407 (($ $continue k src (and primcall ($ $primcall name args)))
408 ,(cond
409 ((bitvector-ref folded? (label->idx label))
410 (let ((val (vector-ref folded-values (label->idx label))))
411 ;; Uncomment for debugging.
412 ;; (pk 'folded src primcall val)
413 (let-fresh (k*) (v*)
414 ;; Rely on DCE to elide this expression, if
415 ;; possible.
416 (build-cps-term
417 ($letk ((k* ($kargs (#f) (v*)
418 ($continue k src ($const val)))))
419 ($continue k* src ,primcall))))))
420 (else
421 (or (vector-ref reduced-terms (label->idx label))
422 term))))
423 (($ $continue kf src ($ $branch kt ($ $primcall)))
424 ,(if (bitvector-ref folded? (label->idx label))
425 ;; Folded branch.
426 (let ((val (vector-ref folded-values (label->idx label))))
427 (build-cps-term
428 ($continue (if val kt kf) src ($values ()))))
429 term))
430 (_ ,term)))
431 (define (visit-fun fun)
432 (rewrite-cps-exp fun
433 (($ $fun body)
434 ($fun ,(fold-constants* body dfg)))))
435 (rewrite-cps-cont fun
436 (($ $cont kfun ($ $kfun src meta self tail clause))
437 (kfun ($kfun src meta self ,tail ,(visit-cont clause))))))))))
438
439 (define (type-fold fun)
440 (let* ((fun (renumber fun))
441 (dfg (compute-dfg fun)))
442 (with-fresh-name-state-from-dfg dfg
443 (fold-constants* fun dfg))))