]> git.cworth.org Git - apitrace/blob - retrace.py
Find the x64 DXSDK libraries.
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class MutableRebuilder(stdapi.Rebuilder):
37     '''Type visitor which derives a mutable type.'''
38
39     def visitConst(self, const):
40         # Strip out const qualifier
41         return const.type
42
43     def visitAlias(self, alias):
44         # Tear the alias on type changes
45         type = self.visit(alias.type)
46         if type is alias.type:
47             return alias
48         return type
49
50     def visitReference(self, reference):
51         # Strip out references
52         return reference.type
53
54     def visitOpaque(self, opaque):
55         # Don't recursule
56         return opaque
57
58
59 def lookupHandle(handle, value):
60     if handle.key is None:
61         return "__%s_map[%s]" % (handle.name, value)
62     else:
63         key_name, key_type = handle.key
64         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
65
66
67 class ValueAllocator(stdapi.Visitor):
68
69     def visitLiteral(self, literal, lvalue, rvalue):
70         pass
71
72     def visitConst(self, const, lvalue, rvalue):
73         self.visit(const.type, lvalue, rvalue)
74
75     def visitAlias(self, alias, lvalue, rvalue):
76         self.visit(alias.type, lvalue, rvalue)
77
78     def visitEnum(self, enum, lvalue, rvalue):
79         pass
80
81     def visitBitmask(self, bitmask, lvalue, rvalue):
82         pass
83
84     def visitArray(self, array, lvalue, rvalue):
85         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
86
87     def visitPointer(self, pointer, lvalue, rvalue):
88         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
89
90     def visitIntPointer(self, pointer, lvalue, rvalue):
91         pass
92
93     def visitObjPointer(self, pointer, lvalue, rvalue):
94         pass
95
96     def visitLinearPointer(self, pointer, lvalue, rvalue):
97         pass
98
99     def visitReference(self, reference, lvalue, rvalue):
100         self.visit(reference.type, lvalue, rvalue);
101
102     def visitHandle(self, handle, lvalue, rvalue):
103         pass
104
105     def visitBlob(self, blob, lvalue, rvalue):
106         pass
107
108     def visitString(self, string, lvalue, rvalue):
109         pass
110
111     def visitStruct(self, struct, lvalue, rvalue):
112         pass
113
114     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
115         self.visit(polymorphic.defaultType, lvalue, rvalue)
116
117
118 class ValueDeserializer(stdapi.Visitor):
119
120     def visitLiteral(self, literal, lvalue, rvalue):
121         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
122
123     def visitConst(self, const, lvalue, rvalue):
124         self.visit(const.type, lvalue, rvalue)
125
126     def visitAlias(self, alias, lvalue, rvalue):
127         self.visit(alias.type, lvalue, rvalue)
128     
129     def visitEnum(self, enum, lvalue, rvalue):
130         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
131
132     def visitBitmask(self, bitmask, lvalue, rvalue):
133         self.visit(bitmask.type, lvalue, rvalue)
134
135     def visitArray(self, array, lvalue, rvalue):
136
137         tmp = '__a_' + array.tag + '_' + str(self.seq)
138         self.seq += 1
139
140         print '    if (%s) {' % (lvalue,)
141         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
142         length = '%s->values.size()' % (tmp,)
143         index = '__j' + array.tag
144         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
145         try:
146             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
147         finally:
148             print '        }'
149             print '    }'
150     
151     def visitPointer(self, pointer, lvalue, rvalue):
152         tmp = '__a_' + pointer.tag + '_' + str(self.seq)
153         self.seq += 1
154
155         print '    if (%s) {' % (lvalue,)
156         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
157         try:
158             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
159         finally:
160             print '    }'
161
162     def visitIntPointer(self, pointer, lvalue, rvalue):
163         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
164
165     def visitObjPointer(self, pointer, lvalue, rvalue):
166         old_lvalue = '(%s).toUIntPtr()' % (rvalue,)
167         new_lvalue = '_obj_map[%s]' % (old_lvalue,)
168         print '    if (retrace::verbosity >= 2) {'
169         print '        std::cout << std::hex << "obj 0x" << size_t(%s) << " <- 0x" << size_t(%s) << std::dec <<"\\n";' % (old_lvalue, new_lvalue)
170         print '    }'
171         print '    %s = static_cast<%s>(%s);' % (lvalue, pointer, new_lvalue)
172
173     def visitLinearPointer(self, pointer, lvalue, rvalue):
174         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
175
176     def visitReference(self, reference, lvalue, rvalue):
177         self.visit(reference.type, lvalue, rvalue);
178
179     def visitHandle(self, handle, lvalue, rvalue):
180         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
181         self.visit(handle.type, lvalue, rvalue);
182         new_lvalue = lookupHandle(handle, lvalue)
183         print '    if (retrace::verbosity >= 2) {'
184         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
185         print '    }'
186         print '    %s = %s;' % (lvalue, new_lvalue)
187     
188     def visitBlob(self, blob, lvalue, rvalue):
189         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
190     
191     def visitString(self, string, lvalue, rvalue):
192         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
193
194     seq = 0
195
196     def visitStruct(self, struct, lvalue, rvalue):
197         tmp = '__s_' + struct.tag + '_' + str(self.seq)
198         self.seq += 1
199
200         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
201         print '    assert(%s);' % (tmp)
202         for i in range(len(struct.members)):
203             member_type, member_name = struct.members[i]
204             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
205
206     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
207         self.visit(polymorphic.defaultType, lvalue, rvalue)
208
209
210 class OpaqueValueDeserializer(ValueDeserializer):
211     '''Value extractor that also understands opaque values.
212
213     Normally opaque values can't be retraced, unless they are being extracted
214     in the context of handles.'''
215
216     def visitOpaque(self, opaque, lvalue, rvalue):
217         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
218
219
220 class SwizzledValueRegistrator(stdapi.Visitor):
221     '''Type visitor which will register (un)swizzled value pairs, to later be
222     swizzled.'''
223
224     def visitLiteral(self, literal, lvalue, rvalue):
225         pass
226
227     def visitAlias(self, alias, lvalue, rvalue):
228         self.visit(alias.type, lvalue, rvalue)
229     
230     def visitEnum(self, enum, lvalue, rvalue):
231         pass
232
233     def visitBitmask(self, bitmask, lvalue, rvalue):
234         pass
235
236     def visitArray(self, array, lvalue, rvalue):
237         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
238         print '    if (__a%s) {' % (array.tag)
239         length = '__a%s->values.size()' % array.tag
240         index = '__j' + array.tag
241         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
242         try:
243             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
244         finally:
245             print '        }'
246             print '    }'
247     
248     def visitPointer(self, pointer, lvalue, rvalue):
249         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
250         print '    if (__a%s) {' % (pointer.tag)
251         try:
252             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
253         finally:
254             print '    }'
255     
256     def visitIntPointer(self, pointer, lvalue, rvalue):
257         pass
258     
259     def visitObjPointer(self, pointer, lvalue, rvalue):
260         print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
261     
262     def visitLinearPointer(self, pointer, lvalue, rvalue):
263         assert pointer.size is not None
264         if pointer.size is not None:
265             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
266
267     def visitReference(self, reference, lvalue, rvalue):
268         pass
269     
270     def visitHandle(self, handle, lvalue, rvalue):
271         print '    %s __orig_result;' % handle.type
272         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
273         if handle.range is None:
274             rvalue = "__orig_result"
275             entry = lookupHandle(handle, rvalue) 
276             print "    %s = %s;" % (entry, lvalue)
277             print '    if (retrace::verbosity >= 2) {'
278             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
279             print '    }'
280         else:
281             i = '__h' + handle.tag
282             lvalue = "%s + %s" % (lvalue, i)
283             rvalue = "__orig_result + %s" % (i,)
284             entry = lookupHandle(handle, rvalue) 
285             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
286             print '        {entry} = {lvalue};'.format(**locals())
287             print '        if (retrace::verbosity >= 2) {'
288             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
289             print '        }'
290             print '    }'
291     
292     def visitBlob(self, blob, lvalue, rvalue):
293         pass
294     
295     def visitString(self, string, lvalue, rvalue):
296         pass
297
298     seq = 0
299
300     def visitStruct(self, struct, lvalue, rvalue):
301         tmp = '__s_' + struct.tag + '_' + str(self.seq)
302         self.seq += 1
303
304         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
305         print '    assert(%s);' % (tmp,)
306         print '    (void)%s;' % (tmp,)
307         for i in range(len(struct.members)):
308             member_type, member_name = struct.members[i]
309             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
310     
311     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
312         self.visit(polymorphic.defaultType, lvalue, rvalue)
313
314
315 class Retracer:
316
317     def retraceFunction(self, function):
318         print 'static void retrace_%s(trace::Call &call) {' % function.name
319         self.retraceFunctionBody(function)
320         print '}'
321         print
322
323     def retraceInterfaceMethod(self, interface, method):
324         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
325         self.retraceInterfaceMethodBody(interface, method)
326         print '}'
327         print
328
329     def retraceFunctionBody(self, function):
330         if not function.sideeffects:
331             print '    (void)call;'
332             return
333
334         self.deserializeArgs(function)
335         
336         self.invokeFunction(function)
337
338         self.swizzleValues(function)
339
340     def retraceInterfaceMethodBody(self, interface, method):
341         if not method.sideeffects:
342             print '    (void)call;'
343             return
344
345         self.deserializeThisPointer(interface)
346
347         self.deserializeArgs(method)
348         
349         self.invokeInterfaceMethod(interface, method)
350
351         self.swizzleValues(method)
352
353     def deserializeThisPointer(self, interface):
354         print r'    %s *_this;' % (interface.name,)
355         print r'    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
356         print r'    if (!_this) {'
357         print r'        retrace::warning(call) << "NULL this pointer\n";'
358         print r'        return;'
359         print r'    }'
360
361     def deserializeArgs(self, function):
362         print '    retrace::ScopedAllocator _allocator;'
363         print '    (void)_allocator;'
364         success = True
365         for arg in function.args:
366             arg_type = MutableRebuilder().visit(arg.type)
367             print '    %s %s;' % (arg_type, arg.name)
368             rvalue = 'call.arg(%u)' % (arg.index,)
369             lvalue = arg.name
370             try:
371                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
372             except NotImplementedError:
373                 success =  False
374                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
375             print
376
377         if not success:
378             print '    if (1) {'
379             self.failFunction(function)
380             if function.name[-1].islower():
381                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
382             print '    }'
383
384     def swizzleValues(self, function):
385         for arg in function.args:
386             if arg.output:
387                 arg_type = MutableRebuilder().visit(arg.type)
388                 rvalue = 'call.arg(%u)' % (arg.index,)
389                 lvalue = arg.name
390                 try:
391                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
392                 except NotImplementedError:
393                     print '    // XXX: %s' % arg.name
394         if function.type is not stdapi.Void:
395             rvalue = '*call.ret'
396             lvalue = '__result'
397             try:
398                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
399             except NotImplementedError:
400                 raise
401                 print '    // XXX: result'
402
403     def failFunction(self, function):
404         print '    if (retrace::verbosity >= 0) {'
405         print '        retrace::unsupported(call);'
406         print '    }'
407         print '    return;'
408
409     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
410         ValueAllocator().visit(arg_type, lvalue, rvalue)
411         if arg.input:
412             ValueDeserializer().visit(arg_type, lvalue, rvalue)
413     
414     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
415         try:
416             ValueAllocator().visit(arg_type, lvalue, rvalue)
417         except NotImplementedError:
418             pass
419         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
420
421     def regiterSwizzledValue(self, type, lvalue, rvalue):
422         visitor = SwizzledValueRegistrator()
423         visitor.visit(type, lvalue, rvalue)
424
425     def invokeFunction(self, function):
426         arg_names = ", ".join(function.argNames())
427         if function.type is not stdapi.Void:
428             print '    %s __result;' % (function.type)
429             print '    __result = %s(%s);' % (function.name, arg_names)
430             print '    (void)__result;'
431         else:
432             print '    %s(%s);' % (function.name, arg_names)
433
434     def invokeInterfaceMethod(self, interface, method):
435         arg_names = ", ".join(method.argNames())
436         if method.type is not stdapi.Void:
437             print '    %s __result;' % (method.type)
438             print '    __result = _this->%s(%s);' % (method.name, arg_names)
439             print '    (void)__result;'
440         else:
441             print '    _this->%s(%s);' % (method.name, arg_names)
442
443     def filterFunction(self, function):
444         return True
445
446     table_name = 'retrace::callbacks'
447
448     def retraceApi(self, api):
449
450         print '#include "os_time.hpp"'
451         print '#include "trace_parser.hpp"'
452         print '#include "retrace.hpp"'
453         print
454
455         types = api.getAllTypes()
456         handles = [type for type in types if isinstance(type, stdapi.Handle)]
457         handle_names = set()
458         for handle in handles:
459             if handle.name not in handle_names:
460                 if handle.key is None:
461                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
462                 else:
463                     key_name, key_type = handle.key
464                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
465                 handle_names.add(handle.name)
466         print
467
468         print 'static std::map<unsigned long long, void *> _obj_map;'
469         print
470
471         functions = filter(self.filterFunction, api.functions)
472         for function in functions:
473             self.retraceFunction(function)
474         interfaces = api.getAllInterfaces()
475         for interface in interfaces:
476             for method in interface.iterMethods():
477                 self.retraceInterfaceMethod(interface, method)
478
479         print 'const retrace::Entry %s[] = {' % self.table_name
480         for function in functions:
481             print '    {"%s", &retrace_%s},' % (function.name, function.name)
482         for interface in interfaces:
483             for method in interface.iterMethods():
484                 print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
485         print '    {NULL, NULL}'
486         print '};'
487         print
488