]> git.cworth.org Git - apitrace/blob - retrace/retrace.py
Fix D3D11 tracing with D3D11_CREATE_DEVICE_DEBUG flag.
[apitrace] / retrace / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 # Adjust path
31 import os.path
32 import sys
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
34
35
36 import specs.stdapi as stdapi
37 import specs.glapi as glapi
38
39
40 class UnsupportedType(Exception):
41     pass
42
43
44 def lookupHandle(handle, value):
45     if handle.key is None:
46         return "_%s_map[%s]" % (handle.name, value)
47     else:
48         key_name, key_type = handle.key
49         return "_%s_map[%s][%s]" % (handle.name, key_name, value)
50
51
52 class ValueAllocator(stdapi.Visitor):
53
54     def visitLiteral(self, literal, lvalue, rvalue):
55         pass
56
57     def visitConst(self, const, lvalue, rvalue):
58         self.visit(const.type, lvalue, rvalue)
59
60     def visitAlias(self, alias, lvalue, rvalue):
61         self.visit(alias.type, lvalue, rvalue)
62
63     def visitEnum(self, enum, lvalue, rvalue):
64         pass
65
66     def visitBitmask(self, bitmask, lvalue, rvalue):
67         pass
68
69     def visitArray(self, array, lvalue, rvalue):
70         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
71
72     def visitPointer(self, pointer, lvalue, rvalue):
73         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
74
75     def visitIntPointer(self, pointer, lvalue, rvalue):
76         pass
77
78     def visitObjPointer(self, pointer, lvalue, rvalue):
79         pass
80
81     def visitLinearPointer(self, pointer, lvalue, rvalue):
82         pass
83
84     def visitReference(self, reference, lvalue, rvalue):
85         self.visit(reference.type, lvalue, rvalue);
86
87     def visitHandle(self, handle, lvalue, rvalue):
88         pass
89
90     def visitBlob(self, blob, lvalue, rvalue):
91         pass
92
93     def visitString(self, string, lvalue, rvalue):
94         pass
95
96     def visitStruct(self, struct, lvalue, rvalue):
97         pass
98
99     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
100         self.visit(polymorphic.defaultType, lvalue, rvalue)
101
102     def visitOpaque(self, opaque, lvalue, rvalue):
103         pass
104
105
106 class ValueDeserializer(stdapi.Visitor):
107
108     def visitLiteral(self, literal, lvalue, rvalue):
109         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
110
111     def visitConst(self, const, lvalue, rvalue):
112         self.visit(const.type, lvalue, rvalue)
113
114     def visitAlias(self, alias, lvalue, rvalue):
115         self.visit(alias.type, lvalue, rvalue)
116     
117     def visitEnum(self, enum, lvalue, rvalue):
118         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
119
120     def visitBitmask(self, bitmask, lvalue, rvalue):
121         self.visit(bitmask.type, lvalue, rvalue)
122
123     def visitArray(self, array, lvalue, rvalue):
124
125         tmp = '_a_' + array.tag + '_' + str(self.seq)
126         self.seq += 1
127
128         print '    if (%s) {' % (lvalue,)
129         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
130         length = '%s->values.size()' % (tmp,)
131         index = '_j' + array.tag
132         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
133         try:
134             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
135         finally:
136             print '        }'
137             print '    }'
138     
139     def visitPointer(self, pointer, lvalue, rvalue):
140         tmp = '_a_' + pointer.tag + '_' + str(self.seq)
141         self.seq += 1
142
143         print '    if (%s) {' % (lvalue,)
144         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
145         try:
146             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
147         finally:
148             print '    }'
149
150     def visitIntPointer(self, pointer, lvalue, rvalue):
151         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
152
153     def visitObjPointer(self, pointer, lvalue, rvalue):
154         old_lvalue = '(%s).toUIntPtr()' % (rvalue,)
155         new_lvalue = '_obj_map[%s]' % (old_lvalue,)
156         print '    if (retrace::verbosity >= 2) {'
157         print '        std::cout << std::hex << "obj 0x" << size_t(%s) << " <- 0x" << size_t(%s) << std::dec <<"\\n";' % (old_lvalue, new_lvalue)
158         print '    }'
159         print '    %s = static_cast<%s>(%s);' % (lvalue, pointer, new_lvalue)
160
161     def visitLinearPointer(self, pointer, lvalue, rvalue):
162         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
163
164     def visitReference(self, reference, lvalue, rvalue):
165         self.visit(reference.type, lvalue, rvalue);
166
167     def visitHandle(self, handle, lvalue, rvalue):
168         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
169         self.visit(handle.type, lvalue, rvalue);
170         new_lvalue = lookupHandle(handle, lvalue)
171         print '    if (retrace::verbosity >= 2) {'
172         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
173         print '    }'
174         print '    %s = %s;' % (lvalue, new_lvalue)
175     
176     def visitBlob(self, blob, lvalue, rvalue):
177         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
178     
179     def visitString(self, string, lvalue, rvalue):
180         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
181
182     seq = 0
183
184     def visitStruct(self, struct, lvalue, rvalue):
185         tmp = '_s_' + struct.tag + '_' + str(self.seq)
186         self.seq += 1
187
188         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
189         print '    assert(%s);' % (tmp)
190         for i in range(len(struct.members)):
191             member_type, member_name = struct.members[i]
192             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
193
194     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
195         self.visit(polymorphic.defaultType, lvalue, rvalue)
196     
197     def visitOpaque(self, opaque, lvalue, rvalue):
198         raise UnsupportedType
199
200
201 class OpaqueValueDeserializer(ValueDeserializer):
202     '''Value extractor that also understands opaque values.
203
204     Normally opaque values can't be retraced, unless they are being extracted
205     in the context of handles.'''
206
207     def visitOpaque(self, opaque, lvalue, rvalue):
208         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
209
210
211 class SwizzledValueRegistrator(stdapi.Visitor):
212     '''Type visitor which will register (un)swizzled value pairs, to later be
213     swizzled.'''
214
215     def visitLiteral(self, literal, lvalue, rvalue):
216         pass
217
218     def visitAlias(self, alias, lvalue, rvalue):
219         self.visit(alias.type, lvalue, rvalue)
220     
221     def visitEnum(self, enum, lvalue, rvalue):
222         pass
223
224     def visitBitmask(self, bitmask, lvalue, rvalue):
225         pass
226
227     def visitArray(self, array, lvalue, rvalue):
228         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
229         print '    if (_a%s) {' % (array.tag)
230         length = '_a%s->values.size()' % array.tag
231         index = '_j' + array.tag
232         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
233         try:
234             self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
235         finally:
236             print '        }'
237             print '    }'
238     
239     def visitPointer(self, pointer, lvalue, rvalue):
240         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
241         print '    if (_a%s) {' % (pointer.tag)
242         try:
243             self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
244         finally:
245             print '    }'
246     
247     def visitIntPointer(self, pointer, lvalue, rvalue):
248         pass
249     
250     def visitObjPointer(self, pointer, lvalue, rvalue):
251         print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
252     
253     def visitLinearPointer(self, pointer, lvalue, rvalue):
254         assert pointer.size is not None
255         if pointer.size is not None:
256             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
257
258     def visitReference(self, reference, lvalue, rvalue):
259         pass
260     
261     def visitHandle(self, handle, lvalue, rvalue):
262         print '    %s _origResult;' % handle.type
263         OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
264         if handle.range is None:
265             rvalue = "_origResult"
266             entry = lookupHandle(handle, rvalue) 
267             print "    %s = %s;" % (entry, lvalue)
268             print '    if (retrace::verbosity >= 2) {'
269             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
270             print '    }'
271         else:
272             i = '_h' + handle.tag
273             lvalue = "%s + %s" % (lvalue, i)
274             rvalue = "_origResult + %s" % (i,)
275             entry = lookupHandle(handle, rvalue) 
276             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
277             print '        {entry} = {lvalue};'.format(**locals())
278             print '        if (retrace::verbosity >= 2) {'
279             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
280             print '        }'
281             print '    }'
282     
283     def visitBlob(self, blob, lvalue, rvalue):
284         pass
285     
286     def visitString(self, string, lvalue, rvalue):
287         pass
288
289     seq = 0
290
291     def visitStruct(self, struct, lvalue, rvalue):
292         tmp = '_s_' + struct.tag + '_' + str(self.seq)
293         self.seq += 1
294
295         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
296         print '    assert(%s);' % (tmp,)
297         print '    (void)%s;' % (tmp,)
298         for i in range(len(struct.members)):
299             member_type, member_name = struct.members[i]
300             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
301     
302     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
303         self.visit(polymorphic.defaultType, lvalue, rvalue)
304     
305     def visitOpaque(self, opaque, lvalue, rvalue):
306         pass
307
308
309 class Retracer:
310
311     def retraceFunction(self, function):
312         print 'static void retrace_%s(trace::Call &call) {' % function.name
313         self.retraceFunctionBody(function)
314         print '}'
315         print
316
317     def retraceInterfaceMethod(self, interface, method):
318         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
319         self.retraceInterfaceMethodBody(interface, method)
320         print '}'
321         print
322
323     def retraceFunctionBody(self, function):
324         assert function.sideeffects
325
326         if function.type is not stdapi.Void:
327             self.checkOrigResult(function)
328
329         self.deserializeArgs(function)
330         
331         self.invokeFunction(function)
332
333         self.swizzleValues(function)
334
335     def retraceInterfaceMethodBody(self, interface, method):
336         assert method.sideeffects
337
338         if method.type is not stdapi.Void:
339             self.checkOrigResult(method)
340
341         self.deserializeThisPointer(interface)
342
343         self.deserializeArgs(method)
344         
345         self.invokeInterfaceMethod(interface, method)
346
347         self.swizzleValues(method)
348
349     def checkOrigResult(self, function):
350         '''Hook for checking the original result, to prevent succeeding now
351         where the original did not, which would cause diversion and potentially
352         unpredictable results.'''
353
354         assert function.type is not stdapi.Void
355
356         if str(function.type) == 'HRESULT':
357             print r'    if (call.ret && FAILED(call.ret->toSInt())) {'
358             print r'        return;'
359             print r'    }'
360
361     def deserializeThisPointer(self, interface):
362         print r'    %s *_this;' % (interface.name,)
363         print r'    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
364         print r'    if (!_this) {'
365         print r'        retrace::warning(call) << "NULL this pointer\n";'
366         print r'        return;'
367         print r'    }'
368
369     def deserializeArgs(self, function):
370         print '    retrace::ScopedAllocator _allocator;'
371         print '    (void)_allocator;'
372         success = True
373         for arg in function.args:
374             arg_type = arg.type.mutable()
375             print '    %s %s;' % (arg_type, arg.name)
376             rvalue = 'call.arg(%u)' % (arg.index,)
377             lvalue = arg.name
378             try:
379                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
380             except UnsupportedType:
381                 success =  False
382                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
383             print
384
385         if not success:
386             print '    if (1) {'
387             self.failFunction(function)
388             if function.name[-1].islower():
389                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
390             print '    }'
391
392     def swizzleValues(self, function):
393         for arg in function.args:
394             if arg.output:
395                 arg_type = arg.type.mutable()
396                 rvalue = 'call.arg(%u)' % (arg.index,)
397                 lvalue = arg.name
398                 try:
399                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
400                 except UnsupportedType:
401                     print '    // XXX: %s' % arg.name
402         if function.type is not stdapi.Void:
403             rvalue = '*call.ret'
404             lvalue = '_result'
405             try:
406                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
407             except UnsupportedType:
408                 raise
409                 print '    // XXX: result'
410
411     def failFunction(self, function):
412         print '    if (retrace::verbosity >= 0) {'
413         print '        retrace::unsupported(call);'
414         print '    }'
415         print '    return;'
416
417     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
418         ValueAllocator().visit(arg_type, lvalue, rvalue)
419         if arg.input:
420             ValueDeserializer().visit(arg_type, lvalue, rvalue)
421     
422     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
423         try:
424             ValueAllocator().visit(arg_type, lvalue, rvalue)
425         except UnsupportedType:
426             pass
427         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
428
429     def regiterSwizzledValue(self, type, lvalue, rvalue):
430         visitor = SwizzledValueRegistrator()
431         visitor.visit(type, lvalue, rvalue)
432
433     def invokeFunction(self, function):
434         arg_names = ", ".join(function.argNames())
435         if function.type is not stdapi.Void:
436             print '    %s _result;' % (function.type)
437             print '    _result = %s(%s);' % (function.name, arg_names)
438             print '    (void)_result;'
439         else:
440             print '    %s(%s);' % (function.name, arg_names)
441
442     def invokeInterfaceMethod(self, interface, method):
443         # On release our reference when we reach Release() == 0 call in the
444         # trace.
445         if method.name == 'Release':
446             print '    if (call.ret->toUInt()) {'
447             print '        return;'
448             print '    }'
449             print '    _obj_map.erase(call.arg(0).toUIntPtr());'
450
451         arg_names = ", ".join(method.argNames())
452         if method.type is not stdapi.Void:
453             print '    %s _result;' % (method.type)
454             print '    _result = _this->%s(%s);' % (method.name, arg_names)
455             print '    (void)_result;'
456         else:
457             print '    _this->%s(%s);' % (method.name, arg_names)
458
459     def filterFunction(self, function):
460         return True
461
462     table_name = 'retrace::callbacks'
463
464     def retraceApi(self, api):
465
466         print '#include "os_time.hpp"'
467         print '#include "trace_parser.hpp"'
468         print '#include "retrace.hpp"'
469         print
470
471         types = api.getAllTypes()
472         handles = [type for type in types if isinstance(type, stdapi.Handle)]
473         handle_names = set()
474         for handle in handles:
475             if handle.name not in handle_names:
476                 if handle.key is None:
477                     print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
478                 else:
479                     key_name, key_type = handle.key
480                     print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
481                 handle_names.add(handle.name)
482         print
483
484         print 'static std::map<unsigned long long, void *> _obj_map;'
485         print
486
487         functions = filter(self.filterFunction, api.functions)
488         for function in functions:
489             if function.sideeffects and not function.internal:
490                 self.retraceFunction(function)
491         interfaces = api.getAllInterfaces()
492         for interface in interfaces:
493             for method in interface.iterMethods():
494                 if method.sideeffects and not method.internal:
495                     self.retraceInterfaceMethod(interface, method)
496
497         print 'const retrace::Entry %s[] = {' % self.table_name
498         for function in functions:
499             if not function.internal:
500                 if function.sideeffects:
501                     print '    {"%s", &retrace_%s},' % (function.name, function.name)
502                 else:
503                     print '    {"%s", &retrace::ignore},' % (function.name,)
504         for interface in interfaces:
505             for method in interface.iterMethods():                
506                 if method.sideeffects:
507                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
508                 else:
509                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
510         print '    {NULL, NULL}'
511         print '};'
512         print
513