]> git.cworth.org Git - apitrace/blob - retrace/retrace.py
cd5ef1d2136770f21c5097f0480c6c8b661059d7
[apitrace] / retrace / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 # Adjust path
31 import os.path
32 import sys
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
34
35
36 import specs.stdapi as stdapi
37
38
39 class UnsupportedType(Exception):
40     pass
41
42
43 def lookupHandle(handle, value):
44     if handle.key is None:
45         return "_%s_map[%s]" % (handle.name, value)
46     else:
47         key_name, key_type = handle.key
48         return "_%s_map[%s][%s]" % (handle.name, key_name, value)
49
50
51 class ValueAllocator(stdapi.Visitor):
52
53     def visitLiteral(self, literal, lvalue, rvalue):
54         pass
55
56     def visitConst(self, const, lvalue, rvalue):
57         self.visit(const.type, lvalue, rvalue)
58
59     def visitAlias(self, alias, lvalue, rvalue):
60         self.visit(alias.type, lvalue, rvalue)
61
62     def visitEnum(self, enum, lvalue, rvalue):
63         pass
64
65     def visitBitmask(self, bitmask, lvalue, rvalue):
66         pass
67
68     def visitArray(self, array, lvalue, rvalue):
69         print '    %s = static_cast<%s *>(_allocator.alloc(&%s, sizeof *%s));' % (lvalue, array.type, rvalue, lvalue)
70
71     def visitPointer(self, pointer, lvalue, rvalue):
72         print '    %s = static_cast<%s *>(_allocator.alloc(&%s, sizeof *%s));' % (lvalue, pointer.type, rvalue, lvalue)
73
74     def visitIntPointer(self, pointer, lvalue, rvalue):
75         pass
76
77     def visitObjPointer(self, pointer, lvalue, rvalue):
78         pass
79
80     def visitLinearPointer(self, pointer, lvalue, rvalue):
81         pass
82
83     def visitReference(self, reference, lvalue, rvalue):
84         self.visit(reference.type, lvalue, rvalue);
85
86     def visitHandle(self, handle, lvalue, rvalue):
87         pass
88
89     def visitBlob(self, blob, lvalue, rvalue):
90         pass
91
92     def visitString(self, string, lvalue, rvalue):
93         pass
94
95     def visitStruct(self, struct, lvalue, rvalue):
96         pass
97
98     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
99         assert polymorphic.defaultType is not None
100         self.visit(polymorphic.defaultType, lvalue, rvalue)
101
102     def visitOpaque(self, opaque, lvalue, rvalue):
103         pass
104
105
106 class ValueDeserializer(stdapi.Visitor, stdapi.ExpanderMixin):
107
108     def visitLiteral(self, literal, lvalue, rvalue):
109         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
110
111     def visitConst(self, const, lvalue, rvalue):
112         self.visit(const.type, lvalue, rvalue)
113
114     def visitAlias(self, alias, lvalue, rvalue):
115         self.visit(alias.type, lvalue, rvalue)
116     
117     def visitEnum(self, enum, lvalue, rvalue):
118         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
119
120     def visitBitmask(self, bitmask, lvalue, rvalue):
121         self.visit(bitmask.type, lvalue, rvalue)
122
123     def visitArray(self, array, lvalue, rvalue):
124
125         tmp = '_a_' + array.tag + '_' + str(self.seq)
126         self.seq += 1
127
128         print '    if (%s) {' % (lvalue,)
129         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
130         length = '%s->values.size()' % (tmp,)
131         index = '_j' + array.tag
132         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
133         try:
134             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
135         finally:
136             print '        }'
137             print '    }'
138     
139     def visitPointer(self, pointer, lvalue, rvalue):
140         tmp = '_a_' + pointer.tag + '_' + str(self.seq)
141         self.seq += 1
142
143         print '    if (%s) {' % (lvalue,)
144         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
145         try:
146             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
147         finally:
148             print '    }'
149
150     def visitIntPointer(self, pointer, lvalue, rvalue):
151         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
152
153     def visitObjPointer(self, pointer, lvalue, rvalue):
154         print '    %s = static_cast<%s>(retrace::toObjPointer(call, %s));' % (lvalue, pointer, rvalue)
155
156     def visitLinearPointer(self, pointer, lvalue, rvalue):
157         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
158
159     def visitReference(self, reference, lvalue, rvalue):
160         self.visit(reference.type, lvalue, rvalue);
161
162     def visitHandle(self, handle, lvalue, rvalue):
163         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
164         self.visit(handle.type, lvalue, rvalue);
165         new_lvalue = lookupHandle(handle, lvalue)
166         print '    if (retrace::verbosity >= 2) {'
167         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
168         print '    }'
169         print '    %s = %s;' % (lvalue, new_lvalue)
170     
171     def visitBlob(self, blob, lvalue, rvalue):
172         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
173     
174     def visitString(self, string, lvalue, rvalue):
175         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
176
177     seq = 0
178
179     def visitStruct(self, struct, lvalue, rvalue):
180         tmp = '_s_' + struct.tag + '_' + str(self.seq)
181         self.seq += 1
182
183         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
184         print '    assert(%s);' % (tmp)
185         for i in range(len(struct.members)):
186             member = struct.members[i]
187             self.visitMember(member, lvalue, '*%s->members[%s]' % (tmp, i))
188
189     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
190         if polymorphic.defaultType is None:
191             switchExpr = self.expand(polymorphic.switchExpr)
192             print r'    switch (%s) {' % switchExpr
193             for cases, type in polymorphic.iterSwitch():
194                 for case in cases:
195                     print r'    %s:' % case
196                 caseLvalue = lvalue
197                 if type.expr is not None:
198                     caseLvalue = 'static_cast<%s>(%s)' % (type, caseLvalue)
199                 print r'        {'
200                 try:
201                     self.visit(type, caseLvalue, rvalue)
202                 finally:
203                     print r'        }'
204                 print r'        break;'
205             if polymorphic.defaultType is None:
206                 print r'    default:'
207                 print r'        retrace::warning(call) << "unexpected polymorphic case" << %s << "\n";' % (switchExpr,)
208                 print r'        break;'
209             print r'    }'
210         else:
211             self.visit(polymorphic.defaultType, lvalue, rvalue)
212     
213     def visitOpaque(self, opaque, lvalue, rvalue):
214         raise UnsupportedType
215
216
217 class OpaqueValueDeserializer(ValueDeserializer):
218     '''Value extractor that also understands opaque values.
219
220     Normally opaque values can't be retraced, unless they are being extracted
221     in the context of handles.'''
222
223     def visitOpaque(self, opaque, lvalue, rvalue):
224         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
225
226
227 class SwizzledValueRegistrator(stdapi.Visitor, stdapi.ExpanderMixin):
228     '''Type visitor which will register (un)swizzled value pairs, to later be
229     swizzled.'''
230
231     def visitLiteral(self, literal, lvalue, rvalue):
232         pass
233
234     def visitAlias(self, alias, lvalue, rvalue):
235         self.visit(alias.type, lvalue, rvalue)
236     
237     def visitEnum(self, enum, lvalue, rvalue):
238         pass
239
240     def visitBitmask(self, bitmask, lvalue, rvalue):
241         pass
242
243     def visitArray(self, array, lvalue, rvalue):
244         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
245         print '    if (_a%s) {' % (array.tag)
246         length = '_a%s->values.size()' % array.tag
247         index = '_j' + array.tag
248         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
249         try:
250             self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
251         finally:
252             print '        }'
253             print '    }'
254     
255     def visitPointer(self, pointer, lvalue, rvalue):
256         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
257         print '    if (_a%s) {' % (pointer.tag)
258         try:
259             self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
260         finally:
261             print '    }'
262     
263     def visitIntPointer(self, pointer, lvalue, rvalue):
264         pass
265     
266     def visitObjPointer(self, pointer, lvalue, rvalue):
267         print r'    retrace::addObj(call, %s, %s);' % (rvalue, lvalue)
268     
269     def visitLinearPointer(self, pointer, lvalue, rvalue):
270         assert pointer.size is not None
271         if pointer.size is not None:
272             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
273
274     def visitReference(self, reference, lvalue, rvalue):
275         pass
276     
277     def visitHandle(self, handle, lvalue, rvalue):
278         print '    %s _origResult;' % handle.type
279         OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
280         if handle.range is None:
281             rvalue = "_origResult"
282             entry = lookupHandle(handle, rvalue) 
283             print "    %s = %s;" % (entry, lvalue)
284             print '    if (retrace::verbosity >= 2) {'
285             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
286             print '    }'
287         else:
288             i = '_h' + handle.tag
289             lvalue = "%s + %s" % (lvalue, i)
290             rvalue = "_origResult + %s" % (i,)
291             entry = lookupHandle(handle, rvalue) 
292             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
293             print '        {entry} = {lvalue};'.format(**locals())
294             print '        if (retrace::verbosity >= 2) {'
295             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
296             print '        }'
297             print '    }'
298     
299     def visitBlob(self, blob, lvalue, rvalue):
300         pass
301     
302     def visitString(self, string, lvalue, rvalue):
303         pass
304
305     seq = 0
306
307     def visitStruct(self, struct, lvalue, rvalue):
308         tmp = '_s_' + struct.tag + '_' + str(self.seq)
309         self.seq += 1
310
311         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
312         print '    assert(%s);' % (tmp,)
313         print '    (void)%s;' % (tmp,)
314         for i in range(len(struct.members)):
315             member = struct.members[i]
316             self.visitMember(member, lvalue, '*%s->members[%s]' % (tmp, i))
317     
318     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
319         assert polymorphic.defaultType is not None
320         self.visit(polymorphic.defaultType, lvalue, rvalue)
321     
322     def visitOpaque(self, opaque, lvalue, rvalue):
323         pass
324
325
326 class Retracer:
327
328     def retraceFunction(self, function):
329         print 'static void retrace_%s(trace::Call &call) {' % function.name
330         self.retraceFunctionBody(function)
331         print '}'
332         print
333
334     def retraceInterfaceMethod(self, interface, method):
335         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
336         self.retraceInterfaceMethodBody(interface, method)
337         print '}'
338         print
339
340     def retraceFunctionBody(self, function):
341         assert function.sideeffects
342
343         if function.type is not stdapi.Void:
344             self.checkOrigResult(function)
345
346         self.deserializeArgs(function)
347         
348         self.declareRet(function)
349         self.invokeFunction(function)
350
351         self.swizzleValues(function)
352
353     def retraceInterfaceMethodBody(self, interface, method):
354         assert method.sideeffects
355
356         if method.type is not stdapi.Void:
357             self.checkOrigResult(method)
358
359         self.deserializeThisPointer(interface)
360
361         self.deserializeArgs(method)
362         
363         self.declareRet(method)
364         self.invokeInterfaceMethod(interface, method)
365
366         self.swizzleValues(method)
367
368     def checkOrigResult(self, function):
369         '''Hook for checking the original result, to prevent succeeding now
370         where the original did not, which would cause diversion and potentially
371         unpredictable results.'''
372
373         assert function.type is not stdapi.Void
374
375         if str(function.type) == 'HRESULT':
376             print r'    if (call.ret && FAILED(call.ret->toSInt())) {'
377             print r'        return;'
378             print r'    }'
379
380     def deserializeThisPointer(self, interface):
381         print r'    %s *_this;' % (interface.name,)
382         print r'    _this = static_cast<%s *>(retrace::toObjPointer(call, call.arg(0)));' % (interface.name,)
383         print r'    if (!_this) {'
384         print r'        return;'
385         print r'    }'
386
387     def deserializeArgs(self, function):
388         print '    retrace::ScopedAllocator _allocator;'
389         print '    (void)_allocator;'
390         success = True
391         for arg in function.args:
392             arg_type = arg.type.mutable()
393             print '    %s %s;' % (arg_type, arg.name)
394             rvalue = 'call.arg(%u)' % (arg.index,)
395             lvalue = arg.name
396             try:
397                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
398             except UnsupportedType:
399                 success =  False
400                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
401             print
402
403         if not success:
404             print '    if (1) {'
405             self.failFunction(function)
406             sys.stderr.write('warning: unsupported %s call\n' % function.name)
407             print '    }'
408
409     def swizzleValues(self, function):
410         for arg in function.args:
411             if arg.output:
412                 arg_type = arg.type.mutable()
413                 rvalue = 'call.arg(%u)' % (arg.index,)
414                 lvalue = arg.name
415                 try:
416                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
417                 except UnsupportedType:
418                     print '    // XXX: %s' % arg.name
419         if function.type is not stdapi.Void:
420             rvalue = '*call.ret'
421             lvalue = '_result'
422             try:
423                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
424             except UnsupportedType:
425                 raise
426                 print '    // XXX: result'
427
428     def failFunction(self, function):
429         print '    if (retrace::verbosity >= 0) {'
430         print '        retrace::unsupported(call);'
431         print '    }'
432         print '    return;'
433
434     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
435         ValueAllocator().visit(arg_type, lvalue, rvalue)
436         if arg.input:
437             ValueDeserializer().visit(arg_type, lvalue, rvalue)
438     
439     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
440         try:
441             ValueAllocator().visit(arg_type, lvalue, rvalue)
442         except UnsupportedType:
443             pass
444         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
445
446     def regiterSwizzledValue(self, type, lvalue, rvalue):
447         visitor = SwizzledValueRegistrator()
448         visitor.visit(type, lvalue, rvalue)
449
450     def declareRet(self, function):
451         if function.type is not stdapi.Void:
452             print '    %s _result;' % (function.type)
453
454     def invokeFunction(self, function):
455         arg_names = ", ".join(function.argNames())
456         if function.type is not stdapi.Void:
457             print '    _result = %s(%s);' % (function.name, arg_names)
458             print '    (void)_result;'
459             self.checkResult(function.type)
460         else:
461             print '    %s(%s);' % (function.name, arg_names)
462
463     def invokeInterfaceMethod(self, interface, method):
464         # On release our reference when we reach Release() == 0 call in the
465         # trace.
466         if method.name == 'Release':
467             print '    if (call.ret->toUInt()) {'
468             print '        return;'
469             print '    }'
470             print '    retrace::delObj(call.arg(0));'
471
472         arg_names = ", ".join(method.argNames())
473         if method.type is not stdapi.Void:
474             print '    _result = _this->%s(%s);' % (method.name, arg_names)
475             print '    (void)_result;'
476             self.checkResult(method.type)
477         else:
478             print '    _this->%s(%s);' % (method.name, arg_names)
479
480     def checkResult(self, resultType):
481         if str(resultType) == 'HRESULT':
482             print r'    if (FAILED(_result)) {'
483             print r'        retrace::warning(call) << "failed\n";'
484             print r'    }'
485
486     def filterFunction(self, function):
487         return True
488
489     table_name = 'retrace::callbacks'
490
491     def retraceApi(self, api):
492
493         print '#include "os_time.hpp"'
494         print '#include "trace_parser.hpp"'
495         print '#include "retrace.hpp"'
496         print '#include "retrace_swizzle.hpp"'
497         print
498
499         types = api.getAllTypes()
500         handles = [type for type in types if isinstance(type, stdapi.Handle)]
501         handle_names = set()
502         for handle in handles:
503             if handle.name not in handle_names:
504                 if handle.key is None:
505                     print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
506                 else:
507                     key_name, key_type = handle.key
508                     print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
509                 handle_names.add(handle.name)
510         print
511
512         functions = filter(self.filterFunction, api.getAllFunctions())
513         for function in functions:
514             if function.sideeffects and not function.internal:
515                 self.retraceFunction(function)
516         interfaces = api.getAllInterfaces()
517         for interface in interfaces:
518             for method in interface.iterMethods():
519                 if method.sideeffects and not method.internal:
520                     self.retraceInterfaceMethod(interface, method)
521
522         print 'const retrace::Entry %s[] = {' % self.table_name
523         for function in functions:
524             if not function.internal:
525                 if function.sideeffects:
526                     print '    {"%s", &retrace_%s},' % (function.name, function.name)
527                 else:
528                     print '    {"%s", &retrace::ignore},' % (function.name,)
529         for interface in interfaces:
530             for method in interface.iterMethods():                
531                 if method.sideeffects:
532                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
533                 else:
534                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
535         print '    {NULL, NULL}'
536         print '};'
537         print
538