]> git.cworth.org Git - apitrace/blob - retrace/retrace.py
Replace dynamic_cast with virtual functions.
[apitrace] / retrace / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 # Adjust path
31 import os.path
32 import sys
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
34
35
36 import specs.stdapi as stdapi
37
38
39 class UnsupportedType(Exception):
40     pass
41
42
43 def lookupHandle(handle, value, lval=False):
44     if handle.key is None:
45         return "_%s_map[%s]" % (handle.name, value)
46     else:
47         key_name, key_type = handle.key
48         if handle.name == "location" and lval == False:
49             return "_location_map[%s].lookupUniformLocation(%s)" % (key_name, value)
50         else:
51             return "_%s_map[%s][%s]" % (handle.name, key_name, value)
52
53
54 class ValueAllocator(stdapi.Visitor):
55
56     def visitLiteral(self, literal, lvalue, rvalue):
57         pass
58
59     def visitConst(self, const, lvalue, rvalue):
60         self.visit(const.type, lvalue, rvalue)
61
62     def visitAlias(self, alias, lvalue, rvalue):
63         self.visit(alias.type, lvalue, rvalue)
64
65     def visitEnum(self, enum, lvalue, rvalue):
66         pass
67
68     def visitBitmask(self, bitmask, lvalue, rvalue):
69         pass
70
71     def visitArray(self, array, lvalue, rvalue):
72         print '    %s = static_cast<%s *>(_allocator.alloc(&%s, sizeof *%s));' % (lvalue, array.type, rvalue, lvalue)
73
74     def visitPointer(self, pointer, lvalue, rvalue):
75         print '    %s = static_cast<%s *>(_allocator.alloc(&%s, sizeof *%s));' % (lvalue, pointer.type, rvalue, lvalue)
76
77     def visitIntPointer(self, pointer, lvalue, rvalue):
78         pass
79
80     def visitObjPointer(self, pointer, lvalue, rvalue):
81         pass
82
83     def visitLinearPointer(self, pointer, lvalue, rvalue):
84         pass
85
86     def visitReference(self, reference, lvalue, rvalue):
87         self.visit(reference.type, lvalue, rvalue);
88
89     def visitHandle(self, handle, lvalue, rvalue):
90         pass
91
92     def visitBlob(self, blob, lvalue, rvalue):
93         pass
94
95     def visitString(self, string, lvalue, rvalue):
96         pass
97
98     def visitStruct(self, struct, lvalue, rvalue):
99         pass
100
101     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
102         assert polymorphic.defaultType is not None
103         self.visit(polymorphic.defaultType, lvalue, rvalue)
104
105     def visitOpaque(self, opaque, lvalue, rvalue):
106         pass
107
108
109 class ValueDeserializer(stdapi.Visitor, stdapi.ExpanderMixin):
110
111     def visitLiteral(self, literal, lvalue, rvalue):
112         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
113
114     def visitConst(self, const, lvalue, rvalue):
115         self.visit(const.type, lvalue, rvalue)
116
117     def visitAlias(self, alias, lvalue, rvalue):
118         self.visit(alias.type, lvalue, rvalue)
119     
120     def visitEnum(self, enum, lvalue, rvalue):
121         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
122
123     def visitBitmask(self, bitmask, lvalue, rvalue):
124         self.visit(bitmask.type, lvalue, rvalue)
125
126     def visitArray(self, array, lvalue, rvalue):
127
128         tmp = '_a_' + array.tag + '_' + str(self.seq)
129         self.seq += 1
130
131         print '    if (%s) {' % (lvalue,)
132         print '        const trace::Array *%s = (%s).toArray();' % (tmp, rvalue)
133         length = '%s->values.size()' % (tmp,)
134         index = '_j' + array.tag
135         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
136         try:
137             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
138         finally:
139             print '        }'
140             print '    }'
141     
142     def visitPointer(self, pointer, lvalue, rvalue):
143         tmp = '_a_' + pointer.tag + '_' + str(self.seq)
144         self.seq += 1
145
146         print '    if (%s) {' % (lvalue,)
147         print '        const trace::Array *%s = (%s).toArray();' % (tmp, rvalue)
148         try:
149             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
150         finally:
151             print '    }'
152
153     def visitIntPointer(self, pointer, lvalue, rvalue):
154         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
155
156     def visitObjPointer(self, pointer, lvalue, rvalue):
157         print '    %s = static_cast<%s>(retrace::toObjPointer(call, %s));' % (lvalue, pointer, rvalue)
158
159     def visitLinearPointer(self, pointer, lvalue, rvalue):
160         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
161
162     def visitReference(self, reference, lvalue, rvalue):
163         self.visit(reference.type, lvalue, rvalue);
164
165     def visitHandle(self, handle, lvalue, rvalue):
166         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
167         self.visit(handle.type, lvalue, rvalue);
168         new_lvalue = lookupHandle(handle, lvalue)
169         print '    if (retrace::verbosity >= 2) {'
170         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
171         print '    }'
172         if (new_lvalue.startswith('_program_map') or new_lvalue.startswith('_shader_map')):
173             print 'if (glretrace::supportsARBShaderObjects) {'
174             print '    %s = _handleARB_map[%s];' % (lvalue, lvalue)
175             print '} else {'
176             print '    %s = %s;' % (lvalue, new_lvalue)
177             print '}'
178         else:
179             print '    %s = %s;' % (lvalue, new_lvalue)
180     
181     def visitBlob(self, blob, lvalue, rvalue):
182         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
183     
184     def visitString(self, string, lvalue, rvalue):
185         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
186
187     seq = 0
188
189     def visitStruct(self, struct, lvalue, rvalue):
190         tmp = '_s_' + struct.tag + '_' + str(self.seq)
191         self.seq += 1
192
193         print '    const trace::Struct *%s = (%s).toStruct();' % (tmp, rvalue)
194         print '    assert(%s);' % (tmp)
195         for i in range(len(struct.members)):
196             member = struct.members[i]
197             self.visitMember(member, lvalue, '*%s->members[%s]' % (tmp, i))
198
199     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
200         if polymorphic.defaultType is None:
201             switchExpr = self.expand(polymorphic.switchExpr)
202             print r'    switch (%s) {' % switchExpr
203             for cases, type in polymorphic.iterSwitch():
204                 for case in cases:
205                     print r'    %s:' % case
206                 caseLvalue = lvalue
207                 if type.expr is not None:
208                     caseLvalue = 'static_cast<%s>(%s)' % (type, caseLvalue)
209                 print r'        {'
210                 try:
211                     self.visit(type, caseLvalue, rvalue)
212                 finally:
213                     print r'        }'
214                 print r'        break;'
215             if polymorphic.defaultType is None:
216                 print r'    default:'
217                 print r'        retrace::warning(call) << "unexpected polymorphic case" << %s << "\n";' % (switchExpr,)
218                 print r'        break;'
219             print r'    }'
220         else:
221             self.visit(polymorphic.defaultType, lvalue, rvalue)
222     
223     def visitOpaque(self, opaque, lvalue, rvalue):
224         raise UnsupportedType
225
226
227 class OpaqueValueDeserializer(ValueDeserializer):
228     '''Value extractor that also understands opaque values.
229
230     Normally opaque values can't be retraced, unless they are being extracted
231     in the context of handles.'''
232
233     def visitOpaque(self, opaque, lvalue, rvalue):
234         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
235
236
237 class SwizzledValueRegistrator(stdapi.Visitor, stdapi.ExpanderMixin):
238     '''Type visitor which will register (un)swizzled value pairs, to later be
239     swizzled.'''
240
241     def visitLiteral(self, literal, lvalue, rvalue):
242         pass
243
244     def visitAlias(self, alias, lvalue, rvalue):
245         self.visit(alias.type, lvalue, rvalue)
246     
247     def visitEnum(self, enum, lvalue, rvalue):
248         pass
249
250     def visitBitmask(self, bitmask, lvalue, rvalue):
251         pass
252
253     def visitArray(self, array, lvalue, rvalue):
254         print '    const trace::Array *_a%s = (%s).toArray();' % (array.tag, rvalue)
255         print '    if (_a%s) {' % (array.tag)
256         length = '_a%s->values.size()' % array.tag
257         index = '_j' + array.tag
258         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
259         try:
260             self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
261         finally:
262             print '        }'
263             print '    }'
264     
265     def visitPointer(self, pointer, lvalue, rvalue):
266         print '    const trace::Array *_a%s = (%s).toArray();' % (pointer.tag, rvalue)
267         print '    if (_a%s) {' % (pointer.tag)
268         try:
269             self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
270         finally:
271             print '    }'
272     
273     def visitIntPointer(self, pointer, lvalue, rvalue):
274         pass
275     
276     def visitObjPointer(self, pointer, lvalue, rvalue):
277         print r'    retrace::addObj(call, %s, %s);' % (rvalue, lvalue)
278     
279     def visitLinearPointer(self, pointer, lvalue, rvalue):
280         assert pointer.size is not None
281         if pointer.size is not None:
282             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
283
284     def visitReference(self, reference, lvalue, rvalue):
285         pass
286     
287     def visitHandle(self, handle, lvalue, rvalue):
288         print '    %s _origResult;' % handle.type
289         OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
290         if handle.range is None:
291             rvalue = "_origResult"
292             entry = lookupHandle(handle, rvalue, True)
293             if (entry.startswith('_program_map') or entry.startswith('_shader_map')):
294                 print 'if (glretrace::supportsARBShaderObjects) {'
295                 print '    _handleARB_map[%s] = %s;' % (rvalue, lvalue)
296                 print '} else {'
297                 print '    %s = %s;' % (entry, lvalue)
298                 print '}'
299             else:
300                 print "    %s = %s;" % (entry, lvalue)
301             print '    if (retrace::verbosity >= 2) {'
302             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
303             print '    }'
304         else:
305             i = '_h' + handle.tag
306             lvalue = "%s + %s" % (lvalue, i)
307             rvalue = "_origResult + %s" % (i,)
308             entry = lookupHandle(handle, rvalue) 
309             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
310             print '        {entry} = {lvalue};'.format(**locals())
311             print '        if (retrace::verbosity >= 2) {'
312             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
313             print '        }'
314             print '    }'
315     
316     def visitBlob(self, blob, lvalue, rvalue):
317         pass
318     
319     def visitString(self, string, lvalue, rvalue):
320         pass
321
322     seq = 0
323
324     def visitStruct(self, struct, lvalue, rvalue):
325         tmp = '_s_' + struct.tag + '_' + str(self.seq)
326         self.seq += 1
327
328         print '    const trace::Struct *%s = (%s).toStruct();' % (tmp, rvalue)
329         print '    assert(%s);' % (tmp,)
330         print '    (void)%s;' % (tmp,)
331         for i in range(len(struct.members)):
332             member = struct.members[i]
333             self.visitMember(member, lvalue, '*%s->members[%s]' % (tmp, i))
334     
335     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
336         assert polymorphic.defaultType is not None
337         self.visit(polymorphic.defaultType, lvalue, rvalue)
338     
339     def visitOpaque(self, opaque, lvalue, rvalue):
340         pass
341
342
343 class Retracer:
344
345     def retraceFunction(self, function):
346         print 'static void retrace_%s(trace::Call &call) {' % function.name
347         self.retraceFunctionBody(function)
348         print '}'
349         print
350
351     def retraceInterfaceMethod(self, interface, method):
352         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
353         self.retraceInterfaceMethodBody(interface, method)
354         print '}'
355         print
356
357     def retraceFunctionBody(self, function):
358         assert function.sideeffects
359
360         if function.type is not stdapi.Void:
361             self.checkOrigResult(function)
362
363         self.deserializeArgs(function)
364         
365         self.declareRet(function)
366         self.invokeFunction(function)
367
368         self.swizzleValues(function)
369
370     def retraceInterfaceMethodBody(self, interface, method):
371         assert method.sideeffects
372
373         if method.type is not stdapi.Void:
374             self.checkOrigResult(method)
375
376         self.deserializeThisPointer(interface)
377
378         self.deserializeArgs(method)
379         
380         self.declareRet(method)
381         self.invokeInterfaceMethod(interface, method)
382
383         self.swizzleValues(method)
384
385     def checkOrigResult(self, function):
386         '''Hook for checking the original result, to prevent succeeding now
387         where the original did not, which would cause diversion and potentially
388         unpredictable results.'''
389
390         assert function.type is not stdapi.Void
391
392         if str(function.type) == 'HRESULT':
393             print r'    if (call.ret && FAILED(call.ret->toSInt())) {'
394             print r'        return;'
395             print r'    }'
396
397     def deserializeThisPointer(self, interface):
398         print r'    %s *_this;' % (interface.name,)
399         print r'    _this = static_cast<%s *>(retrace::toObjPointer(call, call.arg(0)));' % (interface.name,)
400         print r'    if (!_this) {'
401         print r'        return;'
402         print r'    }'
403
404     def deserializeArgs(self, function):
405         print '    retrace::ScopedAllocator _allocator;'
406         print '    (void)_allocator;'
407         success = True
408         for arg in function.args:
409             arg_type = arg.type.mutable()
410             print '    %s %s;' % (arg_type, arg.name)
411             rvalue = 'call.arg(%u)' % (arg.index,)
412             lvalue = arg.name
413             try:
414                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
415             except UnsupportedType:
416                 success =  False
417                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
418             print
419
420         if not success:
421             print '    if (1) {'
422             self.failFunction(function)
423             sys.stderr.write('warning: unsupported %s call\n' % function.name)
424             print '    }'
425
426     def swizzleValues(self, function):
427         for arg in function.args:
428             if arg.output:
429                 arg_type = arg.type.mutable()
430                 rvalue = 'call.arg(%u)' % (arg.index,)
431                 lvalue = arg.name
432                 try:
433                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
434                 except UnsupportedType:
435                     print '    // XXX: %s' % arg.name
436         if function.type is not stdapi.Void:
437             rvalue = '*call.ret'
438             lvalue = '_result'
439             try:
440                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
441             except UnsupportedType:
442                 raise
443                 print '    // XXX: result'
444
445     def failFunction(self, function):
446         print '    if (retrace::verbosity >= 0) {'
447         print '        retrace::unsupported(call);'
448         print '    }'
449         print '    return;'
450
451     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
452         ValueAllocator().visit(arg_type, lvalue, rvalue)
453         if arg.input:
454             ValueDeserializer().visit(arg_type, lvalue, rvalue)
455     
456     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
457         try:
458             ValueAllocator().visit(arg_type, lvalue, rvalue)
459         except UnsupportedType:
460             pass
461         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
462
463     def regiterSwizzledValue(self, type, lvalue, rvalue):
464         visitor = SwizzledValueRegistrator()
465         visitor.visit(type, lvalue, rvalue)
466
467     def declareRet(self, function):
468         if function.type is not stdapi.Void:
469             print '    %s _result;' % (function.type)
470
471     def invokeFunction(self, function):
472         arg_names = ", ".join(function.argNames())
473         if function.type is not stdapi.Void:
474             print '    _result = %s(%s);' % (function.name, arg_names)
475             print '    (void)_result;'
476             self.checkResult(function.type)
477         else:
478             print '    %s(%s);' % (function.name, arg_names)
479
480     def invokeInterfaceMethod(self, interface, method):
481         # On release our reference when we reach Release() == 0 call in the
482         # trace.
483         if method.name == 'Release':
484             print '    if (call.ret->toUInt() == 0) {'
485             print '        retrace::delObj(call.arg(0));'
486             print '    }'
487
488         arg_names = ", ".join(method.argNames())
489         if method.type is not stdapi.Void:
490             print '    _result = _this->%s(%s);' % (method.name, arg_names)
491             print '    (void)_result;'
492             self.checkResult(method.type)
493         else:
494             print '    _this->%s(%s);' % (method.name, arg_names)
495
496     def checkResult(self, resultType):
497         if str(resultType) == 'HRESULT':
498             print r'    if (FAILED(_result)) {'
499             print '         static char szMessageBuffer[128];'
500             print r'        retrace::warning(call) << "call returned 0x" << std::hex << _result << std::dec << ": " << (FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM, NULL, _result, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), szMessageBuffer, sizeof szMessageBuffer, NULL) ? szMessageBuffer : "???") << "\n";'
501             print r'    }'
502
503     def filterFunction(self, function):
504         return True
505
506     table_name = 'retrace::callbacks'
507
508     def retraceApi(self, api):
509
510         print '#include "os_time.hpp"'
511         print '#include "trace_parser.hpp"'
512         print '#include "retrace.hpp"'
513         print '#include "retrace_swizzle.hpp"'
514         print
515
516         types = api.getAllTypes()
517         handles = [type for type in types if isinstance(type, stdapi.Handle)]
518         handle_names = set()
519         for handle in handles:
520             if handle.name not in handle_names:
521                 if handle.key is None:
522                     print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
523                 else:
524                     key_name, key_type = handle.key
525                     print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
526                 handle_names.add(handle.name)
527         print
528
529         functions = filter(self.filterFunction, api.getAllFunctions())
530         for function in functions:
531             if function.sideeffects and not function.internal:
532                 self.retraceFunction(function)
533         interfaces = api.getAllInterfaces()
534         for interface in interfaces:
535             for method in interface.iterMethods():
536                 if method.sideeffects and not method.internal:
537                     self.retraceInterfaceMethod(interface, method)
538
539         print 'const retrace::Entry %s[] = {' % self.table_name
540         for function in functions:
541             if not function.internal:
542                 if function.sideeffects:
543                     print '    {"%s", &retrace_%s},' % (function.name, function.name)
544                 else:
545                     print '    {"%s", &retrace::ignore},' % (function.name,)
546         for interface in interfaces:
547             for method in interface.iterMethods():                
548                 if method.sideeffects:
549                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
550                 else:
551                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
552         print '    {NULL, NULL}'
553         print '};'
554         print
555