]> git.cworth.org Git - apitrace/blob - retrace/retrace.py
d3d10,d3d11: Rudimentary retrace support.
[apitrace] / retrace / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 # Adjust path
31 import os.path
32 import sys
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
34
35
36 import specs.stdapi as stdapi
37
38
39 class UnsupportedType(Exception):
40     pass
41
42
43 def lookupHandle(handle, value):
44     if handle.key is None:
45         return "_%s_map[%s]" % (handle.name, value)
46     else:
47         key_name, key_type = handle.key
48         return "_%s_map[%s][%s]" % (handle.name, key_name, value)
49
50
51 class ValueAllocator(stdapi.Visitor):
52
53     def visitLiteral(self, literal, lvalue, rvalue):
54         pass
55
56     def visitConst(self, const, lvalue, rvalue):
57         self.visit(const.type, lvalue, rvalue)
58
59     def visitAlias(self, alias, lvalue, rvalue):
60         self.visit(alias.type, lvalue, rvalue)
61
62     def visitEnum(self, enum, lvalue, rvalue):
63         pass
64
65     def visitBitmask(self, bitmask, lvalue, rvalue):
66         pass
67
68     def visitArray(self, array, lvalue, rvalue):
69         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
70
71     def visitPointer(self, pointer, lvalue, rvalue):
72         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
73
74     def visitIntPointer(self, pointer, lvalue, rvalue):
75         pass
76
77     def visitObjPointer(self, pointer, lvalue, rvalue):
78         pass
79
80     def visitLinearPointer(self, pointer, lvalue, rvalue):
81         pass
82
83     def visitReference(self, reference, lvalue, rvalue):
84         self.visit(reference.type, lvalue, rvalue);
85
86     def visitHandle(self, handle, lvalue, rvalue):
87         pass
88
89     def visitBlob(self, blob, lvalue, rvalue):
90         pass
91
92     def visitString(self, string, lvalue, rvalue):
93         pass
94
95     def visitStruct(self, struct, lvalue, rvalue):
96         pass
97
98     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
99         self.visit(polymorphic.defaultType, lvalue, rvalue)
100
101     def visitOpaque(self, opaque, lvalue, rvalue):
102         pass
103
104
105 class ValueDeserializer(stdapi.Visitor):
106
107     def visitLiteral(self, literal, lvalue, rvalue):
108         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
109
110     def visitConst(self, const, lvalue, rvalue):
111         self.visit(const.type, lvalue, rvalue)
112
113     def visitAlias(self, alias, lvalue, rvalue):
114         self.visit(alias.type, lvalue, rvalue)
115     
116     def visitEnum(self, enum, lvalue, rvalue):
117         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
118
119     def visitBitmask(self, bitmask, lvalue, rvalue):
120         self.visit(bitmask.type, lvalue, rvalue)
121
122     def visitArray(self, array, lvalue, rvalue):
123
124         tmp = '_a_' + array.tag + '_' + str(self.seq)
125         self.seq += 1
126
127         print '    if (%s) {' % (lvalue,)
128         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
129         length = '%s->values.size()' % (tmp,)
130         index = '_j' + array.tag
131         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
132         try:
133             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
134         finally:
135             print '        }'
136             print '    }'
137     
138     def visitPointer(self, pointer, lvalue, rvalue):
139         tmp = '_a_' + pointer.tag + '_' + str(self.seq)
140         self.seq += 1
141
142         print '    if (%s) {' % (lvalue,)
143         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
144         try:
145             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
146         finally:
147             print '    }'
148
149     def visitIntPointer(self, pointer, lvalue, rvalue):
150         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
151
152     def visitObjPointer(self, pointer, lvalue, rvalue):
153         print '    %s = static_cast<%s>(retrace::toObjPointer(%s));' % (lvalue, pointer, rvalue)
154
155     def visitLinearPointer(self, pointer, lvalue, rvalue):
156         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
157
158     def visitReference(self, reference, lvalue, rvalue):
159         self.visit(reference.type, lvalue, rvalue);
160
161     def visitHandle(self, handle, lvalue, rvalue):
162         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
163         self.visit(handle.type, lvalue, rvalue);
164         new_lvalue = lookupHandle(handle, lvalue)
165         print '    if (retrace::verbosity >= 2) {'
166         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
167         print '    }'
168         print '    %s = %s;' % (lvalue, new_lvalue)
169     
170     def visitBlob(self, blob, lvalue, rvalue):
171         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
172     
173     def visitString(self, string, lvalue, rvalue):
174         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
175
176     seq = 0
177
178     def visitStruct(self, struct, lvalue, rvalue):
179         tmp = '_s_' + struct.tag + '_' + str(self.seq)
180         self.seq += 1
181
182         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
183         print '    assert(%s);' % (tmp)
184         for i in range(len(struct.members)):
185             member_type, member_name = struct.members[i]
186             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
187
188     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
189         self.visit(polymorphic.defaultType, lvalue, rvalue)
190     
191     def visitOpaque(self, opaque, lvalue, rvalue):
192         raise UnsupportedType
193
194
195 class OpaqueValueDeserializer(ValueDeserializer):
196     '''Value extractor that also understands opaque values.
197
198     Normally opaque values can't be retraced, unless they are being extracted
199     in the context of handles.'''
200
201     def visitOpaque(self, opaque, lvalue, rvalue):
202         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
203
204
205 class SwizzledValueRegistrator(stdapi.Visitor):
206     '''Type visitor which will register (un)swizzled value pairs, to later be
207     swizzled.'''
208
209     def visitLiteral(self, literal, lvalue, rvalue):
210         pass
211
212     def visitAlias(self, alias, lvalue, rvalue):
213         self.visit(alias.type, lvalue, rvalue)
214     
215     def visitEnum(self, enum, lvalue, rvalue):
216         pass
217
218     def visitBitmask(self, bitmask, lvalue, rvalue):
219         pass
220
221     def visitArray(self, array, lvalue, rvalue):
222         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
223         print '    if (_a%s) {' % (array.tag)
224         length = '_a%s->values.size()' % array.tag
225         index = '_j' + array.tag
226         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
227         try:
228             self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
229         finally:
230             print '        }'
231             print '    }'
232     
233     def visitPointer(self, pointer, lvalue, rvalue):
234         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
235         print '    if (_a%s) {' % (pointer.tag)
236         try:
237             self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
238         finally:
239             print '    }'
240     
241     def visitIntPointer(self, pointer, lvalue, rvalue):
242         pass
243     
244     def visitObjPointer(self, pointer, lvalue, rvalue):
245         print r'    retrace::addObj(%s, %s);' % (rvalue, lvalue)
246     
247     def visitLinearPointer(self, pointer, lvalue, rvalue):
248         assert pointer.size is not None
249         if pointer.size is not None:
250             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
251
252     def visitReference(self, reference, lvalue, rvalue):
253         pass
254     
255     def visitHandle(self, handle, lvalue, rvalue):
256         print '    %s _origResult;' % handle.type
257         OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
258         if handle.range is None:
259             rvalue = "_origResult"
260             entry = lookupHandle(handle, rvalue) 
261             print "    %s = %s;" % (entry, lvalue)
262             print '    if (retrace::verbosity >= 2) {'
263             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
264             print '    }'
265         else:
266             i = '_h' + handle.tag
267             lvalue = "%s + %s" % (lvalue, i)
268             rvalue = "_origResult + %s" % (i,)
269             entry = lookupHandle(handle, rvalue) 
270             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
271             print '        {entry} = {lvalue};'.format(**locals())
272             print '        if (retrace::verbosity >= 2) {'
273             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
274             print '        }'
275             print '    }'
276     
277     def visitBlob(self, blob, lvalue, rvalue):
278         pass
279     
280     def visitString(self, string, lvalue, rvalue):
281         pass
282
283     seq = 0
284
285     def visitStruct(self, struct, lvalue, rvalue):
286         tmp = '_s_' + struct.tag + '_' + str(self.seq)
287         self.seq += 1
288
289         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
290         print '    assert(%s);' % (tmp,)
291         print '    (void)%s;' % (tmp,)
292         for i in range(len(struct.members)):
293             member_type, member_name = struct.members[i]
294             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
295     
296     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
297         self.visit(polymorphic.defaultType, lvalue, rvalue)
298     
299     def visitOpaque(self, opaque, lvalue, rvalue):
300         pass
301
302
303 class Retracer:
304
305     def retraceFunction(self, function):
306         print 'static void retrace_%s(trace::Call &call) {' % function.name
307         self.retraceFunctionBody(function)
308         print '}'
309         print
310
311     def retraceInterfaceMethod(self, interface, method):
312         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
313         self.retraceInterfaceMethodBody(interface, method)
314         print '}'
315         print
316
317     def retraceFunctionBody(self, function):
318         assert function.sideeffects
319
320         if function.type is not stdapi.Void:
321             self.checkOrigResult(function)
322
323         self.deserializeArgs(function)
324         
325         self.declareRet(function)
326         self.invokeFunction(function)
327
328         self.swizzleValues(function)
329
330     def retraceInterfaceMethodBody(self, interface, method):
331         assert method.sideeffects
332
333         if method.type is not stdapi.Void:
334             self.checkOrigResult(method)
335
336         self.deserializeThisPointer(interface)
337
338         self.deserializeArgs(method)
339         
340         self.declareRet(method)
341         self.invokeInterfaceMethod(interface, method)
342
343         self.swizzleValues(method)
344
345     def checkOrigResult(self, function):
346         '''Hook for checking the original result, to prevent succeeding now
347         where the original did not, which would cause diversion and potentially
348         unpredictable results.'''
349
350         assert function.type is not stdapi.Void
351
352         if str(function.type) == 'HRESULT':
353             print r'    if (call.ret && FAILED(call.ret->toSInt())) {'
354             print r'        return;'
355             print r'    }'
356
357     def deserializeThisPointer(self, interface):
358         print r'    %s *_this;' % (interface.name,)
359         print r'    _this = static_cast<%s *>(retrace::toObjPointer(call.arg(0)));' % (interface.name,)
360         print r'    if (!_this) {'
361         print r'        retrace::warning(call) << "NULL this pointer\n";'
362         print r'        return;'
363         print r'    }'
364
365     def deserializeArgs(self, function):
366         print '    retrace::ScopedAllocator _allocator;'
367         print '    (void)_allocator;'
368         success = True
369         for arg in function.args:
370             arg_type = arg.type.mutable()
371             print '    %s %s;' % (arg_type, arg.name)
372             rvalue = 'call.arg(%u)' % (arg.index,)
373             lvalue = arg.name
374             try:
375                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
376             except UnsupportedType:
377                 success =  False
378                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
379             print
380
381         if not success:
382             print '    if (1) {'
383             self.failFunction(function)
384             if function.name[-1].islower():
385                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
386             print '    }'
387
388     def swizzleValues(self, function):
389         for arg in function.args:
390             if arg.output:
391                 arg_type = arg.type.mutable()
392                 rvalue = 'call.arg(%u)' % (arg.index,)
393                 lvalue = arg.name
394                 try:
395                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
396                 except UnsupportedType:
397                     print '    // XXX: %s' % arg.name
398         if function.type is not stdapi.Void:
399             rvalue = '*call.ret'
400             lvalue = '_result'
401             try:
402                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
403             except UnsupportedType:
404                 raise
405                 print '    // XXX: result'
406
407     def failFunction(self, function):
408         print '    if (retrace::verbosity >= 0) {'
409         print '        retrace::unsupported(call);'
410         print '    }'
411         print '    return;'
412
413     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
414         ValueAllocator().visit(arg_type, lvalue, rvalue)
415         if arg.input:
416             ValueDeserializer().visit(arg_type, lvalue, rvalue)
417     
418     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
419         try:
420             ValueAllocator().visit(arg_type, lvalue, rvalue)
421         except UnsupportedType:
422             pass
423         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
424
425     def regiterSwizzledValue(self, type, lvalue, rvalue):
426         visitor = SwizzledValueRegistrator()
427         visitor.visit(type, lvalue, rvalue)
428
429     def declareRet(self, function):
430         if function.type is not stdapi.Void:
431             print '    %s _result;' % (function.type)
432
433     def invokeFunction(self, function):
434         arg_names = ", ".join(function.argNames())
435         if function.type is not stdapi.Void:
436             print '    _result = %s(%s);' % (function.name, arg_names)
437             print '    (void)_result;'
438         else:
439             print '    %s(%s);' % (function.name, arg_names)
440
441     def invokeInterfaceMethod(self, interface, method):
442         # On release our reference when we reach Release() == 0 call in the
443         # trace.
444         if method.name == 'Release':
445             print '    if (call.ret->toUInt()) {'
446             print '        return;'
447             print '    }'
448             print '    retrace::delObj(call.arg(0));'
449
450         arg_names = ", ".join(method.argNames())
451         if method.type is not stdapi.Void:
452             print '    _result = _this->%s(%s);' % (method.name, arg_names)
453             print '    (void)_result;'
454         else:
455             print '    _this->%s(%s);' % (method.name, arg_names)
456
457     def filterFunction(self, function):
458         return True
459
460     table_name = 'retrace::callbacks'
461
462     def retraceApi(self, api):
463
464         print '#include "os_time.hpp"'
465         print '#include "trace_parser.hpp"'
466         print '#include "retrace.hpp"'
467         print '#include "retrace_swizzle.hpp"'
468         print
469
470         types = api.getAllTypes()
471         handles = [type for type in types if isinstance(type, stdapi.Handle)]
472         handle_names = set()
473         for handle in handles:
474             if handle.name not in handle_names:
475                 if handle.key is None:
476                     print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
477                 else:
478                     key_name, key_type = handle.key
479                     print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
480                 handle_names.add(handle.name)
481         print
482
483         functions = filter(self.filterFunction, api.functions)
484         for function in functions:
485             if function.sideeffects and not function.internal:
486                 self.retraceFunction(function)
487         interfaces = api.getAllInterfaces()
488         for interface in interfaces:
489             for method in interface.iterMethods():
490                 if method.sideeffects and not method.internal:
491                     self.retraceInterfaceMethod(interface, method)
492
493         print 'const retrace::Entry %s[] = {' % self.table_name
494         for function in functions:
495             if not function.internal:
496                 if function.sideeffects:
497                     print '    {"%s", &retrace_%s},' % (function.name, function.name)
498                 else:
499                     print '    {"%s", &retrace::ignore},' % (function.name,)
500         for interface in interfaces:
501             for method in interface.iterMethods():                
502                 if method.sideeffects:
503                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
504                 else:
505                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
506         print '    {NULL, NULL}'
507         print '};'
508         print
509