1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
36 class ConstRemover(stdapi.Rebuilder):
37 '''Type visitor which strips out const qualifiers from types.'''
39 def visitConst(self, const):
42 def visitOpaque(self, opaque):
46 def lookupHandle(handle, value):
47 if handle.key is None:
48 return "__%s_map[%s]" % (handle.name, value)
50 key_name, key_type = handle.key
51 return "__%s_map[%s][%s]" % (handle.name, key_name, value)
54 class ValueDeserializer(stdapi.Visitor):
56 def visitLiteral(self, literal, lvalue, rvalue):
57 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
59 def visitConst(self, const, lvalue, rvalue):
60 self.visit(const.type, lvalue, rvalue)
62 def visitAlias(self, alias, lvalue, rvalue):
63 self.visit(alias.type, lvalue, rvalue)
65 def visitEnum(self, enum, lvalue, rvalue):
66 print ' %s = (%s).toSInt();' % (lvalue, rvalue)
68 def visitBitmask(self, bitmask, lvalue, rvalue):
69 self.visit(bitmask.type, lvalue, rvalue)
71 def visitArray(self, array, lvalue, rvalue):
72 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
73 print ' if (__a%s) {' % (array.tag)
74 length = '__a%s->values.size()' % array.tag
75 print ' %s = _allocator.alloc<%s>(%s);' % (lvalue, array.type, length)
76 index = '__j' + array.tag
77 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
79 self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
83 print ' %s = NULL;' % lvalue
86 def visitPointer(self, pointer, lvalue, rvalue):
87 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
88 print ' if (__a%s) {' % (pointer.tag)
89 print ' %s = _allocator.alloc<%s>();' % (lvalue, pointer.type)
91 self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
94 print ' %s = NULL;' % lvalue
97 def visitIntPointer(self, pointer, lvalue, rvalue):
98 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
100 def visitLinearPointer(self, pointer, lvalue, rvalue):
101 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
103 def visitHandle(self, handle, lvalue, rvalue):
104 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
105 self.visit(handle.type, lvalue, rvalue);
106 new_lvalue = lookupHandle(handle, lvalue)
107 print ' if (retrace::verbosity >= 2) {'
108 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
110 print ' %s = %s;' % (lvalue, new_lvalue)
112 def visitBlob(self, blob, lvalue, rvalue):
113 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
115 def visitString(self, string, lvalue, rvalue):
116 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
119 class OpaqueValueDeserializer(ValueDeserializer):
120 '''Value extractor that also understands opaque values.
122 Normally opaque values can't be retraced, unless they are being extracted
123 in the context of handles.'''
125 def visitOpaque(self, opaque, lvalue, rvalue):
126 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
129 class SwizzledValueRegistrator(stdapi.Visitor):
130 '''Type visitor which will register (un)swizzled value pairs, to later be
133 def visitLiteral(self, literal, lvalue, rvalue):
136 def visitAlias(self, alias, lvalue, rvalue):
137 self.visit(alias.type, lvalue, rvalue)
139 def visitEnum(self, enum, lvalue, rvalue):
142 def visitBitmask(self, bitmask, lvalue, rvalue):
145 def visitArray(self, array, lvalue, rvalue):
146 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
147 print ' if (__a%s) {' % (array.tag)
148 length = '__a%s->values.size()' % array.tag
149 index = '__j' + array.tag
150 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
152 self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
157 def visitPointer(self, pointer, lvalue, rvalue):
158 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
159 print ' if (__a%s) {' % (pointer.tag)
161 self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
165 def visitIntPointer(self, pointer, lvalue, rvalue):
168 def visitLinearPointer(self, pointer, lvalue, rvalue):
169 assert pointer.size is not None
170 if pointer.size is not None:
171 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
173 def visitHandle(self, handle, lvalue, rvalue):
174 print ' %s __orig_result;' % handle.type
175 OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
176 if handle.range is None:
177 rvalue = "__orig_result"
178 entry = lookupHandle(handle, rvalue)
179 print " %s = %s;" % (entry, lvalue)
180 print ' if (retrace::verbosity >= 2) {'
181 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
184 i = '__h' + handle.tag
185 lvalue = "%s + %s" % (lvalue, i)
186 rvalue = "__orig_result + %s" % (i,)
187 entry = lookupHandle(handle, rvalue)
188 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
189 print ' {entry} = {lvalue};'.format(**locals())
190 print ' if (retrace::verbosity >= 2) {'
191 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
195 def visitBlob(self, blob, lvalue, rvalue):
198 def visitString(self, string, lvalue, rvalue):
204 def retraceFunction(self, function):
205 print 'static void retrace_%s(trace::Call &call) {' % function.name
206 self.retraceFunctionBody(function)
210 def retraceInterfaceMethod(self, interface, method):
211 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
212 self.retraceInterfaceMethodBody(interface, method)
216 def retraceFunctionBody(self, function):
217 if not function.sideeffects:
221 self.deserializeArgs(function)
223 self.invokeFunction(function)
225 self.swizzleValues(function)
227 def retraceInterfaceMethodBody(self, interface, method):
228 if not method.sideeffects:
232 self.deserializeThisPointer(interface)
234 self.deserializeArgs(method)
236 self.invokeInterfaceMethod(interface, method)
238 self.swizzleValues(method)
240 def deserializeThisPointer(self, interface):
241 print ' %s *_this;' % (interface.name,)
244 def deserializeArgs(self, function):
245 print ' retrace::ScopedAllocator _allocator;'
246 print ' (void)_allocator;'
248 for arg in function.args:
249 arg_type = ConstRemover().visit(arg.type)
250 #print ' // %s -> %s' % (arg.type, arg_type)
251 print ' %s %s;' % (arg_type, arg.name)
252 rvalue = 'call.arg(%u)' % (arg.index,)
255 self.extractArg(function, arg, arg_type, lvalue, rvalue)
256 except NotImplementedError:
258 print ' %s = 0; // FIXME' % arg.name
262 self.failFunction(function)
263 if function.name[-1].islower():
264 sys.stderr.write('warning: unsupported %s call\n' % function.name)
267 def swizzleValues(self, function):
268 for arg in function.args:
270 arg_type = ConstRemover().visit(arg.type)
271 rvalue = 'call.arg(%u)' % (arg.index,)
274 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
275 except NotImplementedError:
276 print ' // XXX: %s' % arg.name
277 if function.type is not stdapi.Void:
281 self.regiterSwizzledValue(function.type, lvalue, rvalue)
282 except NotImplementedError:
283 print ' // XXX: result'
285 def failFunction(self, function):
286 print ' if (retrace::verbosity >= 0) {'
287 print ' retrace::unsupported(call);'
291 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
292 ValueDeserializer().visit(arg_type, lvalue, rvalue)
294 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
295 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
297 def regiterSwizzledValue(self, type, lvalue, rvalue):
298 visitor = SwizzledValueRegistrator()
299 visitor.visit(type, lvalue, rvalue)
301 def invokeFunction(self, function):
302 arg_names = ", ".join(function.argNames())
303 if function.type is not stdapi.Void:
304 print ' %s __result;' % (function.type)
305 print ' __result = %s(%s);' % (function.name, arg_names)
306 print ' (void)__result;'
308 print ' %s(%s);' % (function.name, arg_names)
310 def invokeInterfaceMethod(self, interface, method):
311 arg_names = ", ".join(method.argNames())
312 if method.type is not stdapi.Void:
313 print ' %s __result;' % (method.type)
314 print ' __result = _this->%s(%s);' % (method.name, arg_names)
315 print ' (void)__result;'
317 print ' _this->%s(%s);' % (method.name, arg_names)
319 def filterFunction(self, function):
322 table_name = 'retrace::callbacks'
324 def retraceApi(self, api):
326 print '#include "trace_parser.hpp"'
327 print '#include "retrace.hpp"'
330 types = api.getAllTypes()
331 handles = [type for type in types if isinstance(type, stdapi.Handle)]
333 for handle in handles:
334 if handle.name not in handle_names:
335 if handle.key is None:
336 print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
338 key_name, key_type = handle.key
339 print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
340 handle_names.add(handle.name)
343 functions = filter(self.filterFunction, api.functions)
344 for function in functions:
345 self.retraceFunction(function)
346 interfaces = api.getAllInterfaces()
347 for interface in interfaces:
348 for method in interface.iterMethods():
349 self.retraceInterfaceMethod(interface, method)
351 print 'const retrace::Entry %s[] = {' % self.table_name
352 for function in functions:
353 print ' {"%s", &retrace_%s},' % (function.name, function.name)
354 for interface in interfaces:
355 for method in interface.iterMethods():
356 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
357 print ' {NULL, NULL}'