]> git.cworth.org Git - apitrace/blob - retrace/retrace.py
Move retracers to their own directory.
[apitrace] / retrace / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 # Adjust path
31 import os.path
32 import sys
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
34
35
36 import specs.stdapi as stdapi
37 import specs.glapi as glapi
38
39
40 class UnsupportedType(Exception):
41     pass
42
43
44 class MutableRebuilder(stdapi.Rebuilder):
45     '''Type visitor which derives a mutable type.'''
46
47     def visitConst(self, const):
48         # Strip out const qualifier
49         return const.type
50
51     def visitAlias(self, alias):
52         # Tear the alias on type changes
53         type = self.visit(alias.type)
54         if type is alias.type:
55             return alias
56         return type
57
58     def visitReference(self, reference):
59         # Strip out references
60         return reference.type
61
62
63 def lookupHandle(handle, value):
64     if handle.key is None:
65         return "__%s_map[%s]" % (handle.name, value)
66     else:
67         key_name, key_type = handle.key
68         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
69
70
71 class ValueAllocator(stdapi.Visitor):
72
73     def visitLiteral(self, literal, lvalue, rvalue):
74         pass
75
76     def visitConst(self, const, lvalue, rvalue):
77         self.visit(const.type, lvalue, rvalue)
78
79     def visitAlias(self, alias, lvalue, rvalue):
80         self.visit(alias.type, lvalue, rvalue)
81
82     def visitEnum(self, enum, lvalue, rvalue):
83         pass
84
85     def visitBitmask(self, bitmask, lvalue, rvalue):
86         pass
87
88     def visitArray(self, array, lvalue, rvalue):
89         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
90
91     def visitPointer(self, pointer, lvalue, rvalue):
92         print '    %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
93
94     def visitIntPointer(self, pointer, lvalue, rvalue):
95         pass
96
97     def visitObjPointer(self, pointer, lvalue, rvalue):
98         pass
99
100     def visitLinearPointer(self, pointer, lvalue, rvalue):
101         pass
102
103     def visitReference(self, reference, lvalue, rvalue):
104         self.visit(reference.type, lvalue, rvalue);
105
106     def visitHandle(self, handle, lvalue, rvalue):
107         pass
108
109     def visitBlob(self, blob, lvalue, rvalue):
110         pass
111
112     def visitString(self, string, lvalue, rvalue):
113         pass
114
115     def visitStruct(self, struct, lvalue, rvalue):
116         pass
117
118     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
119         self.visit(polymorphic.defaultType, lvalue, rvalue)
120
121     def visitOpaque(self, opaque, lvalue, rvalue):
122         pass
123
124
125 class ValueDeserializer(stdapi.Visitor):
126
127     def visitLiteral(self, literal, lvalue, rvalue):
128         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
129
130     def visitConst(self, const, lvalue, rvalue):
131         self.visit(const.type, lvalue, rvalue)
132
133     def visitAlias(self, alias, lvalue, rvalue):
134         self.visit(alias.type, lvalue, rvalue)
135     
136     def visitEnum(self, enum, lvalue, rvalue):
137         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
138
139     def visitBitmask(self, bitmask, lvalue, rvalue):
140         self.visit(bitmask.type, lvalue, rvalue)
141
142     def visitArray(self, array, lvalue, rvalue):
143
144         tmp = '__a_' + array.tag + '_' + str(self.seq)
145         self.seq += 1
146
147         print '    if (%s) {' % (lvalue,)
148         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
149         length = '%s->values.size()' % (tmp,)
150         index = '__j' + array.tag
151         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
152         try:
153             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
154         finally:
155             print '        }'
156             print '    }'
157     
158     def visitPointer(self, pointer, lvalue, rvalue):
159         tmp = '__a_' + pointer.tag + '_' + str(self.seq)
160         self.seq += 1
161
162         print '    if (%s) {' % (lvalue,)
163         print '        const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
164         try:
165             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
166         finally:
167             print '    }'
168
169     def visitIntPointer(self, pointer, lvalue, rvalue):
170         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
171
172     def visitObjPointer(self, pointer, lvalue, rvalue):
173         old_lvalue = '(%s).toUIntPtr()' % (rvalue,)
174         new_lvalue = '_obj_map[%s]' % (old_lvalue,)
175         print '    if (retrace::verbosity >= 2) {'
176         print '        std::cout << std::hex << "obj 0x" << size_t(%s) << " <- 0x" << size_t(%s) << std::dec <<"\\n";' % (old_lvalue, new_lvalue)
177         print '    }'
178         print '    %s = static_cast<%s>(%s);' % (lvalue, pointer, new_lvalue)
179
180     def visitLinearPointer(self, pointer, lvalue, rvalue):
181         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
182
183     def visitReference(self, reference, lvalue, rvalue):
184         self.visit(reference.type, lvalue, rvalue);
185
186     def visitHandle(self, handle, lvalue, rvalue):
187         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
188         self.visit(handle.type, lvalue, rvalue);
189         new_lvalue = lookupHandle(handle, lvalue)
190         print '    if (retrace::verbosity >= 2) {'
191         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
192         print '    }'
193         print '    %s = %s;' % (lvalue, new_lvalue)
194     
195     def visitBlob(self, blob, lvalue, rvalue):
196         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
197     
198     def visitString(self, string, lvalue, rvalue):
199         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
200
201     seq = 0
202
203     def visitStruct(self, struct, lvalue, rvalue):
204         tmp = '__s_' + struct.tag + '_' + str(self.seq)
205         self.seq += 1
206
207         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
208         print '    assert(%s);' % (tmp)
209         for i in range(len(struct.members)):
210             member_type, member_name = struct.members[i]
211             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
212
213     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
214         self.visit(polymorphic.defaultType, lvalue, rvalue)
215     
216     def visitOpaque(self, opaque, lvalue, rvalue):
217         raise UnsupportedType
218
219
220 class OpaqueValueDeserializer(ValueDeserializer):
221     '''Value extractor that also understands opaque values.
222
223     Normally opaque values can't be retraced, unless they are being extracted
224     in the context of handles.'''
225
226     def visitOpaque(self, opaque, lvalue, rvalue):
227         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
228
229
230 class SwizzledValueRegistrator(stdapi.Visitor):
231     '''Type visitor which will register (un)swizzled value pairs, to later be
232     swizzled.'''
233
234     def visitLiteral(self, literal, lvalue, rvalue):
235         pass
236
237     def visitAlias(self, alias, lvalue, rvalue):
238         self.visit(alias.type, lvalue, rvalue)
239     
240     def visitEnum(self, enum, lvalue, rvalue):
241         pass
242
243     def visitBitmask(self, bitmask, lvalue, rvalue):
244         pass
245
246     def visitArray(self, array, lvalue, rvalue):
247         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
248         print '    if (__a%s) {' % (array.tag)
249         length = '__a%s->values.size()' % array.tag
250         index = '__j' + array.tag
251         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
252         try:
253             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
254         finally:
255             print '        }'
256             print '    }'
257     
258     def visitPointer(self, pointer, lvalue, rvalue):
259         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
260         print '    if (__a%s) {' % (pointer.tag)
261         try:
262             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
263         finally:
264             print '    }'
265     
266     def visitIntPointer(self, pointer, lvalue, rvalue):
267         pass
268     
269     def visitObjPointer(self, pointer, lvalue, rvalue):
270         print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
271     
272     def visitLinearPointer(self, pointer, lvalue, rvalue):
273         assert pointer.size is not None
274         if pointer.size is not None:
275             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
276
277     def visitReference(self, reference, lvalue, rvalue):
278         pass
279     
280     def visitHandle(self, handle, lvalue, rvalue):
281         print '    %s __orig_result;' % handle.type
282         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
283         if handle.range is None:
284             rvalue = "__orig_result"
285             entry = lookupHandle(handle, rvalue) 
286             print "    %s = %s;" % (entry, lvalue)
287             print '    if (retrace::verbosity >= 2) {'
288             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
289             print '    }'
290         else:
291             i = '__h' + handle.tag
292             lvalue = "%s + %s" % (lvalue, i)
293             rvalue = "__orig_result + %s" % (i,)
294             entry = lookupHandle(handle, rvalue) 
295             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
296             print '        {entry} = {lvalue};'.format(**locals())
297             print '        if (retrace::verbosity >= 2) {'
298             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
299             print '        }'
300             print '    }'
301     
302     def visitBlob(self, blob, lvalue, rvalue):
303         pass
304     
305     def visitString(self, string, lvalue, rvalue):
306         pass
307
308     seq = 0
309
310     def visitStruct(self, struct, lvalue, rvalue):
311         tmp = '__s_' + struct.tag + '_' + str(self.seq)
312         self.seq += 1
313
314         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
315         print '    assert(%s);' % (tmp,)
316         print '    (void)%s;' % (tmp,)
317         for i in range(len(struct.members)):
318             member_type, member_name = struct.members[i]
319             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
320     
321     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
322         self.visit(polymorphic.defaultType, lvalue, rvalue)
323     
324     def visitOpaque(self, opaque, lvalue, rvalue):
325         pass
326
327
328 class Retracer:
329
330     def retraceFunction(self, function):
331         print 'static void retrace_%s(trace::Call &call) {' % function.name
332         self.retraceFunctionBody(function)
333         print '}'
334         print
335
336     def retraceInterfaceMethod(self, interface, method):
337         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
338         self.retraceInterfaceMethodBody(interface, method)
339         print '}'
340         print
341
342     def retraceFunctionBody(self, function):
343         assert function.sideeffects
344
345         self.deserializeArgs(function)
346         
347         self.invokeFunction(function)
348
349         self.swizzleValues(function)
350
351     def retraceInterfaceMethodBody(self, interface, method):
352         assert method.sideeffects
353
354         self.deserializeThisPointer(interface)
355
356         self.deserializeArgs(method)
357         
358         self.invokeInterfaceMethod(interface, method)
359
360         self.swizzleValues(method)
361
362     def deserializeThisPointer(self, interface):
363         print r'    %s *_this;' % (interface.name,)
364         print r'    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
365         print r'    if (!_this) {'
366         print r'        retrace::warning(call) << "NULL this pointer\n";'
367         print r'        return;'
368         print r'    }'
369
370     def deserializeArgs(self, function):
371         print '    retrace::ScopedAllocator _allocator;'
372         print '    (void)_allocator;'
373         success = True
374         for arg in function.args:
375             arg_type = MutableRebuilder().visit(arg.type)
376             print '    %s %s;' % (arg_type, arg.name)
377             rvalue = 'call.arg(%u)' % (arg.index,)
378             lvalue = arg.name
379             try:
380                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
381             except UnsupportedType:
382                 success =  False
383                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
384             print
385
386         if not success:
387             print '    if (1) {'
388             self.failFunction(function)
389             if function.name[-1].islower():
390                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
391             print '    }'
392
393     def swizzleValues(self, function):
394         for arg in function.args:
395             if arg.output:
396                 arg_type = MutableRebuilder().visit(arg.type)
397                 rvalue = 'call.arg(%u)' % (arg.index,)
398                 lvalue = arg.name
399                 try:
400                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
401                 except UnsupportedType:
402                     print '    // XXX: %s' % arg.name
403         if function.type is not stdapi.Void:
404             rvalue = '*call.ret'
405             lvalue = '__result'
406             try:
407                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
408             except UnsupportedType:
409                 raise
410                 print '    // XXX: result'
411
412     def failFunction(self, function):
413         print '    if (retrace::verbosity >= 0) {'
414         print '        retrace::unsupported(call);'
415         print '    }'
416         print '    return;'
417
418     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
419         ValueAllocator().visit(arg_type, lvalue, rvalue)
420         if arg.input:
421             ValueDeserializer().visit(arg_type, lvalue, rvalue)
422     
423     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
424         try:
425             ValueAllocator().visit(arg_type, lvalue, rvalue)
426         except UnsupportedType:
427             pass
428         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
429
430     def regiterSwizzledValue(self, type, lvalue, rvalue):
431         visitor = SwizzledValueRegistrator()
432         visitor.visit(type, lvalue, rvalue)
433
434     def invokeFunction(self, function):
435         arg_names = ", ".join(function.argNames())
436         if function.type is not stdapi.Void:
437             print '    %s __result;' % (function.type)
438             print '    __result = %s(%s);' % (function.name, arg_names)
439             print '    (void)__result;'
440         else:
441             print '    %s(%s);' % (function.name, arg_names)
442
443     def invokeInterfaceMethod(self, interface, method):
444         arg_names = ", ".join(method.argNames())
445         if method.type is not stdapi.Void:
446             print '    %s __result;' % (method.type)
447             print '    __result = _this->%s(%s);' % (method.name, arg_names)
448             print '    (void)__result;'
449         else:
450             print '    _this->%s(%s);' % (method.name, arg_names)
451
452     def filterFunction(self, function):
453         return True
454
455     table_name = 'retrace::callbacks'
456
457     def retraceApi(self, api):
458
459         print '#include "os_time.hpp"'
460         print '#include "trace_parser.hpp"'
461         print '#include "retrace.hpp"'
462         print
463
464         types = api.getAllTypes()
465         handles = [type for type in types if isinstance(type, stdapi.Handle)]
466         handle_names = set()
467         for handle in handles:
468             if handle.name not in handle_names:
469                 if handle.key is None:
470                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
471                 else:
472                     key_name, key_type = handle.key
473                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
474                 handle_names.add(handle.name)
475         print
476
477         print 'static std::map<unsigned long long, void *> _obj_map;'
478         print
479
480         functions = filter(self.filterFunction, api.functions)
481         for function in functions:
482             if function.sideeffects:
483                 self.retraceFunction(function)
484         interfaces = api.getAllInterfaces()
485         for interface in interfaces:
486             for method in interface.iterMethods():
487                 if method.sideeffects:
488                     self.retraceInterfaceMethod(interface, method)
489
490         print 'const retrace::Entry %s[] = {' % self.table_name
491         for function in functions:
492             if function.sideeffects:
493                 print '    {"%s", &retrace_%s},' % (function.name, function.name)
494             else:
495                 print '    {"%s", &retrace::ignore},' % (function.name,)
496         for interface in interfaces:
497             for method in interface.iterMethods():                
498                 if method.sideeffects:
499                     print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
500                 else:
501                     print '    {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
502         print '    {NULL, NULL}'
503         print '};'
504         print
505