]> git.cworth.org Git - apitrace/blob - retrace.py
Avoid conflicting tmps.
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class MutableRebuilder(stdapi.Rebuilder):
37     '''Type visitor which derives a mutable type.'''
38
39     def visitConst(self, const):
40         # Strip out const qualifier
41         return const.type
42
43     def visitAlias(self, alias):
44         # Tear the alias on type changes
45         type = self.visit(alias.type)
46         if type is alias.type:
47             return alias
48         return type
49
50     def visitReference(self, reference):
51         # Strip out references
52         return reference.type
53
54     def visitOpaque(self, opaque):
55         # Don't recursule
56         return opaque
57
58
59 def lookupHandle(handle, value):
60     if handle.key is None:
61         return "__%s_map[%s]" % (handle.name, value)
62     else:
63         key_name, key_type = handle.key
64         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
65
66
67 class ValueDeserializer(stdapi.Visitor):
68
69     def visitLiteral(self, literal, lvalue, rvalue):
70         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
71
72     def visitConst(self, const, lvalue, rvalue):
73         self.visit(const.type, lvalue, rvalue)
74
75     def visitAlias(self, alias, lvalue, rvalue):
76         self.visit(alias.type, lvalue, rvalue)
77     
78     def visitEnum(self, enum, lvalue, rvalue):
79         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
80
81     def visitBitmask(self, bitmask, lvalue, rvalue):
82         self.visit(bitmask.type, lvalue, rvalue)
83
84     allocated = False
85
86     def visitArray(self, array, lvalue, rvalue):
87         tmp = '__a_' + array.tag + '_' + str(self.seq)
88         self.seq += 1
89
90         print '    const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
91         length = '%s->values.size()' % (tmp,)
92         allocated = self.allocated
93         if not allocated:
94             print '    if (%s) {' % (tmp,)
95             print '        %s = _allocator.alloc<%s>(%s);' % (lvalue, array.type, length)
96             self.allocated = True
97         index = '__j' + array.tag
98         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
99         try:
100             self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
101         finally:
102             print '        }'
103             if not allocated:
104                 print '    } else {'
105                 print '        %s = NULL;' % lvalue
106                 print '    }'
107     
108     def visitPointer(self, pointer, lvalue, rvalue):
109         tmp = '__a_' + pointer.tag + '_' + str(self.seq)
110         self.seq += 1
111
112         print '    const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
113         allocated = self.allocated
114         if not allocated:
115             print '    if (%s) {' % (tmp)
116             print '        %s = _allocator.alloc<%s>();' % (lvalue, pointer.type)
117             self.allocated = True
118         try:
119             self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
120         finally:
121             if not allocated:
122                 print '    } else {'
123                 print '        %s = NULL;' % lvalue
124                 print '    }'
125
126     def visitIntPointer(self, pointer, lvalue, rvalue):
127         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
128
129     def visitObjPointer(self, pointer, lvalue, rvalue):
130         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
131
132     def visitLinearPointer(self, pointer, lvalue, rvalue):
133         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
134
135     def visitReference(self, reference, lvalue, rvalue):
136         self.visit(reference.type, lvalue, rvalue);
137
138     def visitHandle(self, handle, lvalue, rvalue):
139         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
140         self.visit(handle.type, lvalue, rvalue);
141         new_lvalue = lookupHandle(handle, lvalue)
142         print '    if (retrace::verbosity >= 2) {'
143         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
144         print '    }'
145         print '    %s = %s;' % (lvalue, new_lvalue)
146     
147     def visitBlob(self, blob, lvalue, rvalue):
148         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
149     
150     def visitString(self, string, lvalue, rvalue):
151         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
152
153     seq = 0
154
155     def visitStruct(self, struct, lvalue, rvalue):
156         tmp = '__s_' + struct.tag + '_' + str(self.seq)
157         self.seq += 1
158
159         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
160         print '    assert(%s);' % (tmp)
161         self.allocated = True
162         for i in range(len(struct.members)):
163             member_type, member_name = struct.members[i]
164             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
165
166
167 class OpaqueValueDeserializer(ValueDeserializer):
168     '''Value extractor that also understands opaque values.
169
170     Normally opaque values can't be retraced, unless they are being extracted
171     in the context of handles.'''
172
173     def visitOpaque(self, opaque, lvalue, rvalue):
174         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
175
176
177 class SwizzledValueRegistrator(stdapi.Visitor):
178     '''Type visitor which will register (un)swizzled value pairs, to later be
179     swizzled.'''
180
181     def visitLiteral(self, literal, lvalue, rvalue):
182         pass
183
184     def visitAlias(self, alias, lvalue, rvalue):
185         self.visit(alias.type, lvalue, rvalue)
186     
187     def visitEnum(self, enum, lvalue, rvalue):
188         pass
189
190     def visitBitmask(self, bitmask, lvalue, rvalue):
191         pass
192
193     def visitArray(self, array, lvalue, rvalue):
194         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
195         print '    if (__a%s) {' % (array.tag)
196         length = '__a%s->values.size()' % array.tag
197         index = '__j' + array.tag
198         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
199         try:
200             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
201         finally:
202             print '        }'
203             print '    }'
204     
205     def visitPointer(self, pointer, lvalue, rvalue):
206         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
207         print '    if (__a%s) {' % (pointer.tag)
208         try:
209             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
210         finally:
211             print '    }'
212     
213     def visitIntPointer(self, pointer, lvalue, rvalue):
214         pass
215     
216     def visitObjPointer(self, pointer, lvalue, rvalue):
217         print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
218     
219     def visitLinearPointer(self, pointer, lvalue, rvalue):
220         assert pointer.size is not None
221         if pointer.size is not None:
222             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
223
224     def visitReference(self, reference, lvalue, rvalue):
225         pass
226     
227     def visitHandle(self, handle, lvalue, rvalue):
228         print '    %s __orig_result;' % handle.type
229         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
230         if handle.range is None:
231             rvalue = "__orig_result"
232             entry = lookupHandle(handle, rvalue) 
233             print "    %s = %s;" % (entry, lvalue)
234             print '    if (retrace::verbosity >= 2) {'
235             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
236             print '    }'
237         else:
238             i = '__h' + handle.tag
239             lvalue = "%s + %s" % (lvalue, i)
240             rvalue = "__orig_result + %s" % (i,)
241             entry = lookupHandle(handle, rvalue) 
242             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
243             print '        {entry} = {lvalue};'.format(**locals())
244             print '        if (retrace::verbosity >= 2) {'
245             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
246             print '        }'
247             print '    }'
248     
249     def visitBlob(self, blob, lvalue, rvalue):
250         pass
251     
252     def visitString(self, string, lvalue, rvalue):
253         pass
254
255
256 class Retracer:
257
258     def retraceFunction(self, function):
259         print 'static void retrace_%s(trace::Call &call) {' % function.name
260         self.retraceFunctionBody(function)
261         print '}'
262         print
263
264     def retraceInterfaceMethod(self, interface, method):
265         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
266         self.retraceInterfaceMethodBody(interface, method)
267         print '}'
268         print
269
270     def retraceFunctionBody(self, function):
271         if not function.sideeffects:
272             print '    (void)call;'
273             return
274
275         self.deserializeArgs(function)
276         
277         self.invokeFunction(function)
278
279         self.swizzleValues(function)
280
281     def retraceInterfaceMethodBody(self, interface, method):
282         if not method.sideeffects:
283             print '    (void)call;'
284             return
285
286         self.deserializeThisPointer(interface)
287
288         self.deserializeArgs(method)
289         
290         self.invokeInterfaceMethod(interface, method)
291
292         self.swizzleValues(method)
293
294     def deserializeThisPointer(self, interface):
295         print r'    %s *_this;' % (interface.name,)
296         print r'    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
297         print r'    if (!_this) {'
298         print r'        retrace::warning(call) << "NULL this pointer\n";'
299         print r'        return;'
300         print r'    }'
301
302     def deserializeArgs(self, function):
303         print '    retrace::ScopedAllocator _allocator;'
304         print '    (void)_allocator;'
305         success = True
306         for arg in function.args:
307             arg_type = MutableRebuilder().visit(arg.type)
308             #print '    // %s ->  %s' % (arg.type, arg_type)
309             print '    %s %s;' % (arg_type, arg.name)
310             rvalue = 'call.arg(%u)' % (arg.index,)
311             lvalue = arg.name
312             try:
313                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
314             except NotImplementedError:
315                 success =  False
316                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
317
318         if not success:
319             print '    if (1) {'
320             self.failFunction(function)
321             if function.name[-1].islower():
322                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
323             print '    }'
324
325     def swizzleValues(self, function):
326         for arg in function.args:
327             if arg.output:
328                 arg_type = MutableRebuilder().visit(arg.type)
329                 rvalue = 'call.arg(%u)' % (arg.index,)
330                 lvalue = arg.name
331                 try:
332                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
333                 except NotImplementedError:
334                     print '    // XXX: %s' % arg.name
335         if function.type is not stdapi.Void:
336             rvalue = '*call.ret'
337             lvalue = '__result'
338             try:
339                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
340             except NotImplementedError:
341                 raise
342                 print '    // XXX: result'
343
344     def failFunction(self, function):
345         print '    if (retrace::verbosity >= 0) {'
346         print '        retrace::unsupported(call);'
347         print '    }'
348         print '    return;'
349
350     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
351         ValueDeserializer().visit(arg_type, lvalue, rvalue)
352     
353     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
354         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
355
356     def regiterSwizzledValue(self, type, lvalue, rvalue):
357         visitor = SwizzledValueRegistrator()
358         visitor.visit(type, lvalue, rvalue)
359
360     def invokeFunction(self, function):
361         arg_names = ", ".join(function.argNames())
362         if function.type is not stdapi.Void:
363             print '    %s __result;' % (function.type)
364             print '    __result = %s(%s);' % (function.name, arg_names)
365             print '    (void)__result;'
366         else:
367             print '    %s(%s);' % (function.name, arg_names)
368
369     def invokeInterfaceMethod(self, interface, method):
370         arg_names = ", ".join(method.argNames())
371         if method.type is not stdapi.Void:
372             print '    %s __result;' % (method.type)
373             print '    __result = _this->%s(%s);' % (method.name, arg_names)
374             print '    (void)__result;'
375         else:
376             print '    _this->%s(%s);' % (method.name, arg_names)
377
378     def filterFunction(self, function):
379         return True
380
381     table_name = 'retrace::callbacks'
382
383     def retraceApi(self, api):
384
385         print '#include "os_time.hpp"'
386         print '#include "trace_parser.hpp"'
387         print '#include "retrace.hpp"'
388         print
389
390         types = api.getAllTypes()
391         handles = [type for type in types if isinstance(type, stdapi.Handle)]
392         handle_names = set()
393         for handle in handles:
394             if handle.name not in handle_names:
395                 if handle.key is None:
396                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
397                 else:
398                     key_name, key_type = handle.key
399                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
400                 handle_names.add(handle.name)
401         print
402
403         print 'static std::map<unsigned long long, void *> _obj_map;'
404         print
405
406         functions = filter(self.filterFunction, api.functions)
407         for function in functions:
408             self.retraceFunction(function)
409         interfaces = api.getAllInterfaces()
410         for interface in interfaces:
411             for method in interface.iterMethods():
412                 self.retraceInterfaceMethod(interface, method)
413
414         print 'const retrace::Entry %s[] = {' % self.table_name
415         for function in functions:
416             print '    {"%s", &retrace_%s},' % (function.name, function.name)
417         for interface in interfaces:
418             for method in interface.iterMethods():
419                 print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
420         print '    {NULL, NULL}'
421         print '};'
422         print
423