]> git.cworth.org Git - apitrace/blob - retrace.py
Describe C++ references accurately.
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class ConstRemover(stdapi.Rebuilder):
37     '''Type visitor which strips out const qualifiers from types.'''
38
39     def visitConst(self, const):
40         return const.type
41
42     def visitAlias(self, alias):
43         type = self.visit(alias.type)
44         if type is alias.type:
45             return alias
46         return type
47
48     def visitReference(self, reference):
49         return reference.type
50
51     def visitOpaque(self, opaque):
52         return opaque
53
54
55 def lookupHandle(handle, value):
56     if handle.key is None:
57         return "__%s_map[%s]" % (handle.name, value)
58     else:
59         key_name, key_type = handle.key
60         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
61
62
63 class ValueDeserializer(stdapi.Visitor):
64
65     def visitLiteral(self, literal, lvalue, rvalue):
66         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
67
68     def visitConst(self, const, lvalue, rvalue):
69         self.visit(const.type, lvalue, rvalue)
70
71     def visitAlias(self, alias, lvalue, rvalue):
72         self.visit(alias.type, lvalue, rvalue)
73     
74     def visitEnum(self, enum, lvalue, rvalue):
75         print '    %s = (%s).toSInt();' % (lvalue, rvalue)
76
77     def visitBitmask(self, bitmask, lvalue, rvalue):
78         self.visit(bitmask.type, lvalue, rvalue)
79
80     def visitArray(self, array, lvalue, rvalue):
81         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
82         print '    if (__a%s) {' % (array.tag)
83         length = '__a%s->values.size()' % array.tag
84         print '        %s = _allocator.alloc<%s>(%s);' % (lvalue, array.type, length)
85         index = '__j' + array.tag
86         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
87         try:
88             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
89         finally:
90             print '        }'
91             print '    } else {'
92             print '        %s = NULL;' % lvalue
93             print '    }'
94     
95     def visitPointer(self, pointer, lvalue, rvalue):
96         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
97         print '    if (__a%s) {' % (pointer.tag)
98         print '        %s = _allocator.alloc<%s>();' % (lvalue, pointer.type)
99         try:
100             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
101         finally:
102             print '    } else {'
103             print '        %s = NULL;' % lvalue
104             print '    }'
105
106     def visitIntPointer(self, pointer, lvalue, rvalue):
107         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
108
109     def visitLinearPointer(self, pointer, lvalue, rvalue):
110         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
111
112     def visitReference(self, reference, lvalue, rvalue):
113         self.visit(reference.type, lvalue, rvalue);
114
115     def visitHandle(self, handle, lvalue, rvalue):
116         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
117         self.visit(handle.type, lvalue, rvalue);
118         new_lvalue = lookupHandle(handle, lvalue)
119         print '    if (retrace::verbosity >= 2) {'
120         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
121         print '    }'
122         print '    %s = %s;' % (lvalue, new_lvalue)
123     
124     def visitBlob(self, blob, lvalue, rvalue):
125         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
126     
127     def visitString(self, string, lvalue, rvalue):
128         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
129
130
131 class OpaqueValueDeserializer(ValueDeserializer):
132     '''Value extractor that also understands opaque values.
133
134     Normally opaque values can't be retraced, unless they are being extracted
135     in the context of handles.'''
136
137     def visitOpaque(self, opaque, lvalue, rvalue):
138         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
139
140
141 class SwizzledValueRegistrator(stdapi.Visitor):
142     '''Type visitor which will register (un)swizzled value pairs, to later be
143     swizzled.'''
144
145     def visitLiteral(self, literal, lvalue, rvalue):
146         pass
147
148     def visitAlias(self, alias, lvalue, rvalue):
149         self.visit(alias.type, lvalue, rvalue)
150     
151     def visitEnum(self, enum, lvalue, rvalue):
152         pass
153
154     def visitBitmask(self, bitmask, lvalue, rvalue):
155         pass
156
157     def visitArray(self, array, lvalue, rvalue):
158         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
159         print '    if (__a%s) {' % (array.tag)
160         length = '__a%s->values.size()' % array.tag
161         index = '__j' + array.tag
162         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
163         try:
164             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
165         finally:
166             print '        }'
167             print '    }'
168     
169     def visitPointer(self, pointer, lvalue, rvalue):
170         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
171         print '    if (__a%s) {' % (pointer.tag)
172         try:
173             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
174         finally:
175             print '    }'
176     
177     def visitIntPointer(self, pointer, lvalue, rvalue):
178         pass
179     
180     def visitLinearPointer(self, pointer, lvalue, rvalue):
181         assert pointer.size is not None
182         if pointer.size is not None:
183             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
184
185     def visitReference(self, reference, lvalue, rvalue):
186         pass
187     
188     def visitHandle(self, handle, lvalue, rvalue):
189         print '    %s __orig_result;' % handle.type
190         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
191         if handle.range is None:
192             rvalue = "__orig_result"
193             entry = lookupHandle(handle, rvalue) 
194             print "    %s = %s;" % (entry, lvalue)
195             print '    if (retrace::verbosity >= 2) {'
196             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
197             print '    }'
198         else:
199             i = '__h' + handle.tag
200             lvalue = "%s + %s" % (lvalue, i)
201             rvalue = "__orig_result + %s" % (i,)
202             entry = lookupHandle(handle, rvalue) 
203             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
204             print '        {entry} = {lvalue};'.format(**locals())
205             print '        if (retrace::verbosity >= 2) {'
206             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
207             print '        }'
208             print '    }'
209     
210     def visitBlob(self, blob, lvalue, rvalue):
211         pass
212     
213     def visitString(self, string, lvalue, rvalue):
214         pass
215
216
217 class Retracer:
218
219     def retraceFunction(self, function):
220         print 'static void retrace_%s(trace::Call &call) {' % function.name
221         self.retraceFunctionBody(function)
222         print '}'
223         print
224
225     def retraceInterfaceMethod(self, interface, method):
226         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
227         self.retraceInterfaceMethodBody(interface, method)
228         print '}'
229         print
230
231     def retraceFunctionBody(self, function):
232         if not function.sideeffects:
233             print '    (void)call;'
234             return
235
236         self.deserializeArgs(function)
237         
238         self.invokeFunction(function)
239
240         self.swizzleValues(function)
241
242     def retraceInterfaceMethodBody(self, interface, method):
243         if not method.sideeffects:
244             print '    (void)call;'
245             return
246
247         self.deserializeThisPointer(interface)
248
249         self.deserializeArgs(method)
250         
251         self.invokeInterfaceMethod(interface, method)
252
253         self.swizzleValues(method)
254
255     def deserializeThisPointer(self, interface):
256         print '    %s *_this;' % (interface.name,)
257         # FIXME
258
259     def deserializeArgs(self, function):
260         print '    retrace::ScopedAllocator _allocator;'
261         print '    (void)_allocator;'
262         success = True
263         for arg in function.args:
264             arg_type = ConstRemover().visit(arg.type)
265             #print '    // %s ->  %s' % (arg.type, arg_type)
266             print '    %s %s;' % (arg_type, arg.name)
267             rvalue = 'call.arg(%u)' % (arg.index,)
268             lvalue = arg.name
269             try:
270                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
271             except NotImplementedError:
272                 success =  False
273                 print '    %s = 0; // FIXME' % arg.name
274
275         if not success:
276             print '    if (1) {'
277             self.failFunction(function)
278             if function.name[-1].islower():
279                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
280             print '    }'
281
282     def swizzleValues(self, function):
283         for arg in function.args:
284             if arg.output:
285                 arg_type = ConstRemover().visit(arg.type)
286                 rvalue = 'call.arg(%u)' % (arg.index,)
287                 lvalue = arg.name
288                 try:
289                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
290                 except NotImplementedError:
291                     print '    // XXX: %s' % arg.name
292         if function.type is not stdapi.Void:
293             rvalue = '*call.ret'
294             lvalue = '__result'
295             try:
296                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
297             except NotImplementedError:
298                 print '    // XXX: result'
299
300     def failFunction(self, function):
301         print '    if (retrace::verbosity >= 0) {'
302         print '        retrace::unsupported(call);'
303         print '    }'
304         print '    return;'
305
306     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
307         ValueDeserializer().visit(arg_type, lvalue, rvalue)
308     
309     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
310         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
311
312     def regiterSwizzledValue(self, type, lvalue, rvalue):
313         visitor = SwizzledValueRegistrator()
314         visitor.visit(type, lvalue, rvalue)
315
316     def invokeFunction(self, function):
317         arg_names = ", ".join(function.argNames())
318         if function.type is not stdapi.Void:
319             print '    %s __result;' % (function.type)
320             print '    __result = %s(%s);' % (function.name, arg_names)
321             print '    (void)__result;'
322         else:
323             print '    %s(%s);' % (function.name, arg_names)
324
325     def invokeInterfaceMethod(self, interface, method):
326         arg_names = ", ".join(method.argNames())
327         if method.type is not stdapi.Void:
328             print '    %s __result;' % (method.type)
329             print '    __result = _this->%s(%s);' % (method.name, arg_names)
330             print '    (void)__result;'
331         else:
332             print '    _this->%s(%s);' % (method.name, arg_names)
333
334     def filterFunction(self, function):
335         return True
336
337     table_name = 'retrace::callbacks'
338
339     def retraceApi(self, api):
340
341         print '#include "trace_parser.hpp"'
342         print '#include "retrace.hpp"'
343         print
344
345         types = api.getAllTypes()
346         handles = [type for type in types if isinstance(type, stdapi.Handle)]
347         handle_names = set()
348         for handle in handles:
349             if handle.name not in handle_names:
350                 if handle.key is None:
351                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
352                 else:
353                     key_name, key_type = handle.key
354                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
355                 handle_names.add(handle.name)
356         print
357
358         functions = filter(self.filterFunction, api.functions)
359         for function in functions:
360             self.retraceFunction(function)
361         interfaces = api.getAllInterfaces()
362         for interface in interfaces:
363             for method in interface.iterMethods():
364                 self.retraceInterfaceMethod(interface, method)
365
366         print 'const retrace::Entry %s[] = {' % self.table_name
367         for function in functions:
368             print '    {"%s", &retrace_%s},' % (function.name, function.name)
369         for interface in interfaces:
370             for method in interface.iterMethods():
371                 print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
372         print '    {NULL, NULL}'
373         print '};'
374         print
375