51e6ebed20e0b78bc55fbe9bfd513c01b3d6d3e7
[bpt/coccinelle.git] / parsing_cocci / type_infer.ml
1 (*
2 * Copyright 2010, INRIA, University of Copenhagen
3 * Julia Lawall, Rene Rydhof Hansen, Gilles Muller, Nicolas Palix
4 * Copyright 2005-2009, Ecole des Mines de Nantes, University of Copenhagen
5 * Yoann Padioleau, Julia Lawall, Rene Rydhof Hansen, Henrik Stuart, Gilles Muller, Nicolas Palix
6 * This file is part of Coccinelle.
7 *
8 * Coccinelle is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, according to version 2 of the License.
11 *
12 * Coccinelle is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with Coccinelle. If not, see <http://www.gnu.org/licenses/>.
19 *
20 * The authors reserve the right to distribute this or future versions of
21 * Coccinelle under other licenses.
22 *)
23
24
25 module T = Type_cocci
26 module Ast = Ast_cocci
27 module Ast0 = Ast0_cocci
28 module V0 = Visitor_ast0
29 module VT0 = Visitor_ast0_types
30
31 (* Type inference:
32 Just propagates information based on declarations. Could try to infer
33 more precise information about expression metavariables, but not sure it is
34 worth it. The most obvious goal is to distinguish between test expressions
35 that have pointer, integer, and boolean type when matching isomorphisms,
36 but perhaps other needs will become apparent. *)
37
38 (* "functions" that return a boolean value *)
39 let bool_functions = ["likely";"unlikely"]
40
41 let err wrapped ty s =
42 T.typeC ty; Format.print_newline();
43 failwith (Printf.sprintf "line %d: %s" (Ast0.get_line wrapped) s)
44
45 type id = Id of string | Meta of Ast.meta_name
46
47 let int_type = T.BaseType(T.IntType)
48 let bool_type = T.BaseType(T.BoolType)
49 let char_type = T.BaseType(T.CharType)
50 let float_type = T.BaseType(T.FloatType)
51 let size_type = T.BaseType(T.SizeType)
52 let ssize_type = T.BaseType(T.SSizeType)
53 let ptrdiff_type = T.BaseType(T.PtrDiffType)
54
55 let rec lub_type t1 t2 =
56 match (t1,t2) with
57 (None,None) -> None
58 | (None,Some t) -> t2
59 | (Some t,None) -> t1
60 | (Some t1,Some t2) ->
61 let rec loop = function
62 (T.Unknown,t2) -> t2
63 | (t1,T.Unknown) -> t1
64 | (T.ConstVol(cv1,ty1),T.ConstVol(cv2,ty2)) when cv1 = cv2 ->
65 T.ConstVol(cv1,loop(ty1,ty2))
66
67 (* pad: in pointer arithmetic, as in ptr+1, the lub must be ptr *)
68 | (T.Pointer(ty1),T.Pointer(ty2)) ->
69 T.Pointer(loop(ty1,ty2))
70 | (ty1,T.Pointer(ty2)) -> T.Pointer(ty2)
71 | (T.Pointer(ty1),ty2) -> T.Pointer(ty1)
72
73 | (T.Array(ty1),T.Array(ty2)) -> T.Array(loop(ty1,ty2))
74 | (T.TypeName(s1),t2) -> t2
75 | (t1,T.TypeName(s1)) -> t1
76 | (t1,_) -> t1 in (* arbitrarily pick the first, assume type correct *)
77 Some (loop (t1,t2))
78
79 let lub_envs envs =
80 List.fold_left
81 (function acc ->
82 function env ->
83 List.fold_left
84 (function acc ->
85 function (var,ty) ->
86 let (relevant,irrelevant) =
87 List.partition (function (x,_) -> x = var) acc in
88 match relevant with
89 [] -> (var,ty)::acc
90 | [(x,ty1)] ->
91 (match lub_type (Some ty) (Some ty1) with
92 Some new_ty -> (var,new_ty)::irrelevant
93 | None -> irrelevant)
94 | _ -> failwith "bad type environment")
95 acc env)
96 [] envs
97
98 let rec propagate_types env =
99 let option_default = None in
100 let bind x y = option_default in (* no generic way of combining types *)
101
102 let ident r k i =
103 match Ast0.unwrap i with
104 Ast0.Id(id) ->
105 (try Some(List.assoc (Id(Ast0.unwrap_mcode id)) env)
106 with Not_found -> None)
107 | Ast0.MetaId(id,_,_) ->
108 (try Some(List.assoc (Meta(Ast0.unwrap_mcode id)) env)
109 with Not_found -> None)
110 | Ast0.DisjId(_,id_list,_,_) ->
111 let types = List.map Ast0.get_type id_list in
112 let combined = List.fold_left lub_type None types in
113 (match combined with
114 None -> None
115 | Some t ->
116 List.iter (function i -> Ast0.set_type i (Some t)) id_list;
117 Some t)
118 | _ -> k i in
119
120 let strip_cv = function
121 Some (T.ConstVol(_,t)) -> Some t
122 | t -> t in
123
124 (* types that might be integer types. should char be allowed? *)
125 let rec is_int_type = function
126 T.BaseType(T.IntType)
127 | T.BaseType(T.LongType)
128 | T.BaseType(T.ShortType)
129 | T.MetaType(_,_,_)
130 | T.TypeName _
131 | T.EnumName _
132 | T.SignedT(_,None) -> true
133 | T.SignedT(_,Some ty) -> is_int_type ty
134 | _ -> false in
135
136 let expression r k e =
137 let res = k e in
138 let ty =
139 match Ast0.unwrap e with
140 (* pad: the type of id is set in the ident visitor *)
141 Ast0.Ident(id) -> Ast0.set_type e res; res
142 | Ast0.Constant(const) ->
143 (match Ast0.unwrap_mcode const with
144 Ast.String(_) -> Some (T.Pointer(char_type))
145 | Ast.Char(_) -> Some (char_type)
146 | Ast.Int(_) -> Some (int_type)
147 | Ast.Float(_) -> Some (float_type))
148 (* pad: note that in C can do either ptr(...) or ( *ptr)(...)
149 * so I am not sure this code is enough.
150 *)
151 | Ast0.FunCall(fn,lp,args,rp) ->
152 (match Ast0.get_type fn with
153 Some (T.FunctionPointer(ty)) -> Some ty
154 | _ ->
155 (match Ast0.unwrap fn with
156 Ast0.Ident(id) ->
157 (match Ast0.unwrap id with
158 Ast0.Id(id) ->
159 if List.mem (Ast0.unwrap_mcode id) bool_functions
160 then Some(bool_type)
161 else None
162 | _ -> None)
163 | _ -> None))
164 | Ast0.Assignment(exp1,op,exp2,_) ->
165 let ty = lub_type (Ast0.get_type exp1) (Ast0.get_type exp2) in
166 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty; ty
167 | Ast0.CondExpr(exp1,why,Some exp2,colon,exp3) ->
168 let ty = lub_type (Ast0.get_type exp2) (Ast0.get_type exp3) in
169 Ast0.set_type exp2 ty; Ast0.set_type exp3 ty; ty
170 | Ast0.CondExpr(exp1,why,None,colon,exp3) -> Ast0.get_type exp3
171 | Ast0.Postfix(exp,op) | Ast0.Infix(exp,op) -> (* op is dec or inc *)
172 Ast0.get_type exp
173 | Ast0.Unary(exp,op) ->
174 (match Ast0.unwrap_mcode op with
175 Ast.GetRef ->
176 (match Ast0.get_type exp with
177 None -> Some (T.Pointer(T.Unknown))
178 | Some t -> Some (T.Pointer(t)))
179 | Ast.DeRef ->
180 (match Ast0.get_type exp with
181 Some (T.Pointer(t)) -> Some t
182 | _ -> None)
183 | Ast.UnPlus -> Ast0.get_type exp
184 | Ast.UnMinus -> Ast0.get_type exp
185 | Ast.Tilde -> Ast0.get_type exp
186 | Ast.Not -> Some(bool_type))
187 | Ast0.Nested(exp1,op,exp2) -> failwith "nested in type inf not possible"
188 | Ast0.Binary(exp1,op,exp2) ->
189 let ty1 = Ast0.get_type exp1 in
190 let ty2 = Ast0.get_type exp2 in
191 let same_type = function
192 (None,None) -> Some (int_type)
193
194 (* pad: pointer arithmetic handling as in ptr+1 *)
195 | (Some (T.Pointer ty1),Some ty2) when is_int_type ty2 ->
196 Some (T.Pointer ty1)
197 | (Some ty1,Some (T.Pointer ty2)) when is_int_type ty1 ->
198 Some (T.Pointer ty2)
199
200 | (t1,t2) ->
201 let ty = lub_type t1 t2 in
202 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty; ty in
203 (match Ast0.unwrap_mcode op with
204 Ast.Arith(op) -> same_type (ty1, ty2)
205 | Ast.Logical(Ast.AndLog) | Ast.Logical(Ast.OrLog) ->
206 Some(bool_type)
207 | Ast.Logical(op) ->
208 let ty = lub_type ty1 ty2 in
209 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty;
210 Some(bool_type))
211 | Ast0.Paren(lp,exp,rp) -> Ast0.get_type exp
212 | Ast0.ArrayAccess(exp1,lb,exp2,rb) ->
213 (match strip_cv (Ast0.get_type exp2) with
214 None -> Ast0.set_type exp2 (Some(int_type))
215 | Some(ty) when is_int_type ty -> ()
216 | Some(Type_cocci.Unknown) ->
217 (* unknown comes from param types, not sure why this
218 is not just None... *)
219 Ast0.set_type exp2 (Some(int_type))
220 | Some ty -> err exp2 ty "bad type for an array index");
221 (match strip_cv (Ast0.get_type exp1) with
222 None -> None
223 | Some (T.Array(ty)) -> Some ty
224 | Some (T.Pointer(ty)) -> Some ty
225 | Some (T.MetaType(_,_,_)) -> None
226 | Some x -> err exp1 x "ill-typed array reference")
227 (* pad: should handle structure one day and look 'field' in environment *)
228 | Ast0.RecordAccess(exp,pt,field) ->
229 (match strip_cv (Ast0.get_type exp) with
230 None -> None
231 | Some (T.StructUnionName(_,_)) -> None
232 | Some (T.TypeName(_)) -> None
233 | Some (T.MetaType(_,_,_)) -> None
234 | Some x -> err exp x "non-structure type in field ref")
235 | Ast0.RecordPtAccess(exp,ar,field) ->
236 (match strip_cv (Ast0.get_type exp) with
237 None -> None
238 | Some (T.Pointer(t)) ->
239 (match strip_cv (Some t) with
240 | Some (T.Unknown) -> None
241 | Some (T.MetaType(_,_,_)) -> None
242 | Some (T.TypeName(_)) -> None
243 | Some (T.StructUnionName(_,_)) -> None
244 | Some x ->
245 err exp (T.Pointer(t))
246 "non-structure pointer type in field ref"
247 | _ -> failwith "not possible")
248 | Some (T.MetaType(_,_,_)) -> None
249 | Some (T.TypeName(_)) -> None
250 | Some x -> err exp x "non-structure pointer type in field ref")
251 | Ast0.Cast(lp,ty,rp,exp) -> Some(Ast0.ast0_type_to_type ty)
252 | Ast0.SizeOfExpr(szf,exp) -> Some(int_type)
253 | Ast0.SizeOfType(szf,lp,ty,rp) -> Some(int_type)
254 | Ast0.TypeExp(ty) -> None
255 | Ast0.MetaErr(name,_,_) -> None
256 | Ast0.MetaExpr(name,_,Some [ty],_,_) -> Some ty
257 | Ast0.MetaExpr(name,_,ty,_,_) -> None
258 | Ast0.MetaExprList(name,_,_) -> None
259 | Ast0.EComma(cm) -> None
260 | Ast0.DisjExpr(_,exp_list,_,_) ->
261 let types = List.map Ast0.get_type exp_list in
262 let combined = List.fold_left lub_type None types in
263 (match combined with
264 None -> None
265 | Some t ->
266 List.iter (function e -> Ast0.set_type e (Some t)) exp_list;
267 Some t)
268 | Ast0.NestExpr(starter,expr_dots,ender,None,multi) ->
269 let _ = r.VT0.combiner_rec_expression_dots expr_dots in None
270 | Ast0.NestExpr(starter,expr_dots,ender,Some e,multi) ->
271 let _ = r.VT0.combiner_rec_expression_dots expr_dots in
272 let _ = r.VT0.combiner_rec_expression e in None
273 | Ast0.Edots(_,None) | Ast0.Ecircles(_,None) | Ast0.Estars(_,None) ->
274 None
275 | Ast0.Edots(_,Some e) | Ast0.Ecircles(_,Some e)
276 | Ast0.Estars(_,Some e) ->
277 let _ = r.VT0.combiner_rec_expression e in None
278 | Ast0.OptExp(exp) -> Ast0.get_type exp
279 | Ast0.UniqueExp(exp) -> Ast0.get_type exp in
280 Ast0.set_type e ty;
281 ty in
282
283 let rec strip id =
284 match Ast0.unwrap id with
285 Ast0.Id(name) -> [Id(Ast0.unwrap_mcode name)]
286 | Ast0.MetaId(name,_,_) -> [Meta(Ast0.unwrap_mcode name)]
287 | Ast0.MetaFunc(name,_,_) -> [Meta(Ast0.unwrap_mcode name)]
288 | Ast0.MetaLocalFunc(name,_,_) -> [Meta(Ast0.unwrap_mcode name)]
289 | Ast0.DisjId(_,id_list,_,_) -> List.concat (List.map strip id_list)
290 | Ast0.OptIdent(id) -> strip id
291 | Ast0.UniqueIdent(id) -> strip id in
292
293 let process_whencode notfn allfn exp = function
294 Ast0.WhenNot(x) -> let _ = notfn x in ()
295 | Ast0.WhenAlways(x) -> let _ = allfn x in ()
296 | Ast0.WhenModifier(_) -> ()
297 | Ast0.WhenNotTrue(x) -> let _ = exp x in ()
298 | Ast0.WhenNotFalse(x) -> let _ = exp x in () in
299
300 (* assume that all of the declarations are at the beginning of a statement
301 list, which is required by C, but not actually required by the cocci
302 parser *)
303 let rec process_statement_list r acc = function
304 [] -> acc
305 | (s::ss) ->
306 (match Ast0.unwrap s with
307 Ast0.Decl(_,decl) ->
308 let new_acc = (process_decl acc decl)@acc in
309 process_statement_list r new_acc ss
310 | Ast0.Dots(_,wc) ->
311 (* why is this case here? why is there none for nests? *)
312 List.iter
313 (process_whencode r.VT0.combiner_rec_statement_dots
314 r.VT0.combiner_rec_statement r.VT0.combiner_rec_expression)
315 wc;
316 process_statement_list r acc ss
317 | Ast0.Disj(_,statement_dots_list,_,_) ->
318 let new_acc =
319 lub_envs
320 (List.map
321 (function x -> process_statement_list r acc (Ast0.undots x))
322 statement_dots_list) in
323 process_statement_list r new_acc ss
324 | _ ->
325 let _ = (propagate_types acc).VT0.combiner_rec_statement s in
326 process_statement_list r acc ss)
327
328 and process_decl env decl =
329 match Ast0.unwrap decl with
330 Ast0.MetaDecl(_,_) | Ast0.MetaField(_,_) -> []
331 | Ast0.Init(_,ty,id,_,exp,_) ->
332 let _ =
333 (propagate_types env).VT0.combiner_rec_initialiser exp in
334 let ty = Ast0.ast0_type_to_type ty in
335 List.map (function i -> (i,ty)) (strip id)
336 | Ast0.UnInit(_,ty,id,_) ->
337 let ty = Ast0.ast0_type_to_type ty in
338 List.map (function i -> (i,ty)) (strip id)
339 | Ast0.MacroDecl(_,_,_,_,_) -> []
340 | Ast0.TyDecl(ty,_) -> []
341 (* pad: should handle typedef one day and add a binding *)
342 | Ast0.Typedef(_,_,_,_) -> []
343 | Ast0.DisjDecl(_,disjs,_,_) ->
344 List.concat(List.map (process_decl env) disjs)
345 | Ast0.Ddots(_,_) -> [] (* not in a statement list anyway *)
346 | Ast0.OptDecl(decl) -> process_decl env decl
347 | Ast0.UniqueDecl(decl) -> process_decl env decl in
348
349 let statement_dots r k d =
350 match Ast0.unwrap d with
351 Ast0.DOTS(l) | Ast0.CIRCLES(l) | Ast0.STARS(l) ->
352 let _ = process_statement_list r env l in option_default in
353
354 let post_bool exp =
355 let rec process_test exp =
356 match (Ast0.unwrap exp,Ast0.get_type exp) with
357 (Ast0.Edots(_,_),_) -> None
358 | (Ast0.NestExpr(_,_,_,_,_),_) -> None
359 | (Ast0.MetaExpr(_,_,_,_,_),_) ->
360 (* if a type is known, it is specified in the decl *)
361 None
362 | (Ast0.Paren(lp,exp,rp),None) -> process_test exp
363 (* the following doesn't seem like a good idea - triggers int isos
364 on all test expressions *)
365 (*| (_,None) -> Some (int_type) *)
366 | _ -> None in
367 let new_expty = process_test exp in
368 (match new_expty with
369 None -> () (* leave things as they are *)
370 | Some ty -> Ast0.set_type exp new_expty) in
371
372 let statement r k s =
373 match Ast0.unwrap s with
374 Ast0.FunDecl(_,fninfo,name,lp,params,rp,lbrace,body,rbrace) ->
375 let rec get_binding p =
376 match Ast0.unwrap p with
377 Ast0.Param(ty,Some id) ->
378 let ty = Ast0.ast0_type_to_type ty in
379 List.map (function i -> (i,ty)) (strip id)
380 | Ast0.OptParam(param) -> get_binding param
381 | _ -> [] in
382 let fenv = List.concat (List.map get_binding (Ast0.undots params)) in
383 (propagate_types (fenv@env)).VT0.combiner_rec_statement_dots body
384 | Ast0.IfThen(_,_,exp,_,_,_) | Ast0.IfThenElse(_,_,exp,_,_,_,_,_)
385 | Ast0.While(_,_,exp,_,_,_) | Ast0.Do(_,_,_,_,exp,_,_)
386 | Ast0.For(_,_,_,_,Some exp,_,_,_,_,_) ->
387 let _ = k s in
388 post_bool exp;
389 None
390 | Ast0.Switch(_,_,exp,_,_,decls,cases,_) ->
391 let senv = process_statement_list r env (Ast0.undots decls) in
392 let res =
393 (propagate_types (senv@env)).VT0.combiner_rec_case_line_dots cases in
394 post_bool exp;
395 res
396 | _ -> k s
397
398 and case_line r k c =
399 match Ast0.unwrap c with
400 Ast0.Case(case,exp,colon,code) ->
401 let _ = k c in
402 (match Ast0.get_type exp with
403 None -> Ast0.set_type exp (Some (int_type))
404 | _ -> ());
405 None
406 | _ -> k c in
407
408 V0.combiner bind option_default
409 {V0.combiner_functions with
410 VT0.combiner_dotsstmtfn = statement_dots;
411 VT0.combiner_identfn = ident;
412 VT0.combiner_exprfn = expression;
413 VT0.combiner_stmtfn = statement;
414 VT0.combiner_casefn = case_line}
415
416 let type_infer code =
417 let prop = propagate_types [(Id("NULL"),T.Pointer(T.Unknown))] in
418 let fn = prop.VT0.combiner_rec_top_level in
419 let _ = List.map fn code in
420 ()