]> git.cworth.org Git - apitrace/blob - retrace/retrace.py
replay: Support applications mixing glCreateProgramObjectARB and glUseProgram
[apitrace] / retrace / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 # Adjust path
31 import os.path
32 import sys
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
34
35
36 import specs.stdapi as stdapi
37
38
39 class UnsupportedType(Exception):
40     pass
41
42
43 def lookupHandle(handle, value):
44     if handle.key is None:
45         return "_%s_map[%s]" % (handle.name, value)
46     else:
47         key_name, key_type = handle.key
48         return "_%s_map[%s][%s]" % (handle.name, key_name, value)
49
50
51 class ValueAllocator(stdapi.Visitor):
52
53     def visitLiteral(self, literal, lvalue, rvalue):
54         pass
55
56     def visitConst(self, const, lvalue, rvalue):
57         self.visit(const.type, lvalue, rvalue)
58
59     def visitAlias(self, alias, lvalue, rvalue):
60         self.visit(alias.type, lvalue, rvalue)
61
62     def visitEnum(self, enum, lvalue, rvalue):
63         pass
64
65     def visitBitmask(self, bitmask, lvalue, rvalue):
66         pass
67
68     def visitArray(self, array, lvalue, rvalue):
69         print '    %s = static_cast<%s *>(_allocator.alloc(&%s, sizeof *%s));' % (lvalue, array.type, rvalue, lvalue)
70
71     def visitPointer(self, pointer, lvalue, rvalue):
72         print '    %s = static_cast<%s *>(_allocator.alloc(&%s, sizeof *%s));' % (lvalue, pointer.type, rvalue, lvalue)
73
74     def visitIntPointer(self, pointer, lvalue, rvalue):
75         pass
76
77     def visitObjPointer(self, pointer, lvalue, rvalue):
78         pass
79
80     def visitLinearPointer(self, pointer, lvalue, rvalue):
81         pass
82
83     def visitReference(self, reference, lvalue, rvalue):
84         self.visit(reference.type, lvalue, rvalue);
85
86     def visitHandle(self, handle, lvalue, rvalue):
87         pass
88
89     def visitBlob(self, blob, lvalue, rvalue):
90         pass
91
92     def visitString(self, string, lvalue, rvalue):
93         pass
94
95     def visitStruct(self, struct, lvalue, rvalue):
96         pass
97
98     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
99         assert polymorphic.defaultType is not None
100         self.visit(polymorphic.defaultType, lvalue, rvalue)
101
102     def visitOpaque(self, opaque, lvalue, rvalue):
103         pass
104
105
106 class ValueDeserializer(stdapi.Visitor, stdapi.ExpanderMixin):
107
108     def visitLiteral(self, literal, lvalue, rvalue):
109         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
110
111     def visitConst(self, const, lvalue, rvalue):
112         self.visit(const.type, lvalue, rvalue)
113
114     def visitAlias(self, alias, lvalue, rvalue):
115         self.visit(alias.type, lvalue, rvalue)
116     
117     def visitEnum(self, enum, lvalue, rvalue):
118         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
119
120     def visitBitmask(self, bitmask, lvalue, rvalue):
121         self.visit(bitmask.type, lvalue, rvalue)
122
123     def visitArray(self, array, lvalue, rvalue):
124
125         tmp = '_a_' + array.tag + '_' + str(self.seq)
126         self.seq += 1
127
128         print '    if (%s) {' % (lvalue,)
129         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
130         length = '%s->values.size()' % (tmp,)
131         index = '_j' + array.tag
132         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
133         try:
134             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
135         finally:
136             print '        }'
137             print '    }'
138     
139     def visitPointer(self, pointer, lvalue, rvalue):
140         tmp = '_a_' + pointer.tag + '_' + str(self.seq)
141         self.seq += 1
142
143         print '    if (%s) {' % (lvalue,)
144         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
145         try:
146             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
147         finally:
148             print '    }'
149
150     def visitIntPointer(self, pointer, lvalue, rvalue):
151         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
152
153     def visitObjPointer(self, pointer, lvalue, rvalue):
154         print '    %s = static_cast<%s>(retrace::toObjPointer(call, %s));' % (lvalue, pointer, rvalue)
155
156     def visitLinearPointer(self, pointer, lvalue, rvalue):
157         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
158
159     def visitReference(self, reference, lvalue, rvalue):
160         self.visit(reference.type, lvalue, rvalue);
161
162     def visitHandle(self, handle, lvalue, rvalue):
163         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
164         self.visit(handle.type, lvalue, rvalue);
165         new_lvalue = lookupHandle(handle, lvalue)
166         print '    if (retrace::verbosity >= 2) {'
167         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
168         print '    }'
169         if (new_lvalue.startswith('_program_map') or new_lvalue.startswith('_shader_map')):
170             print 'if (glretrace::supportsARBShaderObjects) {'
171             print '    %s = _handleARB_map[%s];' % (lvalue, lvalue)
172             print '} else {'
173             print '    %s = %s;' % (lvalue, new_lvalue)
174             print '}'
175         else:
176             print '    %s = %s;' % (lvalue, new_lvalue)
177     
178     def visitBlob(self, blob, lvalue, rvalue):
179         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
180     
181     def visitString(self, string, lvalue, rvalue):
182         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
183
184     seq = 0
185
186     def visitStruct(self, struct, lvalue, rvalue):
187         tmp = '_s_' + struct.tag + '_' + str(self.seq)
188         self.seq += 1
189
190         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
191         print '    assert(%s);' % (tmp)
192         for i in range(len(struct.members)):
193             member = struct.members[i]
194             self.visitMember(member, lvalue, '*%s->members[%s]' % (tmp, i))
195
196     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
197         if polymorphic.defaultType is None:
198             switchExpr = self.expand(polymorphic.switchExpr)
199             print r'    switch (%s) {' % switchExpr
200             for cases, type in polymorphic.iterSwitch():
201                 for case in cases:
202                     print r'    %s:' % case
203                 caseLvalue = lvalue
204                 if type.expr is not None:
205                     caseLvalue = 'static_cast<%s>(%s)' % (type, caseLvalue)
206                 print r'        {'
207                 try:
208                     self.visit(type, caseLvalue, rvalue)
209                 finally:
210                     print r'        }'
211                 print r'        break;'
212             if polymorphic.defaultType is None:
213                 print r'    default:'
214                 print r'        retrace::warning(call) << "unexpected polymorphic case" << %s << "\n";' % (switchExpr,)
215                 print r'        break;'
216             print r'    }'
217         else:
218             self.visit(polymorphic.defaultType, lvalue, rvalue)
219     
220     def visitOpaque(self, opaque, lvalue, rvalue):
221         raise UnsupportedType
222
223
224 class OpaqueValueDeserializer(ValueDeserializer):
225     '''Value extractor that also understands opaque values.
226
227     Normally opaque values can't be retraced, unless they are being extracted
228     in the context of handles.'''
229
230     def visitOpaque(self, opaque, lvalue, rvalue):
231         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
232
233
234 class SwizzledValueRegistrator(stdapi.Visitor, stdapi.ExpanderMixin):
235     '''Type visitor which will register (un)swizzled value pairs, to later be
236     swizzled.'''
237
238     def visitLiteral(self, literal, lvalue, rvalue):
239         pass
240
241     def visitAlias(self, alias, lvalue, rvalue):
242         self.visit(alias.type, lvalue, rvalue)
243     
244     def visitEnum(self, enum, lvalue, rvalue):
245         pass
246
247     def visitBitmask(self, bitmask, lvalue, rvalue):
248         pass
249
250     def visitArray(self, array, lvalue, rvalue):
251         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
252         print '    if (_a%s) {' % (array.tag)
253         length = '_a%s->values.size()' % array.tag
254         index = '_j' + array.tag
255         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
256         try:
257             self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
258         finally:
259             print '        }'
260             print '    }'
261     
262     def visitPointer(self, pointer, lvalue, rvalue):
263         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
264         print '    if (_a%s) {' % (pointer.tag)
265         try:
266             self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
267         finally:
268             print '    }'
269     
270     def visitIntPointer(self, pointer, lvalue, rvalue):
271         pass
272     
273     def visitObjPointer(self, pointer, lvalue, rvalue):
274         print r'    retrace::addObj(call, %s, %s);' % (rvalue, lvalue)
275     
276     def visitLinearPointer(self, pointer, lvalue, rvalue):
277         assert pointer.size is not None
278         if pointer.size is not None:
279             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
280
281     def visitReference(self, reference, lvalue, rvalue):
282         pass
283     
284     def visitHandle(self, handle, lvalue, rvalue):
285         print '    %s _origResult;' % handle.type
286         OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
287         if handle.range is None:
288             rvalue = "_origResult"
289             entry = lookupHandle(handle, rvalue)
290             if (entry.startswith('_program_map') or entry.startswith('_shader_map')):
291                 print 'if (glretrace::supportsARBShaderObjects) {'
292                 print '    _handleARB_map[%s] = %s;' % (rvalue, lvalue)
293                 print '} else {'
294                 print '    %s = %s;' % (entry, lvalue)
295                 print '}'
296             else:
297                 print "    %s = %s;" % (entry, lvalue)
298             print '    if (retrace::verbosity >= 2) {'
299             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
300             print '    }'
301         else:
302             i = '_h' + handle.tag
303             lvalue = "%s + %s" % (lvalue, i)
304             rvalue = "_origResult + %s" % (i,)
305             entry = lookupHandle(handle, rvalue) 
306             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
307             print '        {entry} = {lvalue};'.format(**locals())
308             print '        if (retrace::verbosity >= 2) {'
309             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
310             print '        }'
311             print '    }'
312     
313     def visitBlob(self, blob, lvalue, rvalue):
314         pass
315     
316     def visitString(self, string, lvalue, rvalue):
317         pass
318
319     seq = 0
320
321     def visitStruct(self, struct, lvalue, rvalue):
322         tmp = '_s_' + struct.tag + '_' + str(self.seq)
323         self.seq += 1
324
325         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
326         print '    assert(%s);' % (tmp,)
327         print '    (void)%s;' % (tmp,)
328         for i in range(len(struct.members)):
329             member = struct.members[i]
330             self.visitMember(member, lvalue, '*%s->members[%s]' % (tmp, i))
331     
332     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
333         assert polymorphic.defaultType is not None
334         self.visit(polymorphic.defaultType, lvalue, rvalue)
335     
336     def visitOpaque(self, opaque, lvalue, rvalue):
337         pass
338
339
340 class Retracer:
341
342     def retraceFunction(self, function):
343         print 'static void retrace_%s(trace::Call &call) {' % function.name
344         self.retraceFunctionBody(function)
345         print '}'
346         print
347
348     def retraceInterfaceMethod(self, interface, method):
349         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
350         self.retraceInterfaceMethodBody(interface, method)
351         print '}'
352         print
353
354     def retraceFunctionBody(self, function):
355         assert function.sideeffects
356
357         if function.type is not stdapi.Void:
358             self.checkOrigResult(function)
359
360         self.deserializeArgs(function)
361         
362         self.declareRet(function)
363         self.invokeFunction(function)
364
365         self.swizzleValues(function)
366
367     def retraceInterfaceMethodBody(self, interface, method):
368         assert method.sideeffects
369
370         if method.type is not stdapi.Void:
371             self.checkOrigResult(method)
372
373         self.deserializeThisPointer(interface)
374
375         self.deserializeArgs(method)
376         
377         self.declareRet(method)
378         self.invokeInterfaceMethod(interface, method)
379
380         self.swizzleValues(method)
381
382     def checkOrigResult(self, function):
383         '''Hook for checking the original result, to prevent succeeding now
384         where the original did not, which would cause diversion and potentially
385         unpredictable results.'''
386
387         assert function.type is not stdapi.Void
388
389         if str(function.type) == 'HRESULT':
390             print r'    if (call.ret && FAILED(call.ret->toSInt())) {'
391             print r'        return;'
392             print r'    }'
393
394     def deserializeThisPointer(self, interface):
395         print r'    %s *_this;' % (interface.name,)
396         print r'    _this = static_cast<%s *>(retrace::toObjPointer(call, call.arg(0)));' % (interface.name,)
397         print r'    if (!_this) {'
398         print r'        return;'
399         print r'    }'
400
401     def deserializeArgs(self, function):
402         print '    retrace::ScopedAllocator _allocator;'
403         print '    (void)_allocator;'
404         success = True
405         for arg in function.args:
406             arg_type = arg.type.mutable()
407             print '    %s %s;' % (arg_type, arg.name)
408             rvalue = 'call.arg(%u)' % (arg.index,)
409             lvalue = arg.name
410             try:
411                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
412             except UnsupportedType:
413                 success =  False
414                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
415             print
416
417         if not success:
418             print '    if (1) {'
419             self.failFunction(function)
420             sys.stderr.write('warning: unsupported %s call\n' % function.name)
421             print '    }'
422
423     def swizzleValues(self, function):
424         for arg in function.args:
425             if arg.output:
426                 arg_type = arg.type.mutable()
427                 rvalue = 'call.arg(%u)' % (arg.index,)
428                 lvalue = arg.name
429                 try:
430                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
431                 except UnsupportedType:
432                     print '    // XXX: %s' % arg.name
433         if function.type is not stdapi.Void:
434             rvalue = '*call.ret'
435             lvalue = '_result'
436             try:
437                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
438             except UnsupportedType:
439                 raise
440                 print '    // XXX: result'
441
442     def failFunction(self, function):
443         print '    if (retrace::verbosity >= 0) {'
444         print '        retrace::unsupported(call);'
445         print '    }'
446         print '    return;'
447
448     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
449         ValueAllocator().visit(arg_type, lvalue, rvalue)
450         if arg.input:
451             ValueDeserializer().visit(arg_type, lvalue, rvalue)
452     
453     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
454         try:
455             ValueAllocator().visit(arg_type, lvalue, rvalue)
456         except UnsupportedType:
457             pass
458         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
459
460     def regiterSwizzledValue(self, type, lvalue, rvalue):
461         visitor = SwizzledValueRegistrator()
462         visitor.visit(type, lvalue, rvalue)
463
464     def declareRet(self, function):
465         if function.type is not stdapi.Void:
466             print '    %s _result;' % (function.type)
467
468     def invokeFunction(self, function):
469         arg_names = ", ".join(function.argNames())
470         if function.type is not stdapi.Void:
471             print '    _result = %s(%s);' % (function.name, arg_names)
472             print '    (void)_result;'
473             self.checkResult(function.type)
474         else:
475             print '    %s(%s);' % (function.name, arg_names)
476
477     def invokeInterfaceMethod(self, interface, method):
478         # On release our reference when we reach Release() == 0 call in the
479         # trace.
480         if method.name == 'Release':
481             print '    if (call.ret->toUInt()) {'
482             print '        return;'
483             print '    }'
484             print '    retrace::delObj(call.arg(0));'
485
486         arg_names = ", ".join(method.argNames())
487         if method.type is not stdapi.Void:
488             print '    _result = _this->%s(%s);' % (method.name, arg_names)
489             print '    (void)_result;'
490             self.checkResult(method.type)
491         else:
492             print '    _this->%s(%s);' % (method.name, arg_names)
493
494     def checkResult(self, resultType):
495         if str(resultType) == 'HRESULT':
496             print r'    if (FAILED(_result)) {'
497             print r'        retrace::warning(call) << "failed\n";'
498             print r'    }'
499
500     def filterFunction(self, function):
501         return True
502
503     table_name = 'retrace::callbacks'
504
505     def retraceApi(self, api):
506
507         print '#include "os_time.hpp"'
508         print '#include "trace_parser.hpp"'
509         print '#include "retrace.hpp"'
510         print '#include "retrace_swizzle.hpp"'
511         print
512
513         types = api.getAllTypes()
514         handles = [type for type in types if isinstance(type, stdapi.Handle)]
515         handle_names = set()
516         for handle in handles:
517             if handle.name not in handle_names:
518                 if handle.key is None:
519                     print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
520                 else:
521                     key_name, key_type = handle.key
522                     print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
523                 handle_names.add(handle.name)
524         print
525
526         functions = filter(self.filterFunction, api.getAllFunctions())
527         for function in functions:
528             if function.sideeffects and not function.internal:
529                 self.retraceFunction(function)
530         interfaces = api.getAllInterfaces()
531         for interface in interfaces:
532             for method in interface.iterMethods():
533                 if method.sideeffects and not method.internal:
534                     self.retraceInterfaceMethod(interface, method)
535
536         print 'const retrace::Entry %s[] = {' % self.table_name
537         for function in functions:
538             if not function.internal:
539                 if function.sideeffects:
540                     print '    {"%s", &retrace_%s},' % (function.name, function.name)
541                 else:
542                     print '    {"%s", &retrace::ignore},' % (function.name,)
543         for interface in interfaces:
544             for method in interface.iterMethods():                
545                 if method.sideeffects:
546                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
547                 else:
548                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
549         print '    {NULL, NULL}'
550         print '};'
551         print
552