]> git.cworth.org Git - apitrace/blob - retrace.py
Merge branch 'd3dretrace'
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class MutableRebuilder(stdapi.Rebuilder):
37     '''Type visitor which derives a mutable type.'''
38
39     def visitConst(self, const):
40         # Strip out const qualifier
41         return const.type
42
43     def visitAlias(self, alias):
44         # Tear the alias on type changes
45         type = self.visit(alias.type)
46         if type is alias.type:
47             return alias
48         return type
49
50     def visitReference(self, reference):
51         # Strip out references
52         return reference.type
53
54     def visitOpaque(self, opaque):
55         # Don't recursule
56         return opaque
57
58
59 def lookupHandle(handle, value):
60     if handle.key is None:
61         return "__%s_map[%s]" % (handle.name, value)
62     else:
63         key_name, key_type = handle.key
64         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
65
66
67 class ValueAllocator(stdapi.Visitor):
68
69     def visitLiteral(self, literal, lvalue, rvalue):
70         pass
71
72     def visitConst(self, const, lvalue, rvalue):
73         self.visit(const.type, lvalue, rvalue)
74
75     def visitAlias(self, alias, lvalue, rvalue):
76         self.visit(alias.type, lvalue, rvalue)
77
78     def visitEnum(self, enum, lvalue, rvalue):
79         pass
80
81     def visitBitmask(self, bitmask, lvalue, rvalue):
82         pass
83
84     def visitArray(self, array, lvalue, rvalue):
85         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
86
87     def visitPointer(self, pointer, lvalue, rvalue):
88         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
89
90     def visitIntPointer(self, pointer, lvalue, rvalue):
91         pass
92
93     def visitObjPointer(self, pointer, lvalue, rvalue):
94         pass
95
96     def visitLinearPointer(self, pointer, lvalue, rvalue):
97         pass
98
99     def visitReference(self, reference, lvalue, rvalue):
100         self.visit(reference.type, lvalue, rvalue);
101
102     def visitHandle(self, handle, lvalue, rvalue):
103         pass
104
105     def visitBlob(self, blob, lvalue, rvalue):
106         pass
107
108     def visitString(self, string, lvalue, rvalue):
109         pass
110
111     def visitStruct(self, struct, lvalue, rvalue):
112         pass
113
114     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
115         self.visit(polymorphic.defaultType, lvalue, rvalue)
116
117
118 class ValueDeserializer(stdapi.Visitor):
119
120     def visitLiteral(self, literal, lvalue, rvalue):
121         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
122
123     def visitConst(self, const, lvalue, rvalue):
124         self.visit(const.type, lvalue, rvalue)
125
126     def visitAlias(self, alias, lvalue, rvalue):
127         self.visit(alias.type, lvalue, rvalue)
128     
129     def visitEnum(self, enum, lvalue, rvalue):
130         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
131
132     def visitBitmask(self, bitmask, lvalue, rvalue):
133         self.visit(bitmask.type, lvalue, rvalue)
134
135     def visitArray(self, array, lvalue, rvalue):
136
137         tmp = '__a_' + array.tag + '_' + str(self.seq)
138         self.seq += 1
139
140         print '    if (%s) {' % (lvalue,)
141         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
142         length = '%s->values.size()' % (tmp,)
143         index = '__j' + array.tag
144         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
145         try:
146             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
147         finally:
148             print '        }'
149             print '    }'
150     
151     def visitPointer(self, pointer, lvalue, rvalue):
152         tmp = '__a_' + pointer.tag + '_' + str(self.seq)
153         self.seq += 1
154
155         print '    if (%s) {' % (lvalue,)
156         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
157         try:
158             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
159         finally:
160             print '    }'
161
162     def visitIntPointer(self, pointer, lvalue, rvalue):
163         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
164
165     def visitObjPointer(self, pointer, lvalue, rvalue):
166         old_lvalue = '(%s).toUIntPtr()' % (rvalue,)
167         new_lvalue = '_obj_map[%s]' % (old_lvalue,)
168         print '    if (retrace::verbosity >= 2) {'
169         print '        std::cout << std::hex << "obj 0x" << size_t(%s) << " <- 0x" << size_t(%s) << std::dec <<"\\n";' % (old_lvalue, new_lvalue)
170         print '    }'
171         print '    %s = static_cast<%s>(%s);' % (lvalue, pointer, new_lvalue)
172
173     def visitLinearPointer(self, pointer, lvalue, rvalue):
174         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
175
176     def visitReference(self, reference, lvalue, rvalue):
177         self.visit(reference.type, lvalue, rvalue);
178
179     def visitHandle(self, handle, lvalue, rvalue):
180         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
181         self.visit(handle.type, lvalue, rvalue);
182         new_lvalue = lookupHandle(handle, lvalue)
183         print '    if (retrace::verbosity >= 2) {'
184         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
185         print '    }'
186         print '    %s = %s;' % (lvalue, new_lvalue)
187     
188     def visitBlob(self, blob, lvalue, rvalue):
189         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
190     
191     def visitString(self, string, lvalue, rvalue):
192         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
193
194     seq = 0
195
196     def visitStruct(self, struct, lvalue, rvalue):
197         tmp = '__s_' + struct.tag + '_' + str(self.seq)
198         self.seq += 1
199
200         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
201         print '    assert(%s);' % (tmp)
202         for i in range(len(struct.members)):
203             member_type, member_name = struct.members[i]
204             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
205
206     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
207         self.visit(polymorphic.defaultType, lvalue, rvalue)
208
209
210 class OpaqueValueDeserializer(ValueDeserializer):
211     '''Value extractor that also understands opaque values.
212
213     Normally opaque values can't be retraced, unless they are being extracted
214     in the context of handles.'''
215
216     def visitOpaque(self, opaque, lvalue, rvalue):
217         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
218
219
220 class SwizzledValueRegistrator(stdapi.Visitor):
221     '''Type visitor which will register (un)swizzled value pairs, to later be
222     swizzled.'''
223
224     def visitLiteral(self, literal, lvalue, rvalue):
225         pass
226
227     def visitAlias(self, alias, lvalue, rvalue):
228         self.visit(alias.type, lvalue, rvalue)
229     
230     def visitEnum(self, enum, lvalue, rvalue):
231         pass
232
233     def visitBitmask(self, bitmask, lvalue, rvalue):
234         pass
235
236     def visitArray(self, array, lvalue, rvalue):
237         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
238         print '    if (__a%s) {' % (array.tag)
239         length = '__a%s->values.size()' % array.tag
240         index = '__j' + array.tag
241         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
242         try:
243             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
244         finally:
245             print '        }'
246             print '    }'
247     
248     def visitPointer(self, pointer, lvalue, rvalue):
249         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
250         print '    if (__a%s) {' % (pointer.tag)
251         try:
252             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
253         finally:
254             print '    }'
255     
256     def visitIntPointer(self, pointer, lvalue, rvalue):
257         pass
258     
259     def visitObjPointer(self, pointer, lvalue, rvalue):
260         print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
261     
262     def visitLinearPointer(self, pointer, lvalue, rvalue):
263         assert pointer.size is not None
264         if pointer.size is not None:
265             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
266
267     def visitReference(self, reference, lvalue, rvalue):
268         pass
269     
270     def visitHandle(self, handle, lvalue, rvalue):
271         print '    %s __orig_result;' % handle.type
272         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
273         if handle.range is None:
274             rvalue = "__orig_result"
275             entry = lookupHandle(handle, rvalue) 
276             print "    %s = %s;" % (entry, lvalue)
277             print '    if (retrace::verbosity >= 2) {'
278             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
279             print '    }'
280         else:
281             i = '__h' + handle.tag
282             lvalue = "%s + %s" % (lvalue, i)
283             rvalue = "__orig_result + %s" % (i,)
284             entry = lookupHandle(handle, rvalue) 
285             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
286             print '        {entry} = {lvalue};'.format(**locals())
287             print '        if (retrace::verbosity >= 2) {'
288             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
289             print '        }'
290             print '    }'
291     
292     def visitBlob(self, blob, lvalue, rvalue):
293         pass
294     
295     def visitString(self, string, lvalue, rvalue):
296         pass
297
298     seq = 0
299
300     def visitStruct(self, struct, lvalue, rvalue):
301         tmp = '__s_' + struct.tag + '_' + str(self.seq)
302         self.seq += 1
303
304         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
305         print '    assert(%s);' % (tmp,)
306         print '    (void)%s;' % (tmp,)
307         for i in range(len(struct.members)):
308             member_type, member_name = struct.members[i]
309             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
310     
311     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
312         self.visit(polymorphic.defaultType, lvalue, rvalue)
313
314
315 class Retracer:
316
317     def retraceFunction(self, function):
318         print 'static void retrace_%s(trace::Call &call) {' % function.name
319         self.retraceFunctionBody(function)
320         print '}'
321         print
322
323     def retraceInterfaceMethod(self, interface, method):
324         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
325         self.retraceInterfaceMethodBody(interface, method)
326         print '}'
327         print
328
329     def retraceFunctionBody(self, function):
330         assert function.sideeffects
331
332         self.deserializeArgs(function)
333         
334         self.invokeFunction(function)
335
336         self.swizzleValues(function)
337
338     def retraceInterfaceMethodBody(self, interface, method):
339         assert method.sideeffects
340
341         self.deserializeThisPointer(interface)
342
343         self.deserializeArgs(method)
344         
345         self.invokeInterfaceMethod(interface, method)
346
347         self.swizzleValues(method)
348
349     def deserializeThisPointer(self, interface):
350         print r'    %s *_this;' % (interface.name,)
351         print r'    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
352         print r'    if (!_this) {'
353         print r'        retrace::warning(call) << "NULL this pointer\n";'
354         print r'        return;'
355         print r'    }'
356
357     def deserializeArgs(self, function):
358         print '    retrace::ScopedAllocator _allocator;'
359         print '    (void)_allocator;'
360         success = True
361         for arg in function.args:
362             arg_type = MutableRebuilder().visit(arg.type)
363             print '    %s %s;' % (arg_type, arg.name)
364             rvalue = 'call.arg(%u)' % (arg.index,)
365             lvalue = arg.name
366             try:
367                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
368             except NotImplementedError:
369                 success =  False
370                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
371             print
372
373         if not success:
374             print '    if (1) {'
375             self.failFunction(function)
376             if function.name[-1].islower():
377                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
378             print '    }'
379
380     def swizzleValues(self, function):
381         for arg in function.args:
382             if arg.output:
383                 arg_type = MutableRebuilder().visit(arg.type)
384                 rvalue = 'call.arg(%u)' % (arg.index,)
385                 lvalue = arg.name
386                 try:
387                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
388                 except NotImplementedError:
389                     print '    // XXX: %s' % arg.name
390         if function.type is not stdapi.Void:
391             rvalue = '*call.ret'
392             lvalue = '__result'
393             try:
394                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
395             except NotImplementedError:
396                 raise
397                 print '    // XXX: result'
398
399     def failFunction(self, function):
400         print '    if (retrace::verbosity >= 0) {'
401         print '        retrace::unsupported(call);'
402         print '    }'
403         print '    return;'
404
405     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
406         ValueAllocator().visit(arg_type, lvalue, rvalue)
407         if arg.input:
408             ValueDeserializer().visit(arg_type, lvalue, rvalue)
409     
410     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
411         try:
412             ValueAllocator().visit(arg_type, lvalue, rvalue)
413         except NotImplementedError:
414             pass
415         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
416
417     def regiterSwizzledValue(self, type, lvalue, rvalue):
418         visitor = SwizzledValueRegistrator()
419         visitor.visit(type, lvalue, rvalue)
420
421     def invokeFunction(self, function):
422         arg_names = ", ".join(function.argNames())
423         if function.type is not stdapi.Void:
424             print '    %s __result;' % (function.type)
425             print '    __result = %s(%s);' % (function.name, arg_names)
426             print '    (void)__result;'
427         else:
428             print '    %s(%s);' % (function.name, arg_names)
429
430     def invokeInterfaceMethod(self, interface, method):
431         arg_names = ", ".join(method.argNames())
432         if method.type is not stdapi.Void:
433             print '    %s __result;' % (method.type)
434             print '    __result = _this->%s(%s);' % (method.name, arg_names)
435             print '    (void)__result;'
436         else:
437             print '    _this->%s(%s);' % (method.name, arg_names)
438
439     def filterFunction(self, function):
440         return True
441
442     table_name = 'retrace::callbacks'
443
444     def retraceApi(self, api):
445
446         print '#include "os_time.hpp"'
447         print '#include "trace_parser.hpp"'
448         print '#include "retrace.hpp"'
449         print
450
451         types = api.getAllTypes()
452         handles = [type for type in types if isinstance(type, stdapi.Handle)]
453         handle_names = set()
454         for handle in handles:
455             if handle.name not in handle_names:
456                 if handle.key is None:
457                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
458                 else:
459                     key_name, key_type = handle.key
460                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
461                 handle_names.add(handle.name)
462         print
463
464         print 'static std::map<unsigned long long, void *> _obj_map;'
465         print
466
467         functions = filter(self.filterFunction, api.functions)
468         for function in functions:
469             if function.sideeffects:
470                 self.retraceFunction(function)
471         interfaces = api.getAllInterfaces()
472         for interface in interfaces:
473             for method in interface.iterMethods():
474                 if method.sideeffects:
475                     self.retraceInterfaceMethod(interface, method)
476
477         print 'const retrace::Entry %s[] = {' % self.table_name
478         for function in functions:
479             if function.sideeffects:
480                 print '    {"%s", &retrace_%s},' % (function.name, function.name)
481             else:
482                 print '    {"%s", &retrace::ignore},' % (function.name,)
483         for interface in interfaces:
484             for method in interface.iterMethods():                
485                 if method.sideeffects:
486                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
487                 else:
488                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
489         print '    {NULL, NULL}'
490         print '};'
491         print
492