]> git.cworth.org Git - apitrace/blob - retrace.py
Warn on null this pointer.
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class MutableRebuilder(stdapi.Rebuilder):
37     '''Type visitor which derives a mutable type.'''
38
39     def visitConst(self, const):
40         # Strip out const qualifier
41         return const.type
42
43     def visitAlias(self, alias):
44         # Tear the alias on type changes
45         type = self.visit(alias.type)
46         if type is alias.type:
47             return alias
48         return type
49
50     def visitReference(self, reference):
51         # Strip out references
52         return reference.type
53
54     def visitOpaque(self, opaque):
55         # Don't recursule
56         return opaque
57
58
59 def lookupHandle(handle, value):
60     if handle.key is None:
61         return "__%s_map[%s]" % (handle.name, value)
62     else:
63         key_name, key_type = handle.key
64         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
65
66
67 class ValueDeserializer(stdapi.Visitor):
68
69     def visitLiteral(self, literal, lvalue, rvalue):
70         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
71
72     def visitConst(self, const, lvalue, rvalue):
73         self.visit(const.type, lvalue, rvalue)
74
75     def visitAlias(self, alias, lvalue, rvalue):
76         self.visit(alias.type, lvalue, rvalue)
77     
78     def visitEnum(self, enum, lvalue, rvalue):
79         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
80
81     def visitBitmask(self, bitmask, lvalue, rvalue):
82         self.visit(bitmask.type, lvalue, rvalue)
83
84     allocated = False
85
86     def visitArray(self, array, lvalue, rvalue):
87         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
88         length = '__a%s->values.size()' % array.tag
89         allocated = self.allocated
90         if not allocated:
91             print '    if (__a%s) {' % (array.tag)
92             print '        %s = _allocator.alloc<%s>(%s);' % (lvalue, array.type, length)
93             self.allocated = True
94         index = '__j' + array.tag
95         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
96         try:
97             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
98         finally:
99             print '        }'
100             if not allocated:
101                 print '    } else {'
102                 print '        %s = NULL;' % lvalue
103                 print '    }'
104     
105     def visitPointer(self, pointer, lvalue, rvalue):
106         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
107         allocated = self.allocated
108         if not allocated:
109             print '    if (__a%s) {' % (pointer.tag)
110             print '        %s = _allocator.alloc<%s>();' % (lvalue, pointer.type)
111             self.allocated = True
112         try:
113             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
114         finally:
115             if not allocated:
116                 print '    } else {'
117                 print '        %s = NULL;' % lvalue
118                 print '    }'
119
120     def visitIntPointer(self, pointer, lvalue, rvalue):
121         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
122
123     def visitObjPointer(self, pointer, lvalue, rvalue):
124         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
125
126     def visitLinearPointer(self, pointer, lvalue, rvalue):
127         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
128
129     def visitReference(self, reference, lvalue, rvalue):
130         self.visit(reference.type, lvalue, rvalue);
131
132     def visitHandle(self, handle, lvalue, rvalue):
133         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
134         self.visit(handle.type, lvalue, rvalue);
135         new_lvalue = lookupHandle(handle, lvalue)
136         print '    if (retrace::verbosity >= 2) {'
137         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
138         print '    }'
139         print '    %s = %s;' % (lvalue, new_lvalue)
140     
141     def visitBlob(self, blob, lvalue, rvalue):
142         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
143     
144     def visitString(self, string, lvalue, rvalue):
145         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
146
147     seq = 0
148
149     def visitStruct(self, struct, lvalue, rvalue):
150         tmp = '__s_' + struct.tag + '_' + str(self.seq)
151         self.seq += 1
152
153         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
154         print '    assert(%s);' % (tmp)
155         self.allocated = True
156         for i in range(len(struct.members)):
157             member_type, member_name = struct.members[i]
158             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
159
160
161 class OpaqueValueDeserializer(ValueDeserializer):
162     '''Value extractor that also understands opaque values.
163
164     Normally opaque values can't be retraced, unless they are being extracted
165     in the context of handles.'''
166
167     def visitOpaque(self, opaque, lvalue, rvalue):
168         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
169
170
171 class SwizzledValueRegistrator(stdapi.Visitor):
172     '''Type visitor which will register (un)swizzled value pairs, to later be
173     swizzled.'''
174
175     def visitLiteral(self, literal, lvalue, rvalue):
176         pass
177
178     def visitAlias(self, alias, lvalue, rvalue):
179         self.visit(alias.type, lvalue, rvalue)
180     
181     def visitEnum(self, enum, lvalue, rvalue):
182         pass
183
184     def visitBitmask(self, bitmask, lvalue, rvalue):
185         pass
186
187     def visitArray(self, array, lvalue, rvalue):
188         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
189         print '    if (__a%s) {' % (array.tag)
190         length = '__a%s->values.size()' % array.tag
191         index = '__j' + array.tag
192         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
193         try:
194             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
195         finally:
196             print '        }'
197             print '    }'
198     
199     def visitPointer(self, pointer, lvalue, rvalue):
200         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
201         print '    if (__a%s) {' % (pointer.tag)
202         try:
203             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
204         finally:
205             print '    }'
206     
207     def visitIntPointer(self, pointer, lvalue, rvalue):
208         pass
209     
210     def visitObjPointer(self, pointer, lvalue, rvalue):
211         print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
212     
213     def visitLinearPointer(self, pointer, lvalue, rvalue):
214         assert pointer.size is not None
215         if pointer.size is not None:
216             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
217
218     def visitReference(self, reference, lvalue, rvalue):
219         pass
220     
221     def visitHandle(self, handle, lvalue, rvalue):
222         print '    %s __orig_result;' % handle.type
223         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
224         if handle.range is None:
225             rvalue = "__orig_result"
226             entry = lookupHandle(handle, rvalue) 
227             print "    %s = %s;" % (entry, lvalue)
228             print '    if (retrace::verbosity >= 2) {'
229             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
230             print '    }'
231         else:
232             i = '__h' + handle.tag
233             lvalue = "%s + %s" % (lvalue, i)
234             rvalue = "__orig_result + %s" % (i,)
235             entry = lookupHandle(handle, rvalue) 
236             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
237             print '        {entry} = {lvalue};'.format(**locals())
238             print '        if (retrace::verbosity >= 2) {'
239             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
240             print '        }'
241             print '    }'
242     
243     def visitBlob(self, blob, lvalue, rvalue):
244         pass
245     
246     def visitString(self, string, lvalue, rvalue):
247         pass
248
249
250 class Retracer:
251
252     def retraceFunction(self, function):
253         print 'static void retrace_%s(trace::Call &call) {' % function.name
254         self.retraceFunctionBody(function)
255         print '}'
256         print
257
258     def retraceInterfaceMethod(self, interface, method):
259         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
260         self.retraceInterfaceMethodBody(interface, method)
261         print '}'
262         print
263
264     def retraceFunctionBody(self, function):
265         if not function.sideeffects:
266             print '    (void)call;'
267             return
268
269         self.deserializeArgs(function)
270         
271         self.invokeFunction(function)
272
273         self.swizzleValues(function)
274
275     def retraceInterfaceMethodBody(self, interface, method):
276         if not method.sideeffects:
277             print '    (void)call;'
278             return
279
280         self.deserializeThisPointer(interface)
281
282         self.deserializeArgs(method)
283         
284         self.invokeInterfaceMethod(interface, method)
285
286         self.swizzleValues(method)
287
288     def deserializeThisPointer(self, interface):
289         print r'    %s *_this;' % (interface.name,)
290         print r'    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
291         print r'    if (!_this) {'
292         print r'        retrace::warning(call) << "NULL this pointer\n";'
293         print r'        return;'
294         print r'    }'
295
296     def deserializeArgs(self, function):
297         print '    retrace::ScopedAllocator _allocator;'
298         print '    (void)_allocator;'
299         success = True
300         for arg in function.args:
301             arg_type = MutableRebuilder().visit(arg.type)
302             #print '    // %s ->  %s' % (arg.type, arg_type)
303             print '    %s %s;' % (arg_type, arg.name)
304             rvalue = 'call.arg(%u)' % (arg.index,)
305             lvalue = arg.name
306             try:
307                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
308             except NotImplementedError:
309                 success =  False
310                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
311
312         if not success:
313             print '    if (1) {'
314             self.failFunction(function)
315             if function.name[-1].islower():
316                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
317             print '    }'
318
319     def swizzleValues(self, function):
320         for arg in function.args:
321             if arg.output:
322                 arg_type = MutableRebuilder().visit(arg.type)
323                 rvalue = 'call.arg(%u)' % (arg.index,)
324                 lvalue = arg.name
325                 try:
326                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
327                 except NotImplementedError:
328                     print '    // XXX: %s' % arg.name
329         if function.type is not stdapi.Void:
330             rvalue = '*call.ret'
331             lvalue = '__result'
332             try:
333                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
334             except NotImplementedError:
335                 raise
336                 print '    // XXX: result'
337
338     def failFunction(self, function):
339         print '    if (retrace::verbosity >= 0) {'
340         print '        retrace::unsupported(call);'
341         print '    }'
342         print '    return;'
343
344     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
345         ValueDeserializer().visit(arg_type, lvalue, rvalue)
346     
347     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
348         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
349
350     def regiterSwizzledValue(self, type, lvalue, rvalue):
351         visitor = SwizzledValueRegistrator()
352         visitor.visit(type, lvalue, rvalue)
353
354     def invokeFunction(self, function):
355         arg_names = ", ".join(function.argNames())
356         if function.type is not stdapi.Void:
357             print '    %s __result;' % (function.type)
358             print '    __result = %s(%s);' % (function.name, arg_names)
359             print '    (void)__result;'
360         else:
361             print '    %s(%s);' % (function.name, arg_names)
362
363     def invokeInterfaceMethod(self, interface, method):
364         arg_names = ", ".join(method.argNames())
365         if method.type is not stdapi.Void:
366             print '    %s __result;' % (method.type)
367             print '    __result = _this->%s(%s);' % (method.name, arg_names)
368             print '    (void)__result;'
369         else:
370             print '    _this->%s(%s);' % (method.name, arg_names)
371
372     def filterFunction(self, function):
373         return True
374
375     table_name = 'retrace::callbacks'
376
377     def retraceApi(self, api):
378
379         print '#include "os_time.hpp"'
380         print '#include "trace_parser.hpp"'
381         print '#include "retrace.hpp"'
382         print
383
384         types = api.getAllTypes()
385         handles = [type for type in types if isinstance(type, stdapi.Handle)]
386         handle_names = set()
387         for handle in handles:
388             if handle.name not in handle_names:
389                 if handle.key is None:
390                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
391                 else:
392                     key_name, key_type = handle.key
393                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
394                 handle_names.add(handle.name)
395         print
396
397         print 'static std::map<unsigned long long, void *> _obj_map;'
398         print
399
400         functions = filter(self.filterFunction, api.functions)
401         for function in functions:
402             self.retraceFunction(function)
403         interfaces = api.getAllInterfaces()
404         for interface in interfaces:
405             for method in interface.iterMethods():
406                 self.retraceInterfaceMethod(interface, method)
407
408         print 'const retrace::Entry %s[] = {' % self.table_name
409         for function in functions:
410             print '    {"%s", &retrace_%s},' % (function.name, function.name)
411         for interface in interfaces:
412             for method in interface.iterMethods():
413                 print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
414         print '    {NULL, NULL}'
415         print '};'
416         print
417