]> git.cworth.org Git - apitrace/blob - retrace.py
9e9af208f53186d3fe6054406b4f3aad6d99d8e0
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class UnsupportedType(Exception):
37     pass
38
39
40 class MutableRebuilder(stdapi.Rebuilder):
41     '''Type visitor which derives a mutable type.'''
42
43     def visitConst(self, const):
44         # Strip out const qualifier
45         return const.type
46
47     def visitAlias(self, alias):
48         # Tear the alias on type changes
49         type = self.visit(alias.type)
50         if type is alias.type:
51             return alias
52         return type
53
54     def visitReference(self, reference):
55         # Strip out references
56         return reference.type
57
58
59 def lookupHandle(handle, value):
60     if handle.key is None:
61         return "__%s_map[%s]" % (handle.name, value)
62     else:
63         key_name, key_type = handle.key
64         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
65
66
67 class ValueAllocator(stdapi.Visitor):
68
69     def visitLiteral(self, literal, lvalue, rvalue):
70         pass
71
72     def visitConst(self, const, lvalue, rvalue):
73         self.visit(const.type, lvalue, rvalue)
74
75     def visitAlias(self, alias, lvalue, rvalue):
76         self.visit(alias.type, lvalue, rvalue)
77
78     def visitEnum(self, enum, lvalue, rvalue):
79         pass
80
81     def visitBitmask(self, bitmask, lvalue, rvalue):
82         pass
83
84     def visitArray(self, array, lvalue, rvalue):
85         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
86
87     def visitPointer(self, pointer, lvalue, rvalue):
88         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
89
90     def visitIntPointer(self, pointer, lvalue, rvalue):
91         pass
92
93     def visitObjPointer(self, pointer, lvalue, rvalue):
94         pass
95
96     def visitLinearPointer(self, pointer, lvalue, rvalue):
97         pass
98
99     def visitReference(self, reference, lvalue, rvalue):
100         self.visit(reference.type, lvalue, rvalue);
101
102     def visitHandle(self, handle, lvalue, rvalue):
103         pass
104
105     def visitBlob(self, blob, lvalue, rvalue):
106         pass
107
108     def visitString(self, string, lvalue, rvalue):
109         pass
110
111     def visitStruct(self, struct, lvalue, rvalue):
112         pass
113
114     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
115         self.visit(polymorphic.defaultType, lvalue, rvalue)
116
117     def visitOpaque(self, opaque, lvalue, rvalue):
118         pass
119
120
121 class ValueDeserializer(stdapi.Visitor):
122
123     def visitLiteral(self, literal, lvalue, rvalue):
124         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
125
126     def visitConst(self, const, lvalue, rvalue):
127         self.visit(const.type, lvalue, rvalue)
128
129     def visitAlias(self, alias, lvalue, rvalue):
130         self.visit(alias.type, lvalue, rvalue)
131     
132     def visitEnum(self, enum, lvalue, rvalue):
133         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
134
135     def visitBitmask(self, bitmask, lvalue, rvalue):
136         self.visit(bitmask.type, lvalue, rvalue)
137
138     def visitArray(self, array, lvalue, rvalue):
139
140         tmp = '__a_' + array.tag + '_' + str(self.seq)
141         self.seq += 1
142
143         print '    if (%s) {' % (lvalue,)
144         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
145         length = '%s->values.size()' % (tmp,)
146         index = '__j' + array.tag
147         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
148         try:
149             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
150         finally:
151             print '        }'
152             print '    }'
153     
154     def visitPointer(self, pointer, lvalue, rvalue):
155         tmp = '__a_' + pointer.tag + '_' + str(self.seq)
156         self.seq += 1
157
158         print '    if (%s) {' % (lvalue,)
159         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
160         try:
161             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
162         finally:
163             print '    }'
164
165     def visitIntPointer(self, pointer, lvalue, rvalue):
166         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
167
168     def visitObjPointer(self, pointer, lvalue, rvalue):
169         old_lvalue = '(%s).toUIntPtr()' % (rvalue,)
170         new_lvalue = '_obj_map[%s]' % (old_lvalue,)
171         print '    if (retrace::verbosity >= 2) {'
172         print '        std::cout << std::hex << "obj 0x" << size_t(%s) << " <- 0x" << size_t(%s) << std::dec <<"\\n";' % (old_lvalue, new_lvalue)
173         print '    }'
174         print '    %s = static_cast<%s>(%s);' % (lvalue, pointer, new_lvalue)
175
176     def visitLinearPointer(self, pointer, lvalue, rvalue):
177         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
178
179     def visitReference(self, reference, lvalue, rvalue):
180         self.visit(reference.type, lvalue, rvalue);
181
182     def visitHandle(self, handle, lvalue, rvalue):
183         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
184         self.visit(handle.type, lvalue, rvalue);
185         new_lvalue = lookupHandle(handle, lvalue)
186         print '    if (retrace::verbosity >= 2) {'
187         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
188         print '    }'
189         print '    %s = %s;' % (lvalue, new_lvalue)
190     
191     def visitBlob(self, blob, lvalue, rvalue):
192         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
193     
194     def visitString(self, string, lvalue, rvalue):
195         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
196
197     seq = 0
198
199     def visitStruct(self, struct, lvalue, rvalue):
200         tmp = '__s_' + struct.tag + '_' + str(self.seq)
201         self.seq += 1
202
203         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
204         print '    assert(%s);' % (tmp)
205         for i in range(len(struct.members)):
206             member_type, member_name = struct.members[i]
207             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
208
209     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
210         self.visit(polymorphic.defaultType, lvalue, rvalue)
211     
212     def visitOpaque(self, opaque, lvalue, rvalue):
213         raise UnsupportedType
214
215
216 class OpaqueValueDeserializer(ValueDeserializer):
217     '''Value extractor that also understands opaque values.
218
219     Normally opaque values can't be retraced, unless they are being extracted
220     in the context of handles.'''
221
222     def visitOpaque(self, opaque, lvalue, rvalue):
223         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
224
225
226 class SwizzledValueRegistrator(stdapi.Visitor):
227     '''Type visitor which will register (un)swizzled value pairs, to later be
228     swizzled.'''
229
230     def visitLiteral(self, literal, lvalue, rvalue):
231         pass
232
233     def visitAlias(self, alias, lvalue, rvalue):
234         self.visit(alias.type, lvalue, rvalue)
235     
236     def visitEnum(self, enum, lvalue, rvalue):
237         pass
238
239     def visitBitmask(self, bitmask, lvalue, rvalue):
240         pass
241
242     def visitArray(self, array, lvalue, rvalue):
243         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
244         print '    if (__a%s) {' % (array.tag)
245         length = '__a%s->values.size()' % array.tag
246         index = '__j' + array.tag
247         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
248         try:
249             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
250         finally:
251             print '        }'
252             print '    }'
253     
254     def visitPointer(self, pointer, lvalue, rvalue):
255         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
256         print '    if (__a%s) {' % (pointer.tag)
257         try:
258             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
259         finally:
260             print '    }'
261     
262     def visitIntPointer(self, pointer, lvalue, rvalue):
263         pass
264     
265     def visitObjPointer(self, pointer, lvalue, rvalue):
266         print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
267     
268     def visitLinearPointer(self, pointer, lvalue, rvalue):
269         assert pointer.size is not None
270         if pointer.size is not None:
271             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
272
273     def visitReference(self, reference, lvalue, rvalue):
274         pass
275     
276     def visitHandle(self, handle, lvalue, rvalue):
277         print '    %s __orig_result;' % handle.type
278         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
279         if handle.range is None:
280             rvalue = "__orig_result"
281             entry = lookupHandle(handle, rvalue) 
282             print "    %s = %s;" % (entry, lvalue)
283             print '    if (retrace::verbosity >= 2) {'
284             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
285             print '    }'
286         else:
287             i = '__h' + handle.tag
288             lvalue = "%s + %s" % (lvalue, i)
289             rvalue = "__orig_result + %s" % (i,)
290             entry = lookupHandle(handle, rvalue) 
291             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
292             print '        {entry} = {lvalue};'.format(**locals())
293             print '        if (retrace::verbosity >= 2) {'
294             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
295             print '        }'
296             print '    }'
297     
298     def visitBlob(self, blob, lvalue, rvalue):
299         pass
300     
301     def visitString(self, string, lvalue, rvalue):
302         pass
303
304     seq = 0
305
306     def visitStruct(self, struct, lvalue, rvalue):
307         tmp = '__s_' + struct.tag + '_' + str(self.seq)
308         self.seq += 1
309
310         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
311         print '    assert(%s);' % (tmp,)
312         print '    (void)%s;' % (tmp,)
313         for i in range(len(struct.members)):
314             member_type, member_name = struct.members[i]
315             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
316     
317     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
318         self.visit(polymorphic.defaultType, lvalue, rvalue)
319     
320     def visitOpaque(self, opaque, lvalue, rvalue):
321         pass
322
323
324 class Retracer:
325
326     def retraceFunction(self, function):
327         print 'static void retrace_%s(trace::Call &call) {' % function.name
328         self.retraceFunctionBody(function)
329         print '}'
330         print
331
332     def retraceInterfaceMethod(self, interface, method):
333         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
334         self.retraceInterfaceMethodBody(interface, method)
335         print '}'
336         print
337
338     def retraceFunctionBody(self, function):
339         assert function.sideeffects
340
341         self.deserializeArgs(function)
342         
343         self.invokeFunction(function)
344
345         self.swizzleValues(function)
346
347     def retraceInterfaceMethodBody(self, interface, method):
348         assert method.sideeffects
349
350         self.deserializeThisPointer(interface)
351
352         self.deserializeArgs(method)
353         
354         self.invokeInterfaceMethod(interface, method)
355
356         self.swizzleValues(method)
357
358     def deserializeThisPointer(self, interface):
359         print r'    %s *_this;' % (interface.name,)
360         print r'    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
361         print r'    if (!_this) {'
362         print r'        retrace::warning(call) << "NULL this pointer\n";'
363         print r'        return;'
364         print r'    }'
365
366     def deserializeArgs(self, function):
367         print '    retrace::ScopedAllocator _allocator;'
368         print '    (void)_allocator;'
369         success = True
370         for arg in function.args:
371             arg_type = MutableRebuilder().visit(arg.type)
372             print '    %s %s;' % (arg_type, arg.name)
373             rvalue = 'call.arg(%u)' % (arg.index,)
374             lvalue = arg.name
375             try:
376                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
377             except UnsupportedType:
378                 success =  False
379                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
380             print
381
382         if not success:
383             print '    if (1) {'
384             self.failFunction(function)
385             if function.name[-1].islower():
386                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
387             print '    }'
388
389     def swizzleValues(self, function):
390         for arg in function.args:
391             if arg.output:
392                 arg_type = MutableRebuilder().visit(arg.type)
393                 rvalue = 'call.arg(%u)' % (arg.index,)
394                 lvalue = arg.name
395                 try:
396                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
397                 except UnsupportedType:
398                     print '    // XXX: %s' % arg.name
399         if function.type is not stdapi.Void:
400             rvalue = '*call.ret'
401             lvalue = '__result'
402             try:
403                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
404             except UnsupportedType:
405                 raise
406                 print '    // XXX: result'
407
408     def failFunction(self, function):
409         print '    if (retrace::verbosity >= 0) {'
410         print '        retrace::unsupported(call);'
411         print '    }'
412         print '    return;'
413
414     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
415         ValueAllocator().visit(arg_type, lvalue, rvalue)
416         if arg.input:
417             ValueDeserializer().visit(arg_type, lvalue, rvalue)
418     
419     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
420         try:
421             ValueAllocator().visit(arg_type, lvalue, rvalue)
422         except UnsupportedType:
423             pass
424         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
425
426     def regiterSwizzledValue(self, type, lvalue, rvalue):
427         visitor = SwizzledValueRegistrator()
428         visitor.visit(type, lvalue, rvalue)
429
430     def invokeFunction(self, function):
431         arg_names = ", ".join(function.argNames())
432         if function.type is not stdapi.Void:
433             print '    %s __result;' % (function.type)
434             print '    __result = %s(%s);' % (function.name, arg_names)
435             print '    (void)__result;'
436         else:
437             print '    %s(%s);' % (function.name, arg_names)
438
439     def invokeInterfaceMethod(self, interface, method):
440         arg_names = ", ".join(method.argNames())
441         if method.type is not stdapi.Void:
442             print '    %s __result;' % (method.type)
443             print '    __result = _this->%s(%s);' % (method.name, arg_names)
444             print '    (void)__result;'
445         else:
446             print '    _this->%s(%s);' % (method.name, arg_names)
447
448     def filterFunction(self, function):
449         return True
450
451     table_name = 'retrace::callbacks'
452
453     def retraceApi(self, api):
454
455         print '#include "os_time.hpp"'
456         print '#include "trace_parser.hpp"'
457         print '#include "retrace.hpp"'
458         print
459
460         types = api.getAllTypes()
461         handles = [type for type in types if isinstance(type, stdapi.Handle)]
462         handle_names = set()
463         for handle in handles:
464             if handle.name not in handle_names:
465                 if handle.key is None:
466                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
467                 else:
468                     key_name, key_type = handle.key
469                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
470                 handle_names.add(handle.name)
471         print
472
473         print 'static std::map<unsigned long long, void *> _obj_map;'
474         print
475
476         functions = filter(self.filterFunction, api.functions)
477         for function in functions:
478             if function.sideeffects:
479                 self.retraceFunction(function)
480         interfaces = api.getAllInterfaces()
481         for interface in interfaces:
482             for method in interface.iterMethods():
483                 if method.sideeffects:
484                     self.retraceInterfaceMethod(interface, method)
485
486         print 'const retrace::Entry %s[] = {' % self.table_name
487         for function in functions:
488             if function.sideeffects:
489                 print '    {"%s", &retrace_%s},' % (function.name, function.name)
490             else:
491                 print '    {"%s", &retrace::ignore},' % (function.name,)
492         for interface in interfaces:
493             for method in interface.iterMethods():                
494                 if method.sideeffects:
495                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
496                 else:
497                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
498         print '    {NULL, NULL}'
499         print '};'
500         print
501