1d566ea32cef6d71fa4b400fc2f05fab7ca2e0ef
[bpt/coccinelle.git] / parsing_cocci / type_infer.ml
1 (*
2 * Copyright 2005-2008, Ecole des Mines de Nantes, University of Copenhagen
3 * Yoann Padioleau, Julia Lawall, Rene Rydhof Hansen, Henrik Stuart, Gilles Muller
4 * This file is part of Coccinelle.
5 *
6 * Coccinelle is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, according to version 2 of the License.
9 *
10 * Coccinelle is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with Coccinelle. If not, see <http://www.gnu.org/licenses/>.
17 *
18 * The authors reserve the right to distribute this or future versions of
19 * Coccinelle under other licenses.
20 *)
21
22
23 module T = Type_cocci
24 module Ast = Ast_cocci
25 module Ast0 = Ast0_cocci
26 module V0 = Visitor_ast0
27
28 (* Type inference:
29 Just propagates information based on declarations. Could try to infer
30 more precise information about expression metavariables, but not sure it is
31 worth it. The most obvious goal is to distinguish between test expressions
32 that have pointer, integer, and boolean type when matching isomorphisms,
33 but perhaps other needs will become apparent. *)
34
35 (* "functions" that return a boolean value *)
36 let bool_functions = ["likely";"unlikely"]
37
38 let err wrapped ty s =
39 T.typeC ty; Format.print_newline();
40 failwith (Printf.sprintf "line %d: %s" (Ast0.get_line wrapped) s)
41
42 type id = Id of string | Meta of (string * string)
43
44 let rec lub_type t1 t2 =
45 match (t1,t2) with
46 (None,None) -> None
47 | (None,Some t) -> t2
48 | (Some t,None) -> t1
49 | (Some t1,Some t2) ->
50 let rec loop = function
51 (T.Unknown,t2) -> t2
52 | (t1,T.Unknown) -> t1
53 | (T.ConstVol(cv1,ty1),T.ConstVol(cv2,ty2)) when cv1 = cv2 ->
54 T.ConstVol(cv1,loop(ty1,ty2))
55 | (T.Pointer(ty1),T.Pointer(ty2)) ->
56 T.Pointer(loop(ty1,ty2))
57 | (ty1,T.Pointer(ty2)) -> T.Pointer(ty2)
58 | (T.Pointer(ty1),ty2) -> T.Pointer(ty1)
59 | (T.Array(ty1),T.Array(ty2)) -> T.Array(loop(ty1,ty2))
60 | (T.TypeName(s1),t2) -> t2
61 | (t1,T.TypeName(s1)) -> t1
62 | (t1,_) -> t1 in (* arbitrarily pick the first, assume type correct *)
63 Some (loop (t1,t2))
64
65 let lub_envs envs =
66 List.fold_left
67 (function acc ->
68 function env ->
69 List.fold_left
70 (function acc ->
71 function (var,ty) ->
72 let (relevant,irrelevant) =
73 List.partition (function (x,_) -> x = var) acc in
74 match relevant with
75 [] -> (var,ty)::acc
76 | [(x,ty1)] ->
77 (match lub_type (Some ty) (Some ty1) with
78 Some new_ty -> (var,new_ty)::irrelevant
79 | None -> irrelevant)
80 | _ -> failwith "bad type environment")
81 acc env)
82 [] envs
83
84 let rec propagate_types env =
85 let option_default = None in
86 let bind x y = option_default in (* no generic way of combining types *)
87
88 let mcode x = option_default in
89
90 let ident r k i =
91 match Ast0.unwrap i with
92 Ast0.Id(id) ->
93 (try Some(List.assoc (Id(Ast0.unwrap_mcode id)) env)
94 with Not_found -> None)
95 | Ast0.MetaId(id,_,_) ->
96 (try Some(List.assoc (Meta(Ast0.unwrap_mcode id)) env)
97 with Not_found -> None)
98 | _ -> k i in
99
100 let strip_cv = function
101 Some (T.ConstVol(_,t)) -> Some t
102 | t -> t in
103
104 let expression r k e =
105 let res = k e in
106 let ty =
107 match Ast0.unwrap e with
108 Ast0.Ident(id) -> Ast0.set_type e res; res
109 | Ast0.Constant(const) ->
110 (match Ast0.unwrap_mcode const with
111 Ast.String(_) -> Some (T.Pointer(T.BaseType(T.CharType,None)))
112 | Ast.Char(_) -> Some (T.BaseType(T.CharType,None))
113 | Ast.Int(_) -> Some (T.BaseType(T.IntType,None))
114 | Ast.Float(_) -> Some (T.BaseType(T.FloatType,None)))
115 | Ast0.FunCall(fn,lp,args,rp) ->
116 (match Ast0.get_type fn with
117 Some (T.FunctionPointer(ty)) -> Some ty
118 | _ ->
119 (match Ast0.unwrap fn with
120 Ast0.Ident(id) ->
121 (match Ast0.unwrap id with
122 Ast0.Id(id) ->
123 if List.mem (Ast0.unwrap_mcode id) bool_functions
124 then Some(T.BaseType(T.BoolType,None))
125 else None
126 | _ -> None)
127 | _ -> None))
128 | Ast0.Assignment(exp1,op,exp2,_) ->
129 let ty = lub_type (Ast0.get_type exp1) (Ast0.get_type exp2) in
130 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty; ty
131 | Ast0.CondExpr(exp1,why,Some exp2,colon,exp3) ->
132 let ty = lub_type (Ast0.get_type exp2) (Ast0.get_type exp3) in
133 Ast0.set_type exp2 ty; Ast0.set_type exp3 ty; ty
134 | Ast0.CondExpr(exp1,why,None,colon,exp3) -> Ast0.get_type exp3
135 | Ast0.Postfix(exp,op) | Ast0.Infix(exp,op) -> (* op is dec or inc *)
136 Ast0.get_type exp
137 | Ast0.Unary(exp,op) ->
138 (match Ast0.unwrap_mcode op with
139 Ast.GetRef ->
140 (match Ast0.get_type exp with
141 None -> Some (T.Pointer(T.Unknown))
142 | Some t -> Some (T.Pointer(t)))
143 | Ast.DeRef ->
144 (match Ast0.get_type exp with
145 Some (T.Pointer(t)) -> Some t
146 | _ -> None)
147 | Ast.UnPlus -> Ast0.get_type exp
148 | Ast.UnMinus -> Ast0.get_type exp
149 | Ast.Tilde -> Ast0.get_type exp
150 | Ast.Not -> Some(T.BaseType(T.BoolType,None)))
151 | Ast0.Nested(exp1,op,exp2) -> failwith "nested in type inf not possible"
152 | Ast0.Binary(exp1,op,exp2) ->
153 let ty1 = Ast0.get_type exp1 in
154 let ty2 = Ast0.get_type exp2 in
155 let same_type = function
156 (None,None) -> Some (T.BaseType(T.IntType,None))
157 | (Some (T.Pointer ty1),Some ty2) ->
158 Some (T.Pointer ty1)
159 | (Some ty1,Some (T.Pointer ty2)) ->
160 Some (T.Pointer ty2)
161 | (t1,t2) ->
162 let ty = lub_type t1 t2 in
163 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty; ty in
164 (match Ast0.unwrap_mcode op with
165 Ast.Arith(op) -> same_type (ty1, ty2)
166 | Ast.Logical(op) ->
167 let ty = lub_type ty1 ty2 in
168 Ast0.set_type exp1 ty; Ast0.set_type exp2 ty;
169 Some(T.BaseType(T.BoolType,None)))
170 | Ast0.Paren(lp,exp,rp) -> Ast0.get_type exp
171 | Ast0.ArrayAccess(exp1,lb,exp2,rb) ->
172 (match strip_cv (Ast0.get_type exp2) with
173 None -> Ast0.set_type exp2 (Some(T.BaseType(T.IntType,None)))
174 | Some(T.BaseType(T.IntType,None)) -> ()
175 | Some (T.MetaType(_,_,_)) -> ()
176 | Some (T.TypeName _) -> ()
177 | Some ty -> err exp2 ty "bad type for an array index");
178 (match strip_cv (Ast0.get_type exp1) with
179 None -> None
180 | Some (T.Array(ty)) -> Some ty
181 | Some (T.Pointer(ty)) -> Some ty
182 | Some (T.MetaType(_,_,_)) -> None
183 | Some x -> err exp1 x "ill-typed array reference")
184 | Ast0.RecordAccess(exp,pt,field) ->
185 (match strip_cv (Ast0.get_type exp) with
186 None -> None
187 | Some (T.StructUnionName(_,_,_)) -> None
188 | Some (T.TypeName(_)) -> None
189 | Some (T.MetaType(_,_,_)) -> None
190 | Some x -> err exp x "non-structure type in field ref")
191 | Ast0.RecordPtAccess(exp,ar,field) ->
192 (match strip_cv (Ast0.get_type exp) with
193 None -> None
194 | Some (T.Pointer(t)) ->
195 (match strip_cv (Some t) with
196 | Some (T.Unknown) -> None
197 | Some (T.MetaType(_,_,_)) -> None
198 | Some (T.TypeName(_)) -> None
199 | Some (T.StructUnionName(_,_,_)) -> None
200 | Some x ->
201 err exp (T.Pointer(t))
202 "non-structure pointer type in field ref"
203 | _ -> failwith "not possible")
204 | Some (T.MetaType(_,_,_)) -> None
205 | Some (T.TypeName(_)) -> None
206 | Some x -> err exp x "non-structure pointer type in field ref")
207 | Ast0.Cast(lp,ty,rp,exp) -> Some(Ast0.ast0_type_to_type ty)
208 | Ast0.SizeOfExpr(szf,exp) -> Some(T.BaseType(T.IntType,None))
209 | Ast0.SizeOfType(szf,lp,ty,rp) -> Some(T.BaseType(T.IntType,None))
210 | Ast0.TypeExp(ty) -> None
211 | Ast0.MetaErr(name,_,_) -> None
212 | Ast0.MetaExpr(name,_,Some [ty],_,_) -> Some ty
213 | Ast0.MetaExpr(name,_,ty,_,_) -> None
214 | Ast0.MetaExprList(name,_,_) -> None
215 | Ast0.EComma(cm) -> None
216 | Ast0.DisjExpr(_,exp_list,_,_) ->
217 let types = List.map Ast0.get_type exp_list in
218 let combined = List.fold_left lub_type None types in
219 (match combined with
220 None -> None
221 | Some t ->
222 List.iter (function e -> Ast0.set_type e (Some t)) exp_list;
223 Some t)
224 | Ast0.NestExpr(starter,expr_dots,ender,None,multi) ->
225 let _ = r.V0.combiner_expression_dots expr_dots in None
226 | Ast0.NestExpr(starter,expr_dots,ender,Some e,multi) ->
227 let _ = r.V0.combiner_expression_dots expr_dots in
228 let _ = r.V0.combiner_expression e in None
229 | Ast0.Edots(_,None) | Ast0.Ecircles(_,None) | Ast0.Estars(_,None) ->
230 None
231 | Ast0.Edots(_,Some e) | Ast0.Ecircles(_,Some e)
232 | Ast0.Estars(_,Some e) ->
233 let _ = r.V0.combiner_expression e in None
234 | Ast0.OptExp(exp) -> Ast0.get_type exp
235 | Ast0.UniqueExp(exp) -> Ast0.get_type exp in
236 Ast0.set_type e ty;
237 ty in
238
239 let donothing r k e = k e in
240
241 let rec strip id =
242 match Ast0.unwrap id with
243 Ast0.Id(name) -> Id(Ast0.unwrap_mcode name)
244 | Ast0.MetaId(name,_,_) -> Meta(Ast0.unwrap_mcode name)
245 | Ast0.MetaFunc(name,_,_) -> Meta(Ast0.unwrap_mcode name)
246 | Ast0.MetaLocalFunc(name,_,_) -> Meta(Ast0.unwrap_mcode name)
247 | Ast0.OptIdent(id) -> strip id
248 | Ast0.UniqueIdent(id) -> strip id in
249
250 let process_whencode notfn allfn exp = function
251 Ast0.WhenNot(x) -> let _ = notfn x in ()
252 | Ast0.WhenAlways(x) -> let _ = allfn x in ()
253 | Ast0.WhenModifier(_) -> ()
254 | Ast0.WhenNotTrue(x) -> let _ = exp x in ()
255 | Ast0.WhenNotFalse(x) -> let _ = exp x in () in
256
257 (* assume that all of the declarations are at the beginning of a statement
258 list, which is required by C, but not actually required by the cocci
259 parser *)
260 let rec process_statement_list r acc = function
261 [] -> acc
262 | (s::ss) ->
263 (match Ast0.unwrap s with
264 Ast0.Decl(_,decl) ->
265 let rec process_decl decl =
266 match Ast0.unwrap decl with
267 Ast0.Init(_,ty,id,_,exp,_) ->
268 let _ =
269 (propagate_types acc).V0.combiner_initialiser exp in
270 [(strip id,Ast0.ast0_type_to_type ty)]
271 | Ast0.UnInit(_,ty,id,_) ->
272 [(strip id,Ast0.ast0_type_to_type ty)]
273 | Ast0.MacroDecl(_,_,_,_,_) -> []
274 | Ast0.TyDecl(ty,_) -> []
275 | Ast0.Typedef(_,_,_,_) -> []
276 | Ast0.DisjDecl(_,disjs,_,_) ->
277 List.concat(List.map process_decl disjs)
278 | Ast0.Ddots(_,_) -> [] (* not in a statement list anyway *)
279 | Ast0.OptDecl(decl) -> process_decl decl
280 | Ast0.UniqueDecl(decl) -> process_decl decl in
281 let new_acc = (process_decl decl)@acc in
282 process_statement_list r new_acc ss
283 | Ast0.Dots(_,wc) ->
284 (* why is this case here? why is there none for nests? *)
285 List.iter
286 (process_whencode r.V0.combiner_statement_dots
287 r.V0.combiner_statement r.V0.combiner_expression)
288 wc;
289 process_statement_list r acc ss
290 | Ast0.Disj(_,statement_dots_list,_,_) ->
291 let new_acc =
292 lub_envs
293 (List.map
294 (function x -> process_statement_list r acc (Ast0.undots x))
295 statement_dots_list) in
296 process_statement_list r new_acc ss
297 | _ ->
298 let _ = (propagate_types acc).V0.combiner_statement s in
299 process_statement_list r acc ss) in
300
301 let statement_dots r k d =
302 match Ast0.unwrap d with
303 Ast0.DOTS(l) | Ast0.CIRCLES(l) | Ast0.STARS(l) ->
304 let _ = process_statement_list r env l in option_default in
305 let statement r k s =
306 match Ast0.unwrap s with
307 Ast0.FunDecl(_,fninfo,name,lp,params,rp,lbrace,body,rbrace) ->
308 let rec get_binding p =
309 match Ast0.unwrap p with
310 Ast0.Param(ty,Some id) ->
311 [(strip id,Ast0.ast0_type_to_type ty)]
312 | Ast0.OptParam(param) -> get_binding param
313 | _ -> [] in
314 let fenv = List.concat (List.map get_binding (Ast0.undots params)) in
315 (propagate_types (fenv@env)).V0.combiner_statement_dots body
316 | Ast0.IfThen(_,_,exp,_,_,_) | Ast0.IfThenElse(_,_,exp,_,_,_,_,_)
317 | Ast0.While(_,_,exp,_,_,_) | Ast0.Do(_,_,_,_,exp,_,_)
318 | Ast0.For(_,_,_,_,Some exp,_,_,_,_,_) | Ast0.Switch(_,_,exp,_,_,_,_) ->
319 let _ = k s in
320 let rec process_test exp =
321 match (Ast0.unwrap exp,Ast0.get_type exp) with
322 (Ast0.Edots(_,_),_) -> None
323 | (Ast0.NestExpr(_,_,_,_,_),_) -> None
324 | (Ast0.MetaExpr(_,_,_,_,_),_) ->
325 (* if a type is known, it is specified in the decl *)
326 None
327 | (Ast0.Paren(lp,exp,rp),None) -> process_test exp
328 | (_,None) -> Some (T.BaseType(T.IntType,None))
329 | _ -> None in
330 let new_expty = process_test exp in
331 (match new_expty with
332 None -> () (* leave things as they are *)
333 | Some ty -> Ast0.set_type exp new_expty);
334 None
335 | _ -> k s
336
337 and case_line r k c =
338 match Ast0.unwrap c with
339 Ast0.Default(def,colon,code) -> let _ = k c in None
340 | Ast0.Case(case,exp,colon,code) ->
341 let _ = k c in
342 (match Ast0.get_type exp with
343 None -> Ast0.set_type exp (Some (T.BaseType(T.IntType,None)))
344 | _ -> ());
345 None
346 | Ast0.OptCase(case) -> k c in
347
348 V0.combiner bind option_default
349 mcode mcode mcode mcode mcode mcode mcode mcode mcode mcode mcode mcode
350 mcode
351 donothing donothing donothing statement_dots donothing donothing
352 ident expression donothing donothing donothing donothing statement
353 case_line donothing
354
355 let type_infer code =
356 let prop = propagate_types [(Id("NULL"),T.Pointer(T.Unknown))] in
357 let fn = prop.V0.combiner_top_level in
358 let _ = List.map fn code in
359 ()