]> git.cworth.org Git - apitrace/blob - retrace/retrace.py
Factor out object swizzling.
[apitrace] / retrace / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 # Adjust path
31 import os.path
32 import sys
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
34
35
36 import specs.stdapi as stdapi
37 import specs.glapi as glapi
38
39
40 class UnsupportedType(Exception):
41     pass
42
43
44 def lookupHandle(handle, value):
45     if handle.key is None:
46         return "_%s_map[%s]" % (handle.name, value)
47     else:
48         key_name, key_type = handle.key
49         return "_%s_map[%s][%s]" % (handle.name, key_name, value)
50
51
52 class ValueAllocator(stdapi.Visitor):
53
54     def visitLiteral(self, literal, lvalue, rvalue):
55         pass
56
57     def visitConst(self, const, lvalue, rvalue):
58         self.visit(const.type, lvalue, rvalue)
59
60     def visitAlias(self, alias, lvalue, rvalue):
61         self.visit(alias.type, lvalue, rvalue)
62
63     def visitEnum(self, enum, lvalue, rvalue):
64         pass
65
66     def visitBitmask(self, bitmask, lvalue, rvalue):
67         pass
68
69     def visitArray(self, array, lvalue, rvalue):
70         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
71
72     def visitPointer(self, pointer, lvalue, rvalue):
73         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
74
75     def visitIntPointer(self, pointer, lvalue, rvalue):
76         pass
77
78     def visitObjPointer(self, pointer, lvalue, rvalue):
79         pass
80
81     def visitLinearPointer(self, pointer, lvalue, rvalue):
82         pass
83
84     def visitReference(self, reference, lvalue, rvalue):
85         self.visit(reference.type, lvalue, rvalue);
86
87     def visitHandle(self, handle, lvalue, rvalue):
88         pass
89
90     def visitBlob(self, blob, lvalue, rvalue):
91         pass
92
93     def visitString(self, string, lvalue, rvalue):
94         pass
95
96     def visitStruct(self, struct, lvalue, rvalue):
97         pass
98
99     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
100         self.visit(polymorphic.defaultType, lvalue, rvalue)
101
102     def visitOpaque(self, opaque, lvalue, rvalue):
103         pass
104
105
106 class ValueDeserializer(stdapi.Visitor):
107
108     def visitLiteral(self, literal, lvalue, rvalue):
109         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
110
111     def visitConst(self, const, lvalue, rvalue):
112         self.visit(const.type, lvalue, rvalue)
113
114     def visitAlias(self, alias, lvalue, rvalue):
115         self.visit(alias.type, lvalue, rvalue)
116     
117     def visitEnum(self, enum, lvalue, rvalue):
118         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
119
120     def visitBitmask(self, bitmask, lvalue, rvalue):
121         self.visit(bitmask.type, lvalue, rvalue)
122
123     def visitArray(self, array, lvalue, rvalue):
124
125         tmp = '_a_' + array.tag + '_' + str(self.seq)
126         self.seq += 1
127
128         print '    if (%s) {' % (lvalue,)
129         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
130         length = '%s->values.size()' % (tmp,)
131         index = '_j' + array.tag
132         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
133         try:
134             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
135         finally:
136             print '        }'
137             print '    }'
138     
139     def visitPointer(self, pointer, lvalue, rvalue):
140         tmp = '_a_' + pointer.tag + '_' + str(self.seq)
141         self.seq += 1
142
143         print '    if (%s) {' % (lvalue,)
144         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
145         try:
146             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
147         finally:
148             print '    }'
149
150     def visitIntPointer(self, pointer, lvalue, rvalue):
151         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
152
153     def visitObjPointer(self, pointer, lvalue, rvalue):
154         print '    %s = static_cast<%s>(retrace::toObjPointer(%s));' % (lvalue, pointer, rvalue)
155
156     def visitLinearPointer(self, pointer, lvalue, rvalue):
157         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
158
159     def visitReference(self, reference, lvalue, rvalue):
160         self.visit(reference.type, lvalue, rvalue);
161
162     def visitHandle(self, handle, lvalue, rvalue):
163         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
164         self.visit(handle.type, lvalue, rvalue);
165         new_lvalue = lookupHandle(handle, lvalue)
166         print '    if (retrace::verbosity >= 2) {'
167         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
168         print '    }'
169         print '    %s = %s;' % (lvalue, new_lvalue)
170     
171     def visitBlob(self, blob, lvalue, rvalue):
172         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
173     
174     def visitString(self, string, lvalue, rvalue):
175         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
176
177     seq = 0
178
179     def visitStruct(self, struct, lvalue, rvalue):
180         tmp = '_s_' + struct.tag + '_' + str(self.seq)
181         self.seq += 1
182
183         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
184         print '    assert(%s);' % (tmp)
185         for i in range(len(struct.members)):
186             member_type, member_name = struct.members[i]
187             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
188
189     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
190         self.visit(polymorphic.defaultType, lvalue, rvalue)
191     
192     def visitOpaque(self, opaque, lvalue, rvalue):
193         raise UnsupportedType
194
195
196 class OpaqueValueDeserializer(ValueDeserializer):
197     '''Value extractor that also understands opaque values.
198
199     Normally opaque values can't be retraced, unless they are being extracted
200     in the context of handles.'''
201
202     def visitOpaque(self, opaque, lvalue, rvalue):
203         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
204
205
206 class SwizzledValueRegistrator(stdapi.Visitor):
207     '''Type visitor which will register (un)swizzled value pairs, to later be
208     swizzled.'''
209
210     def visitLiteral(self, literal, lvalue, rvalue):
211         pass
212
213     def visitAlias(self, alias, lvalue, rvalue):
214         self.visit(alias.type, lvalue, rvalue)
215     
216     def visitEnum(self, enum, lvalue, rvalue):
217         pass
218
219     def visitBitmask(self, bitmask, lvalue, rvalue):
220         pass
221
222     def visitArray(self, array, lvalue, rvalue):
223         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
224         print '    if (_a%s) {' % (array.tag)
225         length = '_a%s->values.size()' % array.tag
226         index = '_j' + array.tag
227         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
228         try:
229             self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
230         finally:
231             print '        }'
232             print '    }'
233     
234     def visitPointer(self, pointer, lvalue, rvalue):
235         print '    const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
236         print '    if (_a%s) {' % (pointer.tag)
237         try:
238             self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
239         finally:
240             print '    }'
241     
242     def visitIntPointer(self, pointer, lvalue, rvalue):
243         pass
244     
245     def visitObjPointer(self, pointer, lvalue, rvalue):
246         print r'    retrace::addObj(%s, %s);' % (rvalue, lvalue)
247     
248     def visitLinearPointer(self, pointer, lvalue, rvalue):
249         assert pointer.size is not None
250         if pointer.size is not None:
251             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
252
253     def visitReference(self, reference, lvalue, rvalue):
254         pass
255     
256     def visitHandle(self, handle, lvalue, rvalue):
257         print '    %s _origResult;' % handle.type
258         OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
259         if handle.range is None:
260             rvalue = "_origResult"
261             entry = lookupHandle(handle, rvalue) 
262             print "    %s = %s;" % (entry, lvalue)
263             print '    if (retrace::verbosity >= 2) {'
264             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
265             print '    }'
266         else:
267             i = '_h' + handle.tag
268             lvalue = "%s + %s" % (lvalue, i)
269             rvalue = "_origResult + %s" % (i,)
270             entry = lookupHandle(handle, rvalue) 
271             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
272             print '        {entry} = {lvalue};'.format(**locals())
273             print '        if (retrace::verbosity >= 2) {'
274             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
275             print '        }'
276             print '    }'
277     
278     def visitBlob(self, blob, lvalue, rvalue):
279         pass
280     
281     def visitString(self, string, lvalue, rvalue):
282         pass
283
284     seq = 0
285
286     def visitStruct(self, struct, lvalue, rvalue):
287         tmp = '_s_' + struct.tag + '_' + str(self.seq)
288         self.seq += 1
289
290         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
291         print '    assert(%s);' % (tmp,)
292         print '    (void)%s;' % (tmp,)
293         for i in range(len(struct.members)):
294             member_type, member_name = struct.members[i]
295             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
296     
297     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
298         self.visit(polymorphic.defaultType, lvalue, rvalue)
299     
300     def visitOpaque(self, opaque, lvalue, rvalue):
301         pass
302
303
304 class Retracer:
305
306     def retraceFunction(self, function):
307         print 'static void retrace_%s(trace::Call &call) {' % function.name
308         self.retraceFunctionBody(function)
309         print '}'
310         print
311
312     def retraceInterfaceMethod(self, interface, method):
313         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
314         self.retraceInterfaceMethodBody(interface, method)
315         print '}'
316         print
317
318     def retraceFunctionBody(self, function):
319         assert function.sideeffects
320
321         if function.type is not stdapi.Void:
322             self.checkOrigResult(function)
323
324         self.deserializeArgs(function)
325         
326         self.declareRet(function)
327         self.invokeFunction(function)
328
329         self.swizzleValues(function)
330
331     def retraceInterfaceMethodBody(self, interface, method):
332         assert method.sideeffects
333
334         if method.type is not stdapi.Void:
335             self.checkOrigResult(method)
336
337         self.deserializeThisPointer(interface)
338
339         self.deserializeArgs(method)
340         
341         self.declareRet(method)
342         self.invokeInterfaceMethod(interface, method)
343
344         self.swizzleValues(method)
345
346     def checkOrigResult(self, function):
347         '''Hook for checking the original result, to prevent succeeding now
348         where the original did not, which would cause diversion and potentially
349         unpredictable results.'''
350
351         assert function.type is not stdapi.Void
352
353         if str(function.type) == 'HRESULT':
354             print r'    if (call.ret && FAILED(call.ret->toSInt())) {'
355             print r'        return;'
356             print r'    }'
357
358     def deserializeThisPointer(self, interface):
359         print r'    %s *_this;' % (interface.name,)
360         print r'    _this = static_cast<%s *>(retrace::toObjPointer(call.arg(0)));' % (interface.name,)
361         print r'    if (!_this) {'
362         print r'        retrace::warning(call) << "NULL this pointer\n";'
363         print r'        return;'
364         print r'    }'
365
366     def deserializeArgs(self, function):
367         print '    retrace::ScopedAllocator _allocator;'
368         print '    (void)_allocator;'
369         success = True
370         for arg in function.args:
371             arg_type = arg.type.mutable()
372             print '    %s %s;' % (arg_type, arg.name)
373             rvalue = 'call.arg(%u)' % (arg.index,)
374             lvalue = arg.name
375             try:
376                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
377             except UnsupportedType:
378                 success =  False
379                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
380             print
381
382         if not success:
383             print '    if (1) {'
384             self.failFunction(function)
385             if function.name[-1].islower():
386                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
387             print '    }'
388
389     def swizzleValues(self, function):
390         for arg in function.args:
391             if arg.output:
392                 arg_type = arg.type.mutable()
393                 rvalue = 'call.arg(%u)' % (arg.index,)
394                 lvalue = arg.name
395                 try:
396                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
397                 except UnsupportedType:
398                     print '    // XXX: %s' % arg.name
399         if function.type is not stdapi.Void:
400             rvalue = '*call.ret'
401             lvalue = '_result'
402             try:
403                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
404             except UnsupportedType:
405                 raise
406                 print '    // XXX: result'
407
408     def failFunction(self, function):
409         print '    if (retrace::verbosity >= 0) {'
410         print '        retrace::unsupported(call);'
411         print '    }'
412         print '    return;'
413
414     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
415         ValueAllocator().visit(arg_type, lvalue, rvalue)
416         if arg.input:
417             ValueDeserializer().visit(arg_type, lvalue, rvalue)
418     
419     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
420         try:
421             ValueAllocator().visit(arg_type, lvalue, rvalue)
422         except UnsupportedType:
423             pass
424         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
425
426     def regiterSwizzledValue(self, type, lvalue, rvalue):
427         visitor = SwizzledValueRegistrator()
428         visitor.visit(type, lvalue, rvalue)
429
430     def declareRet(self, function):
431         if function.type is not stdapi.Void:
432             print '    %s _result;' % (function.type)
433
434     def invokeFunction(self, function):
435         arg_names = ", ".join(function.argNames())
436         if function.type is not stdapi.Void:
437             print '    _result = %s(%s);' % (function.name, arg_names)
438             print '    (void)_result;'
439         else:
440             print '    %s(%s);' % (function.name, arg_names)
441
442     def invokeInterfaceMethod(self, interface, method):
443         # On release our reference when we reach Release() == 0 call in the
444         # trace.
445         if method.name == 'Release':
446             print '    if (call.ret->toUInt()) {'
447             print '        return;'
448             print '    }'
449             print '    retrace::delObj(call.arg(0));'
450
451         arg_names = ", ".join(method.argNames())
452         if method.type is not stdapi.Void:
453             print '    _result = _this->%s(%s);' % (method.name, arg_names)
454             print '    (void)_result;'
455         else:
456             print '    _this->%s(%s);' % (method.name, arg_names)
457
458     def filterFunction(self, function):
459         return True
460
461     table_name = 'retrace::callbacks'
462
463     def retraceApi(self, api):
464
465         print '#include "os_time.hpp"'
466         print '#include "trace_parser.hpp"'
467         print '#include "retrace.hpp"'
468         print '#include "retrace_swizzle.hpp"'
469         print
470
471         types = api.getAllTypes()
472         handles = [type for type in types if isinstance(type, stdapi.Handle)]
473         handle_names = set()
474         for handle in handles:
475             if handle.name not in handle_names:
476                 if handle.key is None:
477                     print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
478                 else:
479                     key_name, key_type = handle.key
480                     print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
481                 handle_names.add(handle.name)
482         print
483
484         functions = filter(self.filterFunction, api.functions)
485         for function in functions:
486             if function.sideeffects and not function.internal:
487                 self.retraceFunction(function)
488         interfaces = api.getAllInterfaces()
489         for interface in interfaces:
490             for method in interface.iterMethods():
491                 if method.sideeffects and not method.internal:
492                     self.retraceInterfaceMethod(interface, method)
493
494         print 'const retrace::Entry %s[] = {' % self.table_name
495         for function in functions:
496             if not function.internal:
497                 if function.sideeffects:
498                     print '    {"%s", &retrace_%s},' % (function.name, function.name)
499                 else:
500                     print '    {"%s", &retrace::ignore},' % (function.name,)
501         for interface in interfaces:
502             for method in interface.iterMethods():                
503                 if method.sideeffects:
504                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
505                 else:
506                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
507         print '    {NULL, NULL}'
508         print '};'
509         print
510