5d5f68b9bfce80424a7ec84d672474a1b76f9892
[bpt/coccinelle.git] / parsing_cocci / type_infer.ml
1 (*
2 * Copyright 2010, INRIA, University of Copenhagen
3 * Julia Lawall, Rene Rydhof Hansen, Gilles Muller, Nicolas Palix
4 * Copyright 2005-2009, Ecole des Mines de Nantes, University of Copenhagen
5 * Yoann Padioleau, Julia Lawall, Rene Rydhof Hansen, Henrik Stuart, Gilles Muller, Nicolas Palix
6 * This file is part of Coccinelle.
7 *
8 * Coccinelle is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, according to version 2 of the License.
11 *
12 * Coccinelle is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with Coccinelle. If not, see <http://www.gnu.org/licenses/>.
19 *
20 * The authors reserve the right to distribute this or future versions of
21 * Coccinelle under other licenses.
22 *)
23
24
25 module T = Type_cocci
26 module Ast = Ast_cocci
27 module Ast0 = Ast0_cocci
28 module V0 = Visitor_ast0
29 module VT0 = Visitor_ast0_types
30
31 (* Type inference:
32 Just propagates information based on declarations. Could try to infer
33 more precise information about expression metavariables, but not sure it is
34 worth it. The most obvious goal is to distinguish between test expressions
35 that have pointer, integer, and boolean type when matching isomorphisms,
36 but perhaps other needs will become apparent. *)
37
38 (* "functions" that return a boolean value *)
39 let bool_functions = ["likely";"unlikely"]
40
41 let err wrapped ty s =
42 T.typeC ty; Format.print_newline();
43 failwith (Printf.sprintf "line %d: %s" (Ast0.get_line wrapped) s)
44
45 type id = Id of string | Meta of Ast.meta_name
46
47 let int_type = T.BaseType(T.IntType)
48 let void_type = T.BaseType(T.VoidType)
49 let bool_type = T.BaseType(T.BoolType)
50 let char_type = T.BaseType(T.CharType)
51 let float_type = T.BaseType(T.FloatType)
52 let size_type = T.BaseType(T.SizeType)
53 let ssize_type = T.BaseType(T.SSizeType)
54 let ptrdiff_type = T.BaseType(T.PtrDiffType)
55
56 let rec lub_type t1 t2 =
57 match (t1,t2) with
58 (None,None) -> None
59 | (None,Some t) -> t2
60 | (Some t,None) -> t1
61 | (Some t1,Some t2) ->
62 let rec loop = function
63 (T.Unknown,t2) -> t2
64 | (t1,T.Unknown) -> t1
65 | (T.ConstVol(cv1,ty1),T.ConstVol(cv2,ty2)) when cv1 = cv2 ->
66 T.ConstVol(cv1,loop(ty1,ty2))
67
68 (* pad: in pointer arithmetic, as in ptr+1, the lub must be ptr *)
69 | (T.Pointer(ty1),T.Pointer(ty2)) ->
70 T.Pointer(loop(ty1,ty2))
71 | (ty1,T.Pointer(ty2)) -> T.Pointer(ty2)
72 | (T.Pointer(ty1),ty2) -> T.Pointer(ty1)
73
74 | (T.Array(ty1),T.Array(ty2)) -> T.Array(loop(ty1,ty2))
75 | (T.TypeName(s1),t2) -> t2
76 | (t1,T.TypeName(s1)) -> t1
77 | (t1,_) -> t1 in (* arbitrarily pick the first, assume type correct *)
78 Some (loop (t1,t2))
79
80 let lub_envs envs =
81 List.fold_left
82 (function acc ->
83 function env ->
84 List.fold_left
85 (function acc ->
86 function (var,ty) ->
87 let (relevant,irrelevant) =
88 List.partition (function (x,_) -> x = var) acc in
89 match relevant with
90 [] -> (var,ty)::acc
91 | [(x,ty1)] ->
92 (match lub_type (Some ty) (Some ty1) with
93 Some new_ty -> (var,new_ty)::irrelevant
94 | None -> irrelevant)
95 | _ -> failwith "bad type environment")
96 acc env)
97 [] envs
98
99 let rec propagate_types env =
100 let option_default = None in
101 let bind x y = option_default in (* no generic way of combining types *)
102
103 let ident r k i =
104 match Ast0.unwrap i with
105 Ast0.Id(id) ->
106 (try Some(List.assoc (Id(Ast0.unwrap_mcode id)) env)
107 with Not_found -> None)
108 | Ast0.MetaId(id,_,_,_) ->
109 (try Some(List.assoc (Meta(Ast0.unwrap_mcode id)) env)
110 with Not_found -> None)
111 | Ast0.DisjId(_,id_list,_,_) ->
112 let types = List.map Ast0.get_type id_list in
113 let combined = List.fold_left lub_type None types in
114 (match combined with
115 None -> None
116 | Some t ->
117 List.iter (function i -> Ast0.set_type i (Some t)) id_list;
118 Some t)
119 | _ -> k i in
120
121 let strip_cv = function
122 Some (T.ConstVol(_,t)) -> Some t
123 | t -> t in
124
125 (* types that might be integer types. should char be allowed? *)
126 let rec is_int_type = function
127 T.BaseType(T.IntType)
128 | T.BaseType(T.LongType)
129 | T.BaseType(T.ShortType)
130 | T.MetaType(_,_,_)
131 | T.TypeName _
132 | T.EnumName _
133 | T.SignedT(_,None) -> true
134 | T.SignedT(_,Some ty) -> is_int_type ty
135 | _ -> false in
136
137 let expression r k e =
138 let res = k e in
139 let ty =
140 match Ast0.unwrap e with
141 (* pad: the type of id is set in the ident visitor *)
142 Ast0.Ident(id) -> Ast0.set_type e res; res
143 | Ast0.Constant(const) ->
144 (match Ast0.unwrap_mcode const with
145 Ast.String(_) -> Some (T.Pointer(char_type))
146 | Ast.Char(_) -> Some (char_type)
147 | Ast.Int(_) -> Some (int_type)
148 | Ast.Float(_) -> Some (float_type))
149 (* pad: note that in C can do either ptr(...) or ( *ptr)(...)
150 * so I am not sure this code is enough.
151 *)
152 | Ast0.FunCall(fn,lp,args,rp) ->
153 (match Ast0.get_type fn with
154 Some (T.FunctionPointer(ty)) -> Some ty
155 | _ ->
156 (match Ast0.unwrap fn with
157 Ast0.Ident(id) ->
158 (match Ast0.unwrap id with
159 Ast0.Id(id) ->
160 if List.mem (Ast0.unwrap_mcode id) bool_functions
161 then Some(bool_type)
162 else None
163 | _ -> None)
164 | _ -> None))
165 | Ast0.Assignment(exp1,op,exp2,_) ->
166 let ty = lub_type (Ast0.get_type exp1) (Ast0.get_type exp2) in
167 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty; ty
168 | Ast0.CondExpr(exp1,why,Some exp2,colon,exp3) ->
169 let ty = lub_type (Ast0.get_type exp2) (Ast0.get_type exp3) in
170 Ast0.set_type exp2 ty; Ast0.set_type exp3 ty; ty
171 | Ast0.CondExpr(exp1,why,None,colon,exp3) -> Ast0.get_type exp3
172 | Ast0.Postfix(exp,op) | Ast0.Infix(exp,op) -> (* op is dec or inc *)
173 Ast0.get_type exp
174 | Ast0.Unary(exp,op) ->
175 (match Ast0.unwrap_mcode op with
176 Ast.GetRef ->
177 (match Ast0.get_type exp with
178 None -> Some (T.Pointer(T.Unknown))
179 | Some t -> Some (T.Pointer(t)))
180 | Ast.GetRefLabel -> Some (T.Pointer(void_type))
181 | Ast.DeRef ->
182 (match Ast0.get_type exp with
183 Some (T.Pointer(t)) -> Some t
184 | _ -> None)
185 | Ast.UnPlus -> Ast0.get_type exp
186 | Ast.UnMinus -> Ast0.get_type exp
187 | Ast.Tilde -> Ast0.get_type exp
188 | Ast.Not -> Some(bool_type))
189 | Ast0.Nested(exp1,op,exp2) -> failwith "nested in type inf not possible"
190 | Ast0.Binary(exp1,op,exp2) ->
191 let ty1 = Ast0.get_type exp1 in
192 let ty2 = Ast0.get_type exp2 in
193 let same_type = function
194 (None,None) -> Some (int_type)
195
196 (* pad: pointer arithmetic handling as in ptr+1 *)
197 | (Some (T.Pointer ty1),Some ty2) when is_int_type ty2 ->
198 Some (T.Pointer ty1)
199 | (Some ty1,Some (T.Pointer ty2)) when is_int_type ty1 ->
200 Some (T.Pointer ty2)
201
202 | (t1,t2) ->
203 let ty = lub_type t1 t2 in
204 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty; ty in
205 (match Ast0.unwrap_mcode op with
206 Ast.Arith(op) -> same_type (ty1, ty2)
207 | Ast.Logical(Ast.AndLog) | Ast.Logical(Ast.OrLog) ->
208 Some(bool_type)
209 | Ast.Logical(op) ->
210 let ty = lub_type ty1 ty2 in
211 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty;
212 Some(bool_type))
213 | Ast0.Paren(lp,exp,rp) -> Ast0.get_type exp
214 | Ast0.ArrayAccess(exp1,lb,exp2,rb) ->
215 (match strip_cv (Ast0.get_type exp2) with
216 None -> Ast0.set_type exp2 (Some(int_type))
217 | Some(ty) when is_int_type ty -> ()
218 | Some(Type_cocci.Unknown) ->
219 (* unknown comes from param types, not sure why this
220 is not just None... *)
221 Ast0.set_type exp2 (Some(int_type))
222 | Some ty -> err exp2 ty "bad type for an array index");
223 (match strip_cv (Ast0.get_type exp1) with
224 None -> None
225 | Some (T.Array(ty)) -> Some ty
226 | Some (T.Pointer(ty)) -> Some ty
227 | Some (T.MetaType(_,_,_)) -> None
228 | Some x -> err exp1 x "ill-typed array reference")
229 (* pad: should handle structure one day and look 'field' in environment *)
230 | Ast0.RecordAccess(exp,pt,field) ->
231 (match strip_cv (Ast0.get_type exp) with
232 None -> None
233 | Some (T.StructUnionName(_,_)) -> None
234 | Some (T.TypeName(_)) -> None
235 | Some (T.MetaType(_,_,_)) -> None
236 | Some x -> err exp x "non-structure type in field ref")
237 | Ast0.RecordPtAccess(exp,ar,field) ->
238 (match strip_cv (Ast0.get_type exp) with
239 None -> None
240 | Some (T.Pointer(t)) ->
241 (match strip_cv (Some t) with
242 | Some (T.Unknown) -> None
243 | Some (T.MetaType(_,_,_)) -> None
244 | Some (T.TypeName(_)) -> None
245 | Some (T.StructUnionName(_,_)) -> None
246 | Some x ->
247 err exp (T.Pointer(t))
248 "non-structure pointer type in field ref"
249 | _ -> failwith "not possible")
250 | Some (T.MetaType(_,_,_)) -> None
251 | Some (T.TypeName(_)) -> None
252 | Some x -> err exp x "non-structure pointer type in field ref")
253 | Ast0.Cast(lp,ty,rp,exp) -> Some(Ast0.ast0_type_to_type ty)
254 | Ast0.SizeOfExpr(szf,exp) -> Some(int_type)
255 | Ast0.SizeOfType(szf,lp,ty,rp) -> Some(int_type)
256 | Ast0.TypeExp(ty) -> None
257 | Ast0.Constructor(lp,ty,rp,init) -> Some(Ast0.ast0_type_to_type ty)
258 | Ast0.MetaErr(name,_,_) -> None
259 | Ast0.MetaExpr(name,_,Some [ty],_,_) -> Some ty
260 | Ast0.MetaExpr(name,_,ty,_,_) -> None
261 | Ast0.MetaExprList(name,_,_) -> None
262 | Ast0.EComma(cm) -> None
263 | Ast0.DisjExpr(_,exp_list,_,_) ->
264 let types = List.map Ast0.get_type exp_list in
265 let combined = List.fold_left lub_type None types in
266 (match combined with
267 None -> None
268 | Some t ->
269 List.iter (function e -> Ast0.set_type e (Some t)) exp_list;
270 Some t)
271 | Ast0.NestExpr(starter,expr_dots,ender,None,multi) ->
272 let _ = r.VT0.combiner_rec_expression_dots expr_dots in None
273 | Ast0.NestExpr(starter,expr_dots,ender,Some e,multi) ->
274 let _ = r.VT0.combiner_rec_expression_dots expr_dots in
275 let _ = r.VT0.combiner_rec_expression e in None
276 | Ast0.Edots(_,None) | Ast0.Ecircles(_,None) | Ast0.Estars(_,None) ->
277 None
278 | Ast0.Edots(_,Some e) | Ast0.Ecircles(_,Some e)
279 | Ast0.Estars(_,Some e) ->
280 let _ = r.VT0.combiner_rec_expression e in None
281 | Ast0.OptExp(exp) -> Ast0.get_type exp
282 | Ast0.UniqueExp(exp) -> Ast0.get_type exp in
283 Ast0.set_type e ty;
284 ty in
285
286 let rec strip id =
287 match Ast0.unwrap id with
288 Ast0.Id(name) -> [Id(Ast0.unwrap_mcode name)]
289 | Ast0.MetaId(name,_,_,_) -> [Meta(Ast0.unwrap_mcode name)]
290 | Ast0.MetaFunc(name,_,_) -> [Meta(Ast0.unwrap_mcode name)]
291 | Ast0.MetaLocalFunc(name,_,_) -> [Meta(Ast0.unwrap_mcode name)]
292 | Ast0.DisjId(_,id_list,_,_) -> List.concat (List.map strip id_list)
293 | Ast0.OptIdent(id) -> strip id
294 | Ast0.UniqueIdent(id) -> strip id in
295
296 let process_whencode notfn allfn exp = function
297 Ast0.WhenNot(x) -> let _ = notfn x in ()
298 | Ast0.WhenAlways(x) -> let _ = allfn x in ()
299 | Ast0.WhenModifier(_) -> ()
300 | Ast0.WhenNotTrue(x) -> let _ = exp x in ()
301 | Ast0.WhenNotFalse(x) -> let _ = exp x in () in
302
303 (* assume that all of the declarations are at the beginning of a statement
304 list, which is required by C, but not actually required by the cocci
305 parser *)
306 let rec process_statement_list r acc = function
307 [] -> acc
308 | (s::ss) ->
309 (match Ast0.unwrap s with
310 Ast0.Decl(_,decl) ->
311 let new_acc = (process_decl acc decl)@acc in
312 process_statement_list r new_acc ss
313 | Ast0.Dots(_,wc) ->
314 (* why is this case here? why is there none for nests? *)
315 List.iter
316 (process_whencode r.VT0.combiner_rec_statement_dots
317 r.VT0.combiner_rec_statement r.VT0.combiner_rec_expression)
318 wc;
319 process_statement_list r acc ss
320 | Ast0.Disj(_,statement_dots_list,_,_) ->
321 let new_acc =
322 lub_envs
323 (List.map
324 (function x -> process_statement_list r acc (Ast0.undots x))
325 statement_dots_list) in
326 process_statement_list r new_acc ss
327 | _ ->
328 let _ = (propagate_types acc).VT0.combiner_rec_statement s in
329 process_statement_list r acc ss)
330
331 and process_decl env decl =
332 match Ast0.unwrap decl with
333 Ast0.MetaDecl(_,_) | Ast0.MetaField(_,_)
334 | Ast0.MetaFieldList(_,_,_) -> []
335 | Ast0.Init(_,ty,id,_,exp,_) ->
336 let _ =
337 (propagate_types env).VT0.combiner_rec_initialiser exp in
338 let ty = Ast0.ast0_type_to_type ty in
339 List.map (function i -> (i,ty)) (strip id)
340 | Ast0.UnInit(_,ty,id,_) ->
341 let ty = Ast0.ast0_type_to_type ty in
342 List.map (function i -> (i,ty)) (strip id)
343 | Ast0.MacroDecl(_,_,_,_,_) -> []
344 | Ast0.TyDecl(ty,_) -> []
345 (* pad: should handle typedef one day and add a binding *)
346 | Ast0.Typedef(_,_,_,_) -> []
347 | Ast0.DisjDecl(_,disjs,_,_) ->
348 List.concat(List.map (process_decl env) disjs)
349 | Ast0.Ddots(_,_) -> [] (* not in a statement list anyway *)
350 | Ast0.OptDecl(decl) -> process_decl env decl
351 | Ast0.UniqueDecl(decl) -> process_decl env decl in
352
353 let statement_dots r k d =
354 match Ast0.unwrap d with
355 Ast0.DOTS(l) | Ast0.CIRCLES(l) | Ast0.STARS(l) ->
356 let _ = process_statement_list r env l in option_default in
357
358 let post_bool exp =
359 let rec process_test exp =
360 match (Ast0.unwrap exp,Ast0.get_type exp) with
361 (Ast0.Edots(_,_),_) -> None
362 | (Ast0.NestExpr(_,_,_,_,_),_) -> None
363 | (Ast0.MetaExpr(_,_,_,_,_),_) ->
364 (* if a type is known, it is specified in the decl *)
365 None
366 | (Ast0.Paren(lp,exp,rp),None) -> process_test exp
367 (* the following doesn't seem like a good idea - triggers int isos
368 on all test expressions *)
369 (*| (_,None) -> Some (int_type) *)
370 | _ -> None in
371 let new_expty = process_test exp in
372 (match new_expty with
373 None -> () (* leave things as they are *)
374 | Some ty -> Ast0.set_type exp new_expty) in
375
376 let statement r k s =
377 match Ast0.unwrap s with
378 Ast0.FunDecl(_,fninfo,name,lp,params,rp,lbrace,body,rbrace) ->
379 let rec get_binding p =
380 match Ast0.unwrap p with
381 Ast0.Param(ty,Some id) ->
382 let ty = Ast0.ast0_type_to_type ty in
383 List.map (function i -> (i,ty)) (strip id)
384 | Ast0.OptParam(param) -> get_binding param
385 | _ -> [] in
386 let fenv = List.concat (List.map get_binding (Ast0.undots params)) in
387 (propagate_types (fenv@env)).VT0.combiner_rec_statement_dots body
388 | Ast0.IfThen(_,_,exp,_,_,_) | Ast0.IfThenElse(_,_,exp,_,_,_,_,_)
389 | Ast0.While(_,_,exp,_,_,_) | Ast0.Do(_,_,_,_,exp,_,_)
390 | Ast0.For(_,_,_,_,Some exp,_,_,_,_,_) ->
391 let _ = k s in
392 post_bool exp;
393 None
394 | Ast0.Switch(_,_,exp,_,_,decls,cases,_) ->
395 let senv = process_statement_list r env (Ast0.undots decls) in
396 let res =
397 (propagate_types (senv@env)).VT0.combiner_rec_case_line_dots cases in
398 post_bool exp;
399 res
400 | _ -> k s
401
402 and case_line r k c =
403 match Ast0.unwrap c with
404 Ast0.Case(case,exp,colon,code) ->
405 let _ = k c in
406 (match Ast0.get_type exp with
407 None -> Ast0.set_type exp (Some (int_type))
408 | _ -> ());
409 None
410 | _ -> k c in
411
412 V0.combiner bind option_default
413 {V0.combiner_functions with
414 VT0.combiner_dotsstmtfn = statement_dots;
415 VT0.combiner_identfn = ident;
416 VT0.combiner_exprfn = expression;
417 VT0.combiner_stmtfn = statement;
418 VT0.combiner_casefn = case_line}
419
420 let type_infer code =
421 let prop = propagate_types [(Id("NULL"),T.Pointer(T.Unknown))] in
422 let fn = prop.VT0.combiner_rec_top_level in
423 let _ = List.map fn code in
424 ()