]> git.cworth.org Git - apitrace/blob - retrace/retrace.py
Retrace IUnknown::AddRef/Release correctly.
[apitrace] / retrace / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 # Adjust path
31 import os.path
32 import sys
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
34
35
36 import specs.stdapi as stdapi
37 import specs.glapi as glapi
38
39
40 class UnsupportedType(Exception):
41     pass
42
43
44 def lookupHandle(handle, value):
45     if handle.key is None:
46         return "__%s_map[%s]" % (handle.name, value)
47     else:
48         key_name, key_type = handle.key
49         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
50
51
52 class ValueAllocator(stdapi.Visitor):
53
54     def visitLiteral(self, literal, lvalue, rvalue):
55         pass
56
57     def visitConst(self, const, lvalue, rvalue):
58         self.visit(const.type, lvalue, rvalue)
59
60     def visitAlias(self, alias, lvalue, rvalue):
61         self.visit(alias.type, lvalue, rvalue)
62
63     def visitEnum(self, enum, lvalue, rvalue):
64         pass
65
66     def visitBitmask(self, bitmask, lvalue, rvalue):
67         pass
68
69     def visitArray(self, array, lvalue, rvalue):
70         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
71
72     def visitPointer(self, pointer, lvalue, rvalue):
73         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
74
75     def visitIntPointer(self, pointer, lvalue, rvalue):
76         pass
77
78     def visitObjPointer(self, pointer, lvalue, rvalue):
79         pass
80
81     def visitLinearPointer(self, pointer, lvalue, rvalue):
82         pass
83
84     def visitReference(self, reference, lvalue, rvalue):
85         self.visit(reference.type, lvalue, rvalue);
86
87     def visitHandle(self, handle, lvalue, rvalue):
88         pass
89
90     def visitBlob(self, blob, lvalue, rvalue):
91         pass
92
93     def visitString(self, string, lvalue, rvalue):
94         pass
95
96     def visitStruct(self, struct, lvalue, rvalue):
97         pass
98
99     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
100         self.visit(polymorphic.defaultType, lvalue, rvalue)
101
102     def visitOpaque(self, opaque, lvalue, rvalue):
103         pass
104
105
106 class ValueDeserializer(stdapi.Visitor):
107
108     def visitLiteral(self, literal, lvalue, rvalue):
109         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
110
111     def visitConst(self, const, lvalue, rvalue):
112         self.visit(const.type, lvalue, rvalue)
113
114     def visitAlias(self, alias, lvalue, rvalue):
115         self.visit(alias.type, lvalue, rvalue)
116     
117     def visitEnum(self, enum, lvalue, rvalue):
118         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
119
120     def visitBitmask(self, bitmask, lvalue, rvalue):
121         self.visit(bitmask.type, lvalue, rvalue)
122
123     def visitArray(self, array, lvalue, rvalue):
124
125         tmp = '__a_' + array.tag + '_' + str(self.seq)
126         self.seq += 1
127
128         print '    if (%s) {' % (lvalue,)
129         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
130         length = '%s->values.size()' % (tmp,)
131         index = '__j' + array.tag
132         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
133         try:
134             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
135         finally:
136             print '        }'
137             print '    }'
138     
139     def visitPointer(self, pointer, lvalue, rvalue):
140         tmp = '__a_' + pointer.tag + '_' + str(self.seq)
141         self.seq += 1
142
143         print '    if (%s) {' % (lvalue,)
144         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
145         try:
146             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
147         finally:
148             print '    }'
149
150     def visitIntPointer(self, pointer, lvalue, rvalue):
151         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
152
153     def visitObjPointer(self, pointer, lvalue, rvalue):
154         old_lvalue = '(%s).toUIntPtr()' % (rvalue,)
155         new_lvalue = '_obj_map[%s]' % (old_lvalue,)
156         print '    if (retrace::verbosity >= 2) {'
157         print '        std::cout << std::hex << "obj 0x" << size_t(%s) << " <- 0x" << size_t(%s) << std::dec <<"\\n";' % (old_lvalue, new_lvalue)
158         print '    }'
159         print '    %s = static_cast<%s>(%s);' % (lvalue, pointer, new_lvalue)
160
161     def visitLinearPointer(self, pointer, lvalue, rvalue):
162         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
163
164     def visitReference(self, reference, lvalue, rvalue):
165         self.visit(reference.type, lvalue, rvalue);
166
167     def visitHandle(self, handle, lvalue, rvalue):
168         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
169         self.visit(handle.type, lvalue, rvalue);
170         new_lvalue = lookupHandle(handle, lvalue)
171         print '    if (retrace::verbosity >= 2) {'
172         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
173         print '    }'
174         print '    %s = %s;' % (lvalue, new_lvalue)
175     
176     def visitBlob(self, blob, lvalue, rvalue):
177         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
178     
179     def visitString(self, string, lvalue, rvalue):
180         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
181
182     seq = 0
183
184     def visitStruct(self, struct, lvalue, rvalue):
185         tmp = '__s_' + struct.tag + '_' + str(self.seq)
186         self.seq += 1
187
188         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
189         print '    assert(%s);' % (tmp)
190         for i in range(len(struct.members)):
191             member_type, member_name = struct.members[i]
192             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
193
194     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
195         self.visit(polymorphic.defaultType, lvalue, rvalue)
196     
197     def visitOpaque(self, opaque, lvalue, rvalue):
198         raise UnsupportedType
199
200
201 class OpaqueValueDeserializer(ValueDeserializer):
202     '''Value extractor that also understands opaque values.
203
204     Normally opaque values can't be retraced, unless they are being extracted
205     in the context of handles.'''
206
207     def visitOpaque(self, opaque, lvalue, rvalue):
208         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
209
210
211 class SwizzledValueRegistrator(stdapi.Visitor):
212     '''Type visitor which will register (un)swizzled value pairs, to later be
213     swizzled.'''
214
215     def visitLiteral(self, literal, lvalue, rvalue):
216         pass
217
218     def visitAlias(self, alias, lvalue, rvalue):
219         self.visit(alias.type, lvalue, rvalue)
220     
221     def visitEnum(self, enum, lvalue, rvalue):
222         pass
223
224     def visitBitmask(self, bitmask, lvalue, rvalue):
225         pass
226
227     def visitArray(self, array, lvalue, rvalue):
228         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
229         print '    if (__a%s) {' % (array.tag)
230         length = '__a%s->values.size()' % array.tag
231         index = '__j' + array.tag
232         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
233         try:
234             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
235         finally:
236             print '        }'
237             print '    }'
238     
239     def visitPointer(self, pointer, lvalue, rvalue):
240         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
241         print '    if (__a%s) {' % (pointer.tag)
242         try:
243             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
244         finally:
245             print '    }'
246     
247     def visitIntPointer(self, pointer, lvalue, rvalue):
248         pass
249     
250     def visitObjPointer(self, pointer, lvalue, rvalue):
251         print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
252     
253     def visitLinearPointer(self, pointer, lvalue, rvalue):
254         assert pointer.size is not None
255         if pointer.size is not None:
256             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
257
258     def visitReference(self, reference, lvalue, rvalue):
259         pass
260     
261     def visitHandle(self, handle, lvalue, rvalue):
262         print '    %s __orig_result;' % handle.type
263         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
264         if handle.range is None:
265             rvalue = "__orig_result"
266             entry = lookupHandle(handle, rvalue) 
267             print "    %s = %s;" % (entry, lvalue)
268             print '    if (retrace::verbosity >= 2) {'
269             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
270             print '    }'
271         else:
272             i = '__h' + handle.tag
273             lvalue = "%s + %s" % (lvalue, i)
274             rvalue = "__orig_result + %s" % (i,)
275             entry = lookupHandle(handle, rvalue) 
276             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
277             print '        {entry} = {lvalue};'.format(**locals())
278             print '        if (retrace::verbosity >= 2) {'
279             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
280             print '        }'
281             print '    }'
282     
283     def visitBlob(self, blob, lvalue, rvalue):
284         pass
285     
286     def visitString(self, string, lvalue, rvalue):
287         pass
288
289     seq = 0
290
291     def visitStruct(self, struct, lvalue, rvalue):
292         tmp = '__s_' + struct.tag + '_' + str(self.seq)
293         self.seq += 1
294
295         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
296         print '    assert(%s);' % (tmp,)
297         print '    (void)%s;' % (tmp,)
298         for i in range(len(struct.members)):
299             member_type, member_name = struct.members[i]
300             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
301     
302     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
303         self.visit(polymorphic.defaultType, lvalue, rvalue)
304     
305     def visitOpaque(self, opaque, lvalue, rvalue):
306         pass
307
308
309 class Retracer:
310
311     def retraceFunction(self, function):
312         print 'static void retrace_%s(trace::Call &call) {' % function.name
313         self.retraceFunctionBody(function)
314         print '}'
315         print
316
317     def retraceInterfaceMethod(self, interface, method):
318         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
319         self.retraceInterfaceMethodBody(interface, method)
320         print '}'
321         print
322
323     def retraceFunctionBody(self, function):
324         assert function.sideeffects
325
326         self.deserializeArgs(function)
327         
328         self.invokeFunction(function)
329
330         self.swizzleValues(function)
331
332     def retraceInterfaceMethodBody(self, interface, method):
333         assert method.sideeffects
334
335         self.deserializeThisPointer(interface)
336
337         self.deserializeArgs(method)
338         
339         self.invokeInterfaceMethod(interface, method)
340
341         self.swizzleValues(method)
342
343     def deserializeThisPointer(self, interface):
344         print r'    %s *_this;' % (interface.name,)
345         print r'    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
346         print r'    if (!_this) {'
347         print r'        retrace::warning(call) << "NULL this pointer\n";'
348         print r'        return;'
349         print r'    }'
350
351     def deserializeArgs(self, function):
352         print '    retrace::ScopedAllocator _allocator;'
353         print '    (void)_allocator;'
354         success = True
355         for arg in function.args:
356             arg_type = arg.type.mutable()
357             print '    %s %s;' % (arg_type, arg.name)
358             rvalue = 'call.arg(%u)' % (arg.index,)
359             lvalue = arg.name
360             try:
361                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
362             except UnsupportedType:
363                 success =  False
364                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
365             print
366
367         if not success:
368             print '    if (1) {'
369             self.failFunction(function)
370             if function.name[-1].islower():
371                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
372             print '    }'
373
374     def swizzleValues(self, function):
375         for arg in function.args:
376             if arg.output:
377                 arg_type = arg.type.mutable()
378                 rvalue = 'call.arg(%u)' % (arg.index,)
379                 lvalue = arg.name
380                 try:
381                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
382                 except UnsupportedType:
383                     print '    // XXX: %s' % arg.name
384         if function.type is not stdapi.Void:
385             rvalue = '*call.ret'
386             lvalue = '__result'
387             try:
388                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
389             except UnsupportedType:
390                 raise
391                 print '    // XXX: result'
392
393     def failFunction(self, function):
394         print '    if (retrace::verbosity >= 0) {'
395         print '        retrace::unsupported(call);'
396         print '    }'
397         print '    return;'
398
399     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
400         ValueAllocator().visit(arg_type, lvalue, rvalue)
401         if arg.input:
402             ValueDeserializer().visit(arg_type, lvalue, rvalue)
403     
404     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
405         try:
406             ValueAllocator().visit(arg_type, lvalue, rvalue)
407         except UnsupportedType:
408             pass
409         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
410
411     def regiterSwizzledValue(self, type, lvalue, rvalue):
412         visitor = SwizzledValueRegistrator()
413         visitor.visit(type, lvalue, rvalue)
414
415     def invokeFunction(self, function):
416         arg_names = ", ".join(function.argNames())
417         if function.type is not stdapi.Void:
418             print '    %s __result;' % (function.type)
419             print '    __result = %s(%s);' % (function.name, arg_names)
420             print '    (void)__result;'
421         else:
422             print '    %s(%s);' % (function.name, arg_names)
423
424     def invokeInterfaceMethod(self, interface, method):
425         # On release our reference when we reach Release() == 0 call in the
426         # trace.
427         if method.name == 'Release':
428             print '    if (call.ret->toUInt()) {'
429             print '        return;'
430             print '    }'
431             print '    _obj_map.erase(call.arg(0).toUIntPtr());'
432
433         arg_names = ", ".join(method.argNames())
434         if method.type is not stdapi.Void:
435             print '    %s __result;' % (method.type)
436             print '    __result = _this->%s(%s);' % (method.name, arg_names)
437             print '    (void)__result;'
438         else:
439             print '    _this->%s(%s);' % (method.name, arg_names)
440
441     def filterFunction(self, function):
442         return True
443
444     table_name = 'retrace::callbacks'
445
446     def retraceApi(self, api):
447
448         print '#include "os_time.hpp"'
449         print '#include "trace_parser.hpp"'
450         print '#include "retrace.hpp"'
451         print
452
453         types = api.getAllTypes()
454         handles = [type for type in types if isinstance(type, stdapi.Handle)]
455         handle_names = set()
456         for handle in handles:
457             if handle.name not in handle_names:
458                 if handle.key is None:
459                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
460                 else:
461                     key_name, key_type = handle.key
462                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
463                 handle_names.add(handle.name)
464         print
465
466         print 'static std::map<unsigned long long, void *> _obj_map;'
467         print
468
469         functions = filter(self.filterFunction, api.functions)
470         for function in functions:
471             if function.sideeffects:
472                 self.retraceFunction(function)
473         interfaces = api.getAllInterfaces()
474         for interface in interfaces:
475             for method in interface.iterMethods():
476                 if method.sideeffects:
477                     self.retraceInterfaceMethod(interface, method)
478
479         print 'const retrace::Entry %s[] = {' % self.table_name
480         for function in functions:
481             if function.sideeffects:
482                 print '    {"%s", &retrace_%s},' % (function.name, function.name)
483             else:
484                 print '    {"%s", &retrace::ignore},' % (function.name,)
485         for interface in interfaces:
486             for method in interface.iterMethods():                
487                 if method.sideeffects:
488                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
489                 else:
490                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
491         print '    {NULL, NULL}'
492         print '};'
493         print
494