Release coccinelle-0.1.6
[bpt/coccinelle.git] / parsing_cocci / type_infer.ml
1 (*
2 * Copyright 2005-2009, Ecole des Mines de Nantes, University of Copenhagen
3 * Yoann Padioleau, Julia Lawall, Rene Rydhof Hansen, Henrik Stuart, Gilles Muller
4 * This file is part of Coccinelle.
5 *
6 * Coccinelle is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, according to version 2 of the License.
9 *
10 * Coccinelle is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with Coccinelle. If not, see <http://www.gnu.org/licenses/>.
17 *
18 * The authors reserve the right to distribute this or future versions of
19 * Coccinelle under other licenses.
20 *)
21
22
23 module T = Type_cocci
24 module Ast = Ast_cocci
25 module Ast0 = Ast0_cocci
26 module V0 = Visitor_ast0
27
28 (* Type inference:
29 Just propagates information based on declarations. Could try to infer
30 more precise information about expression metavariables, but not sure it is
31 worth it. The most obvious goal is to distinguish between test expressions
32 that have pointer, integer, and boolean type when matching isomorphisms,
33 but perhaps other needs will become apparent. *)
34
35 (* "functions" that return a boolean value *)
36 let bool_functions = ["likely";"unlikely"]
37
38 let err wrapped ty s =
39 T.typeC ty; Format.print_newline();
40 failwith (Printf.sprintf "line %d: %s" (Ast0.get_line wrapped) s)
41
42 type id = Id of string | Meta of (string * string)
43
44 let int_type = T.BaseType(T.IntType)
45 let bool_type = T.BaseType(T.BoolType)
46 let char_type = T.BaseType(T.CharType)
47 let float_type = T.BaseType(T.FloatType)
48
49 let rec lub_type t1 t2 =
50 match (t1,t2) with
51 (None,None) -> None
52 | (None,Some t) -> t2
53 | (Some t,None) -> t1
54 | (Some t1,Some t2) ->
55 let rec loop = function
56 (T.Unknown,t2) -> t2
57 | (t1,T.Unknown) -> t1
58 | (T.ConstVol(cv1,ty1),T.ConstVol(cv2,ty2)) when cv1 = cv2 ->
59 T.ConstVol(cv1,loop(ty1,ty2))
60
61 (* pad: in pointer arithmetic, as in ptr+1, the lub must be ptr *)
62 | (T.Pointer(ty1),T.Pointer(ty2)) ->
63 T.Pointer(loop(ty1,ty2))
64 | (ty1,T.Pointer(ty2)) -> T.Pointer(ty2)
65 | (T.Pointer(ty1),ty2) -> T.Pointer(ty1)
66
67 | (T.Array(ty1),T.Array(ty2)) -> T.Array(loop(ty1,ty2))
68 | (T.TypeName(s1),t2) -> t2
69 | (t1,T.TypeName(s1)) -> t1
70 | (t1,_) -> t1 in (* arbitrarily pick the first, assume type correct *)
71 Some (loop (t1,t2))
72
73 let lub_envs envs =
74 List.fold_left
75 (function acc ->
76 function env ->
77 List.fold_left
78 (function acc ->
79 function (var,ty) ->
80 let (relevant,irrelevant) =
81 List.partition (function (x,_) -> x = var) acc in
82 match relevant with
83 [] -> (var,ty)::acc
84 | [(x,ty1)] ->
85 (match lub_type (Some ty) (Some ty1) with
86 Some new_ty -> (var,new_ty)::irrelevant
87 | None -> irrelevant)
88 | _ -> failwith "bad type environment")
89 acc env)
90 [] envs
91
92 let rec propagate_types env =
93 let option_default = None in
94 let bind x y = option_default in (* no generic way of combining types *)
95
96 let mcode x = option_default in
97
98 let ident r k i =
99 match Ast0.unwrap i with
100 Ast0.Id(id) ->
101 (try Some(List.assoc (Id(Ast0.unwrap_mcode id)) env)
102 with Not_found -> None)
103 | Ast0.MetaId(id,_,_) ->
104 (try Some(List.assoc (Meta(Ast0.unwrap_mcode id)) env)
105 with Not_found -> None)
106 | _ -> k i in
107
108 let strip_cv = function
109 Some (T.ConstVol(_,t)) -> Some t
110 | t -> t in
111
112 (* types that might be integer types. should char be allowed? *)
113 let rec is_int_type = function
114 T.BaseType(T.IntType)
115 | T.BaseType(T.LongType)
116 | T.BaseType(T.ShortType)
117 | T.MetaType(_,_,_)
118 | T.TypeName _
119 | T.EnumName _
120 | T.SignedT(_,None) -> true
121 | T.SignedT(_,Some ty) -> is_int_type ty
122 | _ -> false in
123
124 let expression r k e =
125 let res = k e in
126 let ty =
127 match Ast0.unwrap e with
128 (* pad: the type of id is set in the ident visitor *)
129 Ast0.Ident(id) -> Ast0.set_type e res; res
130 | Ast0.Constant(const) ->
131 (match Ast0.unwrap_mcode const with
132 Ast.String(_) -> Some (T.Pointer(char_type))
133 | Ast.Char(_) -> Some (char_type)
134 | Ast.Int(_) -> Some (int_type)
135 | Ast.Float(_) -> Some (float_type))
136 (* pad: note that in C can do either ptr(...) or ( *ptr)(...)
137 * so I am not sure this code is enough.
138 *)
139 | Ast0.FunCall(fn,lp,args,rp) ->
140 (match Ast0.get_type fn with
141 Some (T.FunctionPointer(ty)) -> Some ty
142 | _ ->
143 (match Ast0.unwrap fn with
144 Ast0.Ident(id) ->
145 (match Ast0.unwrap id with
146 Ast0.Id(id) ->
147 if List.mem (Ast0.unwrap_mcode id) bool_functions
148 then Some(bool_type)
149 else None
150 | _ -> None)
151 | _ -> None))
152 | Ast0.Assignment(exp1,op,exp2,_) ->
153 let ty = lub_type (Ast0.get_type exp1) (Ast0.get_type exp2) in
154 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty; ty
155 | Ast0.CondExpr(exp1,why,Some exp2,colon,exp3) ->
156 let ty = lub_type (Ast0.get_type exp2) (Ast0.get_type exp3) in
157 Ast0.set_type exp2 ty; Ast0.set_type exp3 ty; ty
158 | Ast0.CondExpr(exp1,why,None,colon,exp3) -> Ast0.get_type exp3
159 | Ast0.Postfix(exp,op) | Ast0.Infix(exp,op) -> (* op is dec or inc *)
160 Ast0.get_type exp
161 | Ast0.Unary(exp,op) ->
162 (match Ast0.unwrap_mcode op with
163 Ast.GetRef ->
164 (match Ast0.get_type exp with
165 None -> Some (T.Pointer(T.Unknown))
166 | Some t -> Some (T.Pointer(t)))
167 | Ast.DeRef ->
168 (match Ast0.get_type exp with
169 Some (T.Pointer(t)) -> Some t
170 | _ -> None)
171 | Ast.UnPlus -> Ast0.get_type exp
172 | Ast.UnMinus -> Ast0.get_type exp
173 | Ast.Tilde -> Ast0.get_type exp
174 | Ast.Not -> Some(bool_type))
175 | Ast0.Nested(exp1,op,exp2) -> failwith "nested in type inf not possible"
176 | Ast0.Binary(exp1,op,exp2) ->
177 let ty1 = Ast0.get_type exp1 in
178 let ty2 = Ast0.get_type exp2 in
179 let same_type = function
180 (None,None) -> Some (int_type)
181
182 (* pad: pointer arithmetic handling as in ptr+1 *)
183 | (Some (T.Pointer ty1),Some ty2) when is_int_type ty2 ->
184 Some (T.Pointer ty1)
185 | (Some ty1,Some (T.Pointer ty2)) when is_int_type ty1 ->
186 Some (T.Pointer ty2)
187
188 | (t1,t2) ->
189 let ty = lub_type t1 t2 in
190 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty; ty in
191 (match Ast0.unwrap_mcode op with
192 Ast.Arith(op) -> same_type (ty1, ty2)
193 | Ast.Logical(op) ->
194 let ty = lub_type ty1 ty2 in
195 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty;
196 Some(bool_type))
197 | Ast0.Paren(lp,exp,rp) -> Ast0.get_type exp
198 | Ast0.ArrayAccess(exp1,lb,exp2,rb) ->
199 (match strip_cv (Ast0.get_type exp2) with
200 None -> Ast0.set_type exp2 (Some(int_type))
201 | Some(ty) when is_int_type ty -> ()
202 | Some ty -> err exp2 ty "bad type for an array index");
203 (match strip_cv (Ast0.get_type exp1) with
204 None -> None
205 | Some (T.Array(ty)) -> Some ty
206 | Some (T.Pointer(ty)) -> Some ty
207 | Some (T.MetaType(_,_,_)) -> None
208 | Some x -> err exp1 x "ill-typed array reference")
209 (* pad: should handle structure one day and look 'field' in environment *)
210 | Ast0.RecordAccess(exp,pt,field) ->
211 (match strip_cv (Ast0.get_type exp) with
212 None -> None
213 | Some (T.StructUnionName(_,_,_)) -> None
214 | Some (T.TypeName(_)) -> None
215 | Some (T.MetaType(_,_,_)) -> None
216 | Some x -> err exp x "non-structure type in field ref")
217 | Ast0.RecordPtAccess(exp,ar,field) ->
218 (match strip_cv (Ast0.get_type exp) with
219 None -> None
220 | Some (T.Pointer(t)) ->
221 (match strip_cv (Some t) with
222 | Some (T.Unknown) -> None
223 | Some (T.MetaType(_,_,_)) -> None
224 | Some (T.TypeName(_)) -> None
225 | Some (T.StructUnionName(_,_,_)) -> None
226 | Some x ->
227 err exp (T.Pointer(t))
228 "non-structure pointer type in field ref"
229 | _ -> failwith "not possible")
230 | Some (T.MetaType(_,_,_)) -> None
231 | Some (T.TypeName(_)) -> None
232 | Some x -> err exp x "non-structure pointer type in field ref")
233 | Ast0.Cast(lp,ty,rp,exp) -> Some(Ast0.ast0_type_to_type ty)
234 | Ast0.SizeOfExpr(szf,exp) -> Some(int_type)
235 | Ast0.SizeOfType(szf,lp,ty,rp) -> Some(int_type)
236 | Ast0.TypeExp(ty) -> None
237 | Ast0.MetaErr(name,_,_) -> None
238 | Ast0.MetaExpr(name,_,Some [ty],_,_) -> Some ty
239 | Ast0.MetaExpr(name,_,ty,_,_) -> None
240 | Ast0.MetaExprList(name,_,_) -> None
241 | Ast0.EComma(cm) -> None
242 | Ast0.DisjExpr(_,exp_list,_,_) ->
243 let types = List.map Ast0.get_type exp_list in
244 let combined = List.fold_left lub_type None types in
245 (match combined with
246 None -> None
247 | Some t ->
248 List.iter (function e -> Ast0.set_type e (Some t)) exp_list;
249 Some t)
250 | Ast0.NestExpr(starter,expr_dots,ender,None,multi) ->
251 let _ = r.V0.combiner_expression_dots expr_dots in None
252 | Ast0.NestExpr(starter,expr_dots,ender,Some e,multi) ->
253 let _ = r.V0.combiner_expression_dots expr_dots in
254 let _ = r.V0.combiner_expression e in None
255 | Ast0.Edots(_,None) | Ast0.Ecircles(_,None) | Ast0.Estars(_,None) ->
256 None
257 | Ast0.Edots(_,Some e) | Ast0.Ecircles(_,Some e)
258 | Ast0.Estars(_,Some e) ->
259 let _ = r.V0.combiner_expression e in None
260 | Ast0.OptExp(exp) -> Ast0.get_type exp
261 | Ast0.UniqueExp(exp) -> Ast0.get_type exp in
262 Ast0.set_type e ty;
263 ty in
264
265 let donothing r k e = k e in
266
267 let rec strip id =
268 match Ast0.unwrap id with
269 Ast0.Id(name) -> Id(Ast0.unwrap_mcode name)
270 | Ast0.MetaId(name,_,_) -> Meta(Ast0.unwrap_mcode name)
271 | Ast0.MetaFunc(name,_,_) -> Meta(Ast0.unwrap_mcode name)
272 | Ast0.MetaLocalFunc(name,_,_) -> Meta(Ast0.unwrap_mcode name)
273 | Ast0.OptIdent(id) -> strip id
274 | Ast0.UniqueIdent(id) -> strip id in
275
276 let process_whencode notfn allfn exp = function
277 Ast0.WhenNot(x) -> let _ = notfn x in ()
278 | Ast0.WhenAlways(x) -> let _ = allfn x in ()
279 | Ast0.WhenModifier(_) -> ()
280 | Ast0.WhenNotTrue(x) -> let _ = exp x in ()
281 | Ast0.WhenNotFalse(x) -> let _ = exp x in () in
282
283 (* assume that all of the declarations are at the beginning of a statement
284 list, which is required by C, but not actually required by the cocci
285 parser *)
286 let rec process_statement_list r acc = function
287 [] -> acc
288 | (s::ss) ->
289 (match Ast0.unwrap s with
290 Ast0.Decl(_,decl) ->
291 let rec process_decl decl =
292 match Ast0.unwrap decl with
293 Ast0.Init(_,ty,id,_,exp,_) ->
294 let _ =
295 (propagate_types acc).V0.combiner_initialiser exp in
296 [(strip id,Ast0.ast0_type_to_type ty)]
297 | Ast0.UnInit(_,ty,id,_) ->
298 [(strip id,Ast0.ast0_type_to_type ty)]
299 | Ast0.MacroDecl(_,_,_,_,_) -> []
300 | Ast0.TyDecl(ty,_) -> []
301 (* pad: should handle typedef one day and add a binding *)
302 | Ast0.Typedef(_,_,_,_) -> []
303 | Ast0.DisjDecl(_,disjs,_,_) ->
304 List.concat(List.map process_decl disjs)
305 | Ast0.Ddots(_,_) -> [] (* not in a statement list anyway *)
306 | Ast0.OptDecl(decl) -> process_decl decl
307 | Ast0.UniqueDecl(decl) -> process_decl decl in
308 let new_acc = (process_decl decl)@acc in
309 process_statement_list r new_acc ss
310 | Ast0.Dots(_,wc) ->
311 (* why is this case here? why is there none for nests? *)
312 List.iter
313 (process_whencode r.V0.combiner_statement_dots
314 r.V0.combiner_statement r.V0.combiner_expression)
315 wc;
316 process_statement_list r acc ss
317 | Ast0.Disj(_,statement_dots_list,_,_) ->
318 let new_acc =
319 lub_envs
320 (List.map
321 (function x -> process_statement_list r acc (Ast0.undots x))
322 statement_dots_list) in
323 process_statement_list r new_acc ss
324 | _ ->
325 let _ = (propagate_types acc).V0.combiner_statement s in
326 process_statement_list r acc ss) in
327
328 let statement_dots r k d =
329 match Ast0.unwrap d with
330 Ast0.DOTS(l) | Ast0.CIRCLES(l) | Ast0.STARS(l) ->
331 let _ = process_statement_list r env l in option_default in
332 let statement r k s =
333 match Ast0.unwrap s with
334 Ast0.FunDecl(_,fninfo,name,lp,params,rp,lbrace,body,rbrace) ->
335 let rec get_binding p =
336 match Ast0.unwrap p with
337 Ast0.Param(ty,Some id) ->
338 [(strip id,Ast0.ast0_type_to_type ty)]
339 | Ast0.OptParam(param) -> get_binding param
340 | _ -> [] in
341 let fenv = List.concat (List.map get_binding (Ast0.undots params)) in
342 (propagate_types (fenv@env)).V0.combiner_statement_dots body
343 | Ast0.IfThen(_,_,exp,_,_,_) | Ast0.IfThenElse(_,_,exp,_,_,_,_,_)
344 | Ast0.While(_,_,exp,_,_,_) | Ast0.Do(_,_,_,_,exp,_,_)
345 | Ast0.For(_,_,_,_,Some exp,_,_,_,_,_) | Ast0.Switch(_,_,exp,_,_,_,_) ->
346 let _ = k s in
347 let rec process_test exp =
348 match (Ast0.unwrap exp,Ast0.get_type exp) with
349 (Ast0.Edots(_,_),_) -> None
350 | (Ast0.NestExpr(_,_,_,_,_),_) -> None
351 | (Ast0.MetaExpr(_,_,_,_,_),_) ->
352 (* if a type is known, it is specified in the decl *)
353 None
354 | (Ast0.Paren(lp,exp,rp),None) -> process_test exp
355 | (_,None) -> Some (int_type)
356 | _ -> None in
357 let new_expty = process_test exp in
358 (match new_expty with
359 None -> () (* leave things as they are *)
360 | Some ty -> Ast0.set_type exp new_expty);
361 None
362 | _ -> k s
363
364 and case_line r k c =
365 match Ast0.unwrap c with
366 Ast0.Default(def,colon,code) -> let _ = k c in None
367 | Ast0.Case(case,exp,colon,code) ->
368 let _ = k c in
369 (match Ast0.get_type exp with
370 None -> Ast0.set_type exp (Some (int_type))
371 | _ -> ());
372 None
373 | Ast0.OptCase(case) -> k c in
374
375 V0.combiner bind option_default
376 mcode mcode mcode mcode mcode mcode mcode mcode mcode mcode mcode mcode
377 donothing donothing donothing statement_dots donothing donothing
378 ident expression donothing donothing donothing donothing statement
379 case_line donothing
380
381 let type_infer code =
382 let prop = propagate_types [(Id("NULL"),T.Pointer(T.Unknown))] in
383 let fn = prop.V0.combiner_top_level in
384 let _ = List.map fn code in
385 ()