]> git.cworth.org Git - apitrace/blob - retrace.py
ebbbf44c77a40202f3458c9cd87425a2512e06d0
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class MutableRebuilder(stdapi.Rebuilder):
37     '''Type visitor which derives a mutable type.'''
38
39     def visitConst(self, const):
40         # Strip out const qualifier
41         return const.type
42
43     def visitAlias(self, alias):
44         # Tear the alias on type changes
45         type = self.visit(alias.type)
46         if type is alias.type:
47             return alias
48         return type
49
50     def visitReference(self, reference):
51         # Strip out references
52         return reference.type
53
54     def visitOpaque(self, opaque):
55         # Don't recursule
56         return opaque
57
58
59 def lookupHandle(handle, value):
60     if handle.key is None:
61         return "__%s_map[%s]" % (handle.name, value)
62     else:
63         key_name, key_type = handle.key
64         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
65
66
67 class ValueDeserializer(stdapi.Visitor):
68
69     def visitLiteral(self, literal, lvalue, rvalue):
70         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
71
72     def visitConst(self, const, lvalue, rvalue):
73         self.visit(const.type, lvalue, rvalue)
74
75     def visitAlias(self, alias, lvalue, rvalue):
76         self.visit(alias.type, lvalue, rvalue)
77     
78     def visitEnum(self, enum, lvalue, rvalue):
79         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
80
81     def visitBitmask(self, bitmask, lvalue, rvalue):
82         self.visit(bitmask.type, lvalue, rvalue)
83
84     allocated = False
85
86     def visitArray(self, array, lvalue, rvalue):
87         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
88         length = '__a%s->values.size()' % array.tag
89         allocated = self.allocated
90         if not allocated:
91             print '    if (__a%s) {' % (array.tag)
92             print '        %s = _allocator.alloc<%s>(%s);' % (lvalue, array.type, length)
93             self.allocated = True
94         index = '__j' + array.tag
95         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
96         try:
97             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
98         finally:
99             print '        }'
100             if not allocated:
101                 print '    } else {'
102                 print '        %s = NULL;' % lvalue
103                 print '    }'
104     
105     def visitPointer(self, pointer, lvalue, rvalue):
106         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
107         allocated = self.allocated
108         if not allocated:
109             print '    if (__a%s) {' % (pointer.tag)
110             print '        %s = _allocator.alloc<%s>();' % (lvalue, pointer.type)
111             self.allocated = True
112         try:
113             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
114         finally:
115             if not allocated:
116                 print '    } else {'
117                 print '        %s = NULL;' % lvalue
118                 print '    }'
119
120     def visitIntPointer(self, pointer, lvalue, rvalue):
121         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
122
123     def visitLinearPointer(self, pointer, lvalue, rvalue):
124         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
125
126     def visitReference(self, reference, lvalue, rvalue):
127         self.visit(reference.type, lvalue, rvalue);
128
129     def visitHandle(self, handle, lvalue, rvalue):
130         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
131         self.visit(handle.type, lvalue, rvalue);
132         new_lvalue = lookupHandle(handle, lvalue)
133         print '    if (retrace::verbosity >= 2) {'
134         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
135         print '    }'
136         print '    %s = %s;' % (lvalue, new_lvalue)
137     
138     def visitBlob(self, blob, lvalue, rvalue):
139         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
140     
141     def visitString(self, string, lvalue, rvalue):
142         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
143
144     seq = 0
145
146     def visitStruct(self, struct, lvalue, rvalue):
147         tmp = '__s_' + struct.tag + '_' + str(self.seq)
148         self.seq += 1
149
150         print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
151         print '    assert(%s);' % (tmp)
152         self.allocated = True
153         for i in range(len(struct.members)):
154             member_type, member_name = struct.members[i]
155             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
156
157
158 class OpaqueValueDeserializer(ValueDeserializer):
159     '''Value extractor that also understands opaque values.
160
161     Normally opaque values can't be retraced, unless they are being extracted
162     in the context of handles.'''
163
164     def visitOpaque(self, opaque, lvalue, rvalue):
165         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
166
167
168 class SwizzledValueRegistrator(stdapi.Visitor):
169     '''Type visitor which will register (un)swizzled value pairs, to later be
170     swizzled.'''
171
172     def visitLiteral(self, literal, lvalue, rvalue):
173         pass
174
175     def visitAlias(self, alias, lvalue, rvalue):
176         self.visit(alias.type, lvalue, rvalue)
177     
178     def visitEnum(self, enum, lvalue, rvalue):
179         pass
180
181     def visitBitmask(self, bitmask, lvalue, rvalue):
182         pass
183
184     def visitArray(self, array, lvalue, rvalue):
185         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
186         print '    if (__a%s) {' % (array.tag)
187         length = '__a%s->values.size()' % array.tag
188         index = '__j' + array.tag
189         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
190         try:
191             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
192         finally:
193             print '        }'
194             print '    }'
195     
196     def visitPointer(self, pointer, lvalue, rvalue):
197         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
198         print '    if (__a%s) {' % (pointer.tag)
199         try:
200             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
201         finally:
202             print '    }'
203     
204     def visitIntPointer(self, pointer, lvalue, rvalue):
205         pass
206     
207     def visitLinearPointer(self, pointer, lvalue, rvalue):
208         assert pointer.size is not None
209         if pointer.size is not None:
210             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
211
212     def visitReference(self, reference, lvalue, rvalue):
213         pass
214     
215     def visitHandle(self, handle, lvalue, rvalue):
216         print '    %s __orig_result;' % handle.type
217         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
218         if handle.range is None:
219             rvalue = "__orig_result"
220             entry = lookupHandle(handle, rvalue) 
221             print "    %s = %s;" % (entry, lvalue)
222             print '    if (retrace::verbosity >= 2) {'
223             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
224             print '    }'
225         else:
226             i = '__h' + handle.tag
227             lvalue = "%s + %s" % (lvalue, i)
228             rvalue = "__orig_result + %s" % (i,)
229             entry = lookupHandle(handle, rvalue) 
230             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
231             print '        {entry} = {lvalue};'.format(**locals())
232             print '        if (retrace::verbosity >= 2) {'
233             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
234             print '        }'
235             print '    }'
236     
237     def visitBlob(self, blob, lvalue, rvalue):
238         pass
239     
240     def visitString(self, string, lvalue, rvalue):
241         pass
242
243
244 class Retracer:
245
246     def retraceFunction(self, function):
247         print 'static void retrace_%s(trace::Call &call) {' % function.name
248         self.retraceFunctionBody(function)
249         print '}'
250         print
251
252     def retraceInterfaceMethod(self, interface, method):
253         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
254         self.retraceInterfaceMethodBody(interface, method)
255         print '}'
256         print
257
258     def retraceFunctionBody(self, function):
259         if not function.sideeffects:
260             print '    (void)call;'
261             return
262
263         self.deserializeArgs(function)
264         
265         self.invokeFunction(function)
266
267         self.swizzleValues(function)
268
269     def retraceInterfaceMethodBody(self, interface, method):
270         if not method.sideeffects:
271             print '    (void)call;'
272             return
273
274         self.deserializeThisPointer(interface)
275
276         self.deserializeArgs(method)
277         
278         self.invokeInterfaceMethod(interface, method)
279
280         self.swizzleValues(method)
281
282     def deserializeThisPointer(self, interface):
283         print '    %s *_this;' % (interface.name,)
284         # FIXME
285
286     def deserializeArgs(self, function):
287         print '    retrace::ScopedAllocator _allocator;'
288         print '    (void)_allocator;'
289         success = True
290         for arg in function.args:
291             arg_type = MutableRebuilder().visit(arg.type)
292             #print '    // %s ->  %s' % (arg.type, arg_type)
293             print '    %s %s;' % (arg_type, arg.name)
294             rvalue = 'call.arg(%u)' % (arg.index,)
295             lvalue = arg.name
296             try:
297                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
298             except NotImplementedError:
299                 success =  False
300                 print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
301
302         if not success:
303             print '    if (1) {'
304             self.failFunction(function)
305             if function.name[-1].islower():
306                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
307             print '    }'
308
309     def swizzleValues(self, function):
310         for arg in function.args:
311             if arg.output:
312                 arg_type = MutableRebuilder().visit(arg.type)
313                 rvalue = 'call.arg(%u)' % (arg.index,)
314                 lvalue = arg.name
315                 try:
316                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
317                 except NotImplementedError:
318                     print '    // XXX: %s' % arg.name
319         if function.type is not stdapi.Void:
320             rvalue = '*call.ret'
321             lvalue = '__result'
322             try:
323                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
324             except NotImplementedError:
325                 print '    // XXX: result'
326
327     def failFunction(self, function):
328         print '    if (retrace::verbosity >= 0) {'
329         print '        retrace::unsupported(call);'
330         print '    }'
331         print '    return;'
332
333     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
334         ValueDeserializer().visit(arg_type, lvalue, rvalue)
335     
336     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
337         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
338
339     def regiterSwizzledValue(self, type, lvalue, rvalue):
340         visitor = SwizzledValueRegistrator()
341         visitor.visit(type, lvalue, rvalue)
342
343     def invokeFunction(self, function):
344         arg_names = ", ".join(function.argNames())
345         if function.type is not stdapi.Void:
346             print '    %s __result;' % (function.type)
347             print '    __result = %s(%s);' % (function.name, arg_names)
348             print '    (void)__result;'
349         else:
350             print '    %s(%s);' % (function.name, arg_names)
351
352     def invokeInterfaceMethod(self, interface, method):
353         arg_names = ", ".join(method.argNames())
354         if method.type is not stdapi.Void:
355             print '    %s __result;' % (method.type)
356             print '    __result = _this->%s(%s);' % (method.name, arg_names)
357             print '    (void)__result;'
358         else:
359             print '    _this->%s(%s);' % (method.name, arg_names)
360
361     def filterFunction(self, function):
362         return True
363
364     table_name = 'retrace::callbacks'
365
366     def retraceApi(self, api):
367
368         print '#include "os_time.hpp"'
369         print '#include "trace_parser.hpp"'
370         print '#include "retrace.hpp"'
371         print
372
373         types = api.getAllTypes()
374         handles = [type for type in types if isinstance(type, stdapi.Handle)]
375         handle_names = set()
376         for handle in handles:
377             if handle.name not in handle_names:
378                 if handle.key is None:
379                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
380                 else:
381                     key_name, key_type = handle.key
382                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
383                 handle_names.add(handle.name)
384         print
385
386         functions = filter(self.filterFunction, api.functions)
387         for function in functions:
388             self.retraceFunction(function)
389         interfaces = api.getAllInterfaces()
390         for interface in interfaces:
391             for method in interface.iterMethods():
392                 self.retraceInterfaceMethod(interface, method)
393
394         print 'const retrace::Entry %s[] = {' % self.table_name
395         for function in functions:
396             print '    {"%s", &retrace_%s},' % (function.name, function.name)
397         for interface in interfaces:
398             for method in interface.iterMethods():
399                 print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
400         print '    {NULL, NULL}'
401         print '};'
402         print
403