]> git.cworth.org Git - apitrace/blob - retrace/retrace.py
specs: Initial attempt to support unions.
[apitrace] / retrace / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 # Adjust path
31 import os.path
32 import sys
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
34
35
36 import specs.stdapi as stdapi
37
38
39 class UnsupportedType(Exception):
40     pass
41
42
43 def lookupHandle(handle, value):
44     if handle.key is None:
45         return "_%s_map[%s]" % (handle.name, value)
46     else:
47         key_name, key_type = handle.key
48         return "_%s_map[%s][%s]" % (handle.name, key_name, value)
49
50
51 class ValueAllocator(stdapi.Visitor):
52
53     def visitLiteral(self, literal, lvalue, rvalue):
54         pass
55
56     def visitConst(self, const, lvalue, rvalue):
57         self.visit(const.type, lvalue, rvalue)
58
59     def visitAlias(self, alias, lvalue, rvalue):
60         self.visit(alias.type, lvalue, rvalue)
61
62     def visitEnum(self, enum, lvalue, rvalue):
63         pass
64
65     def visitBitmask(self, bitmask, lvalue, rvalue):
66         pass
67
68     def visitArray(self, array, lvalue, rvalue):
69         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
70
71     def visitPointer(self, pointer, lvalue, rvalue):
72         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
73
74     def visitIntPointer(self, pointer, lvalue, rvalue):
75         pass
76
77     def visitObjPointer(self, pointer, lvalue, rvalue):
78         pass
79
80     def visitLinearPointer(self, pointer, lvalue, rvalue):
81         pass
82
83     def visitReference(self, reference, lvalue, rvalue):
84         self.visit(reference.type, lvalue, rvalue);
85
86     def visitHandle(self, handle, lvalue, rvalue):
87         pass
88
89     def visitBlob(self, blob, lvalue, rvalue):
90         pass
91
92     def visitString(self, string, lvalue, rvalue):
93         pass
94
95     def visitStruct(self, struct, lvalue, rvalue):
96         pass
97
98     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
99         if polymorphic.defaultType is None:
100             # FIXME
101             raise UnsupportedType
102         self.visit(polymorphic.defaultType, lvalue, rvalue)
103
104     def visitOpaque(self, opaque, lvalue, rvalue):
105         pass
106
107
108 class ValueDeserializer(stdapi.Visitor):
109
110     def visitLiteral(self, literal, lvalue, rvalue):
111         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
112
113     def visitConst(self, const, lvalue, rvalue):
114         self.visit(const.type, lvalue, rvalue)
115
116     def visitAlias(self, alias, lvalue, rvalue):
117         self.visit(alias.type, lvalue, rvalue)
118     
119     def visitEnum(self, enum, lvalue, rvalue):
120         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
121
122     def visitBitmask(self, bitmask, lvalue, rvalue):
123         self.visit(bitmask.type, lvalue, rvalue)
124
125     def visitArray(self, array, lvalue, rvalue):
126
127         tmp = '_a_' + array.tag + '_' + str(self.seq)
128         self.seq += 1
129
130         print '    if (%s) {' % (lvalue,)
131         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
132         length = '%s->values.size()' % (tmp,)
133         index = '_j' + array.tag
134         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
135         try:
136             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
137         finally:
138             print '        }'
139             print '    }'
140     
141     def visitPointer(self, pointer, lvalue, rvalue):
142         tmp = '_a_' + pointer.tag + '_' + str(self.seq)
143         self.seq += 1
144
145         print '    if (%s) {' % (lvalue,)
146         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
147         try:
148             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
149         finally:
150             print '    }'
151
152     def visitIntPointer(self, pointer, lvalue, rvalue):
153         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
154
155     def visitObjPointer(self, pointer, lvalue, rvalue):
156         print '    %s = static_cast<%s>(retrace::toObjPointer(call, %s));' % (lvalue, pointer, rvalue)
157
158     def visitLinearPointer(self, pointer, lvalue, rvalue):
159         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
160
161     def visitReference(self, reference, lvalue, rvalue):
162         self.visit(reference.type, lvalue, rvalue);
163
164     def visitHandle(self, handle, lvalue, rvalue):
165         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
166         self.visit(handle.type, lvalue, rvalue);
167         new_lvalue = lookupHandle(handle, lvalue)
168         print '    if (retrace::verbosity >= 2) {'
169         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
170         print '    }'
171         print '    %s = %s;' % (lvalue, new_lvalue)
172     
173     def visitBlob(self, blob, lvalue, rvalue):
174         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
175     
176     def visitString(self, string, lvalue, rvalue):
177         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
178
179     seq = 0
180
181     def visitStruct(self, struct, lvalue, rvalue):
182         tmp = '_s_' + struct.tag + '_' + str(self.seq)
183         self.seq += 1
184
185         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
186         print '    assert(%s);' % (tmp)
187         for i in range(len(struct.members)):
188             member_type, member_name = struct.members[i]
189             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
190
191     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
192         if polymorphic.defaultType is None:
193             # FIXME
194             raise UnsupportedType
195         self.visit(polymorphic.defaultType, lvalue, rvalue)
196     
197     def visitOpaque(self, opaque, lvalue, rvalue):
198         raise UnsupportedType
199
200
201 class OpaqueValueDeserializer(ValueDeserializer):
202     '''Value extractor that also understands opaque values.
203
204     Normally opaque values can't be retraced, unless they are being extracted
205     in the context of handles.'''
206
207     def visitOpaque(self, opaque, lvalue, rvalue):
208         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
209
210
211 class SwizzledValueRegistrator(stdapi.Visitor):
212     '''Type visitor which will register (un)swizzled value pairs, to later be
213     swizzled.'''
214
215     def visitLiteral(self, literal, lvalue, rvalue):
216         pass
217
218     def visitAlias(self, alias, lvalue, rvalue):
219         self.visit(alias.type, lvalue, rvalue)
220     
221     def visitEnum(self, enum, lvalue, rvalue):
222         pass
223
224     def visitBitmask(self, bitmask, lvalue, rvalue):
225         pass
226
227     def visitArray(self, array, lvalue, rvalue):
228         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
229         print '    if (_a%s) {' % (array.tag)
230         length = '_a%s->values.size()' % array.tag
231         index = '_j' + array.tag
232         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
233         try:
234             self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
235         finally:
236             print '        }'
237             print '    }'
238     
239     def visitPointer(self, pointer, lvalue, rvalue):
240         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
241         print '    if (_a%s) {' % (pointer.tag)
242         try:
243             self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
244         finally:
245             print '    }'
246     
247     def visitIntPointer(self, pointer, lvalue, rvalue):
248         pass
249     
250     def visitObjPointer(self, pointer, lvalue, rvalue):
251         print r'    retrace::addObj(call, %s, %s);' % (rvalue, lvalue)
252     
253     def visitLinearPointer(self, pointer, lvalue, rvalue):
254         assert pointer.size is not None
255         if pointer.size is not None:
256             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
257
258     def visitReference(self, reference, lvalue, rvalue):
259         pass
260     
261     def visitHandle(self, handle, lvalue, rvalue):
262         print '    %s _origResult;' % handle.type
263         OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
264         if handle.range is None:
265             rvalue = "_origResult"
266             entry = lookupHandle(handle, rvalue) 
267             print "    %s = %s;" % (entry, lvalue)
268             print '    if (retrace::verbosity >= 2) {'
269             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
270             print '    }'
271         else:
272             i = '_h' + handle.tag
273             lvalue = "%s + %s" % (lvalue, i)
274             rvalue = "_origResult + %s" % (i,)
275             entry = lookupHandle(handle, rvalue) 
276             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
277             print '        {entry} = {lvalue};'.format(**locals())
278             print '        if (retrace::verbosity >= 2) {'
279             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
280             print '        }'
281             print '    }'
282     
283     def visitBlob(self, blob, lvalue, rvalue):
284         pass
285     
286     def visitString(self, string, lvalue, rvalue):
287         pass
288
289     seq = 0
290
291     def visitStruct(self, struct, lvalue, rvalue):
292         tmp = '_s_' + struct.tag + '_' + str(self.seq)
293         self.seq += 1
294
295         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
296         print '    assert(%s);' % (tmp,)
297         print '    (void)%s;' % (tmp,)
298         for i in range(len(struct.members)):
299             member_type, member_name = struct.members[i]
300             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
301     
302     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
303         if polymorphic.defaultType is None:
304             # FIXME
305             raise UnsupportedType
306         self.visit(polymorphic.defaultType, lvalue, rvalue)
307     
308     def visitOpaque(self, opaque, lvalue, rvalue):
309         pass
310
311
312 class Retracer:
313
314     def retraceFunction(self, function):
315         print 'static void retrace_%s(trace::Call &call) {' % function.name
316         self.retraceFunctionBody(function)
317         print '}'
318         print
319
320     def retraceInterfaceMethod(self, interface, method):
321         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
322         self.retraceInterfaceMethodBody(interface, method)
323         print '}'
324         print
325
326     def retraceFunctionBody(self, function):
327         assert function.sideeffects
328
329         if function.type is not stdapi.Void:
330             self.checkOrigResult(function)
331
332         self.deserializeArgs(function)
333         
334         self.declareRet(function)
335         self.invokeFunction(function)
336
337         self.swizzleValues(function)
338
339     def retraceInterfaceMethodBody(self, interface, method):
340         assert method.sideeffects
341
342         if method.type is not stdapi.Void:
343             self.checkOrigResult(method)
344
345         self.deserializeThisPointer(interface)
346
347         self.deserializeArgs(method)
348         
349         self.declareRet(method)
350         self.invokeInterfaceMethod(interface, method)
351
352         self.swizzleValues(method)
353
354     def checkOrigResult(self, function):
355         '''Hook for checking the original result, to prevent succeeding now
356         where the original did not, which would cause diversion and potentially
357         unpredictable results.'''
358
359         assert function.type is not stdapi.Void
360
361         if str(function.type) == 'HRESULT':
362             print r'    if (call.ret && FAILED(call.ret->toSInt())) {'
363             print r'        return;'
364             print r'    }'
365
366     def deserializeThisPointer(self, interface):
367         print r'    %s *_this;' % (interface.name,)
368         print r'    _this = static_cast<%s *>(retrace::toObjPointer(call, call.arg(0)));' % (interface.name,)
369         print r'    if (!_this) {'
370         print r'        return;'
371         print r'    }'
372
373     def deserializeArgs(self, function):
374         print '    retrace::ScopedAllocator _allocator;'
375         print '    (void)_allocator;'
376         success = True
377         for arg in function.args:
378             arg_type = arg.type.mutable()
379             print '    %s %s;' % (arg_type, arg.name)
380             rvalue = 'call.arg(%u)' % (arg.index,)
381             lvalue = arg.name
382             try:
383                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
384             except UnsupportedType:
385                 success =  False
386                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
387             print
388
389         if not success:
390             print '    if (1) {'
391             self.failFunction(function)
392             sys.stderr.write('warning: unsupported %s call\n' % function.name)
393             print '    }'
394
395     def swizzleValues(self, function):
396         for arg in function.args:
397             if arg.output:
398                 arg_type = arg.type.mutable()
399                 rvalue = 'call.arg(%u)' % (arg.index,)
400                 lvalue = arg.name
401                 try:
402                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
403                 except UnsupportedType:
404                     print '    // XXX: %s' % arg.name
405         if function.type is not stdapi.Void:
406             rvalue = '*call.ret'
407             lvalue = '_result'
408             try:
409                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
410             except UnsupportedType:
411                 raise
412                 print '    // XXX: result'
413
414     def failFunction(self, function):
415         print '    if (retrace::verbosity >= 0) {'
416         print '        retrace::unsupported(call);'
417         print '    }'
418         print '    return;'
419
420     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
421         ValueAllocator().visit(arg_type, lvalue, rvalue)
422         if arg.input:
423             ValueDeserializer().visit(arg_type, lvalue, rvalue)
424     
425     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
426         try:
427             ValueAllocator().visit(arg_type, lvalue, rvalue)
428         except UnsupportedType:
429             pass
430         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
431
432     def regiterSwizzledValue(self, type, lvalue, rvalue):
433         visitor = SwizzledValueRegistrator()
434         visitor.visit(type, lvalue, rvalue)
435
436     def declareRet(self, function):
437         if function.type is not stdapi.Void:
438             print '    %s _result;' % (function.type)
439
440     def invokeFunction(self, function):
441         arg_names = ", ".join(function.argNames())
442         if function.type is not stdapi.Void:
443             print '    _result = %s(%s);' % (function.name, arg_names)
444             print '    (void)_result;'
445             self.checkResult(function.type)
446         else:
447             print '    %s(%s);' % (function.name, arg_names)
448
449     def invokeInterfaceMethod(self, interface, method):
450         # On release our reference when we reach Release() == 0 call in the
451         # trace.
452         if method.name == 'Release':
453             print '    if (call.ret->toUInt()) {'
454             print '        return;'
455             print '    }'
456             print '    retrace::delObj(call.arg(0));'
457
458         arg_names = ", ".join(method.argNames())
459         if method.type is not stdapi.Void:
460             print '    _result = _this->%s(%s);' % (method.name, arg_names)
461             print '    (void)_result;'
462             self.checkResult(method.type)
463         else:
464             print '    _this->%s(%s);' % (method.name, arg_names)
465
466     def checkResult(self, resultType):
467         if str(resultType) == 'HRESULT':
468             print r'    if (FAILED(_result)) {'
469             print r'        retrace::warning(call) << "failed\n";'
470             print r'    }'
471
472     def filterFunction(self, function):
473         return True
474
475     table_name = 'retrace::callbacks'
476
477     def retraceApi(self, api):
478
479         print '#include "os_time.hpp"'
480         print '#include "trace_parser.hpp"'
481         print '#include "retrace.hpp"'
482         print '#include "retrace_swizzle.hpp"'
483         print
484
485         types = api.getAllTypes()
486         handles = [type for type in types if isinstance(type, stdapi.Handle)]
487         handle_names = set()
488         for handle in handles:
489             if handle.name not in handle_names:
490                 if handle.key is None:
491                     print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
492                 else:
493                     key_name, key_type = handle.key
494                     print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
495                 handle_names.add(handle.name)
496         print
497
498         functions = filter(self.filterFunction, api.getAllFunctions())
499         for function in functions:
500             if function.sideeffects and not function.internal:
501                 self.retraceFunction(function)
502         interfaces = api.getAllInterfaces()
503         for interface in interfaces:
504             for method in interface.iterMethods():
505                 if method.sideeffects and not method.internal:
506                     self.retraceInterfaceMethod(interface, method)
507
508         print 'const retrace::Entry %s[] = {' % self.table_name
509         for function in functions:
510             if not function.internal:
511                 if function.sideeffects:
512                     print '    {"%s", &retrace_%s},' % (function.name, function.name)
513                 else:
514                     print '    {"%s", &retrace::ignore},' % (function.name,)
515         for interface in interfaces:
516             for method in interface.iterMethods():                
517                 if method.sideeffects:
518                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
519                 else:
520                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
521         print '    {NULL, NULL}'
522         print '};'
523         print
524