1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
36 class ConstRemover(stdapi.Rebuilder):
37 '''Type visitor which strips out const qualifiers from types.'''
39 def visitConst(self, const):
42 def visitOpaque(self, opaque):
46 def lookupHandle(handle, value):
47 if handle.key is None:
48 return "__%s_map[%s]" % (handle.name, value)
50 key_name, key_type = handle.key
51 return "__%s_map[%s][%s]" % (handle.name, key_name, value)
54 class ValueDeserializer(stdapi.Visitor):
56 def visitLiteral(self, literal, lvalue, rvalue):
57 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
59 def visitConst(self, const, lvalue, rvalue):
60 self.visit(const.type, lvalue, rvalue)
62 def visitAlias(self, alias, lvalue, rvalue):
63 self.visit(alias.type, lvalue, rvalue)
65 def visitEnum(self, enum, lvalue, rvalue):
66 print ' %s = (%s).toSInt();' % (lvalue, rvalue)
68 def visitBitmask(self, bitmask, lvalue, rvalue):
69 self.visit(bitmask.type, lvalue, rvalue)
71 def visitArray(self, array, lvalue, rvalue):
72 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
73 print ' if (__a%s) {' % (array.tag)
74 length = '__a%s->values.size()' % array.tag
75 print ' %s = new %s[%s];' % (lvalue, array.type, length)
76 index = '__j' + array.tag
77 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
79 self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
83 print ' %s = NULL;' % lvalue
86 def visitPointer(self, pointer, lvalue, rvalue):
87 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
88 print ' if (__a%s) {' % (pointer.tag)
89 print ' %s = new %s;' % (lvalue, pointer.type)
91 self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
94 print ' %s = NULL;' % lvalue
97 def visitIntPointer(self, pointer, lvalue, rvalue):
98 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
100 def visitLinearPointer(self, pointer, lvalue, rvalue):
101 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
103 def visitHandle(self, handle, lvalue, rvalue):
104 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
105 self.visit(handle.type, lvalue, rvalue);
106 new_lvalue = lookupHandle(handle, lvalue)
107 print ' if (retrace::verbosity >= 2) {'
108 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
110 print ' %s = %s;' % (lvalue, new_lvalue)
112 def visitBlob(self, blob, lvalue, rvalue):
113 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
115 def visitString(self, string, lvalue, rvalue):
116 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
119 class OpaqueValueDeserializer(ValueDeserializer):
120 '''Value extractor that also understands opaque values.
122 Normally opaque values can't be retraced, unless they are being extracted
123 in the context of handles.'''
125 def visitOpaque(self, opaque, lvalue, rvalue):
126 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
129 class SwizzledValueRegistrator(stdapi.Visitor):
130 '''Type visitor which will register (un)swizzled value pairs, to later be
133 def visitLiteral(self, literal, lvalue, rvalue):
136 def visitAlias(self, alias, lvalue, rvalue):
137 self.visit(alias.type, lvalue, rvalue)
139 def visitEnum(self, enum, lvalue, rvalue):
142 def visitBitmask(self, bitmask, lvalue, rvalue):
145 def visitArray(self, array, lvalue, rvalue):
146 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
147 print ' if (__a%s) {' % (array.tag)
148 length = '__a%s->values.size()' % array.tag
149 index = '__j' + array.tag
150 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
152 self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
157 def visitPointer(self, pointer, lvalue, rvalue):
158 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
159 print ' if (__a%s) {' % (pointer.tag)
161 self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
165 def visitIntPointer(self, pointer, lvalue, rvalue):
168 def visitLinearPointer(self, pointer, lvalue, rvalue):
169 assert pointer.size is not None
170 if pointer.size is not None:
171 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
173 def visitHandle(self, handle, lvalue, rvalue):
174 print ' %s __orig_result;' % handle.type
175 OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
176 if handle.range is None:
177 rvalue = "__orig_result"
178 entry = lookupHandle(handle, rvalue)
179 print " %s = %s;" % (entry, lvalue)
180 print ' if (retrace::verbosity >= 2) {'
181 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
184 i = '__h' + handle.tag
185 lvalue = "%s + %s" % (lvalue, i)
186 rvalue = "__orig_result + %s" % (i,)
187 entry = lookupHandle(handle, rvalue)
188 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
189 print ' {entry} = {lvalue};'.format(**locals())
190 print ' if (retrace::verbosity >= 2) {'
191 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
195 def visitBlob(self, blob, lvalue, rvalue):
198 def visitString(self, string, lvalue, rvalue):
204 def retraceFunction(self, function):
205 print 'static void retrace_%s(trace::Call &call) {' % function.name
206 self.retraceFunctionBody(function)
210 def retraceInterfaceMethod(self, interface, method):
211 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
212 self.retraceInterfaceMethodBody(interface, method)
216 def retraceFunctionBody(self, function):
217 if not function.sideeffects:
221 self.deserializeArgs(function)
223 self.invokeFunction(function)
225 self.swizzleValues(function)
227 def retraceInterfaceMethodBody(self, interface, method):
228 if not method.sideeffects:
232 self.deserializeThisPointer(interface)
234 self.deserializeArgs(method)
236 self.invokeInterfaceMethod(interface, method)
238 self.swizzleValues(method)
240 def deserializeThisPointer(self, interface):
241 print ' %s *_this;' % (interface.name,)
244 def deserializeArgs(self, function):
246 for arg in function.args:
247 arg_type = ConstRemover().visit(arg.type)
248 #print ' // %s -> %s' % (arg.type, arg_type)
249 print ' %s %s;' % (arg_type, arg.name)
250 rvalue = 'call.arg(%u)' % (arg.index,)
253 self.extractArg(function, arg, arg_type, lvalue, rvalue)
254 except NotImplementedError:
256 print ' %s = 0; // FIXME' % arg.name
260 self.failFunction(function)
261 if function.name[-1].islower():
262 sys.stderr.write('warning: unsupported %s call\n' % function.name)
265 def swizzleValues(self, function):
266 for arg in function.args:
268 arg_type = ConstRemover().visit(arg.type)
269 rvalue = 'call.arg(%u)' % (arg.index,)
272 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
273 except NotImplementedError:
274 print ' // XXX: %s' % arg.name
275 if function.type is not stdapi.Void:
279 self.regiterSwizzledValue(function.type, lvalue, rvalue)
280 except NotImplementedError:
281 print ' // XXX: result'
283 def failFunction(self, function):
284 print ' if (retrace::verbosity >= 0) {'
285 print ' retrace::unsupported(call);'
289 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
290 ValueDeserializer().visit(arg_type, lvalue, rvalue)
292 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
293 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
295 def regiterSwizzledValue(self, type, lvalue, rvalue):
296 visitor = SwizzledValueRegistrator()
297 visitor.visit(type, lvalue, rvalue)
299 def invokeFunction(self, function):
300 arg_names = ", ".join(function.argNames())
301 if function.type is not stdapi.Void:
302 print ' %s __result;' % (function.type)
303 print ' __result = %s(%s);' % (function.name, arg_names)
304 print ' (void)__result;'
306 print ' %s(%s);' % (function.name, arg_names)
308 def invokeInterfaceMethod(self, interface, method):
309 arg_names = ", ".join(method.argNames())
310 if method.type is not stdapi.Void:
311 print ' %s __result;' % (method.type)
312 print ' __result = _this->%s(%s);' % (method.name, arg_names)
313 print ' (void)__result;'
315 print ' _this->%s(%s);' % (method.name, arg_names)
317 def filterFunction(self, function):
320 table_name = 'retrace::callbacks'
322 def retraceApi(self, api):
324 print '#include "trace_parser.hpp"'
325 print '#include "retrace.hpp"'
328 types = api.getAllTypes()
329 handles = [type for type in types if isinstance(type, stdapi.Handle)]
331 for handle in handles:
332 if handle.name not in handle_names:
333 if handle.key is None:
334 print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
336 key_name, key_type = handle.key
337 print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
338 handle_names.add(handle.name)
341 functions = filter(self.filterFunction, api.functions)
342 for function in functions:
343 self.retraceFunction(function)
344 interfaces = api.getAllInterfaces()
345 for interface in interfaces:
346 for method in interface.iterMethods():
347 self.retraceInterfaceMethod(interface, method)
349 print 'const retrace::Entry %s[] = {' % self.table_name
350 for function in functions:
351 print ' {"%s", &retrace_%s},' % (function.name, function.name)
352 for interface in interfaces:
353 for method in interface.iterMethods():
354 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
355 print ' {NULL, NULL}'