1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
36 class MutableRebuilder(stdapi.Rebuilder):
37 '''Type visitor which derives a mutable type.'''
39 def visitConst(self, const):
40 # Strip out const qualifier
43 def visitAlias(self, alias):
44 # Tear the alias on type changes
45 type = self.visit(alias.type)
46 if type is alias.type:
50 def visitReference(self, reference):
51 # Strip out references
54 def visitOpaque(self, opaque):
59 def lookupHandle(handle, value):
60 if handle.key is None:
61 return "__%s_map[%s]" % (handle.name, value)
63 key_name, key_type = handle.key
64 return "__%s_map[%s][%s]" % (handle.name, key_name, value)
67 class ValueDeserializer(stdapi.Visitor):
69 def visitLiteral(self, literal, lvalue, rvalue):
70 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
72 def visitConst(self, const, lvalue, rvalue):
73 self.visit(const.type, lvalue, rvalue)
75 def visitAlias(self, alias, lvalue, rvalue):
76 self.visit(alias.type, lvalue, rvalue)
78 def visitEnum(self, enum, lvalue, rvalue):
79 print ' %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
81 def visitBitmask(self, bitmask, lvalue, rvalue):
82 self.visit(bitmask.type, lvalue, rvalue)
86 def visitArray(self, array, lvalue, rvalue):
87 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
88 length = '__a%s->values.size()' % array.tag
89 allocated = self.allocated
91 print ' if (__a%s) {' % (array.tag)
92 print ' %s = _allocator.alloc<%s>(%s);' % (lvalue, array.type, length)
94 index = '__j' + array.tag
95 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
97 self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
102 print ' %s = NULL;' % lvalue
105 def visitPointer(self, pointer, lvalue, rvalue):
106 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
107 allocated = self.allocated
109 print ' if (__a%s) {' % (pointer.tag)
110 print ' %s = _allocator.alloc<%s>();' % (lvalue, pointer.type)
111 self.allocated = True
113 self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
117 print ' %s = NULL;' % lvalue
120 def visitIntPointer(self, pointer, lvalue, rvalue):
121 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
123 def visitLinearPointer(self, pointer, lvalue, rvalue):
124 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
126 def visitReference(self, reference, lvalue, rvalue):
127 self.visit(reference.type, lvalue, rvalue);
129 def visitHandle(self, handle, lvalue, rvalue):
130 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
131 self.visit(handle.type, lvalue, rvalue);
132 new_lvalue = lookupHandle(handle, lvalue)
133 print ' if (retrace::verbosity >= 2) {'
134 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
136 print ' %s = %s;' % (lvalue, new_lvalue)
138 def visitBlob(self, blob, lvalue, rvalue):
139 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
141 def visitString(self, string, lvalue, rvalue):
142 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
146 def visitStruct(self, struct, lvalue, rvalue):
147 tmp = '__s_' + struct.tag + '_' + str(self.seq)
150 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
151 print ' assert(%s);' % (tmp)
152 self.allocated = True
153 for i in range(len(struct.members)):
154 member_type, member_name = struct.members[i]
155 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
158 class OpaqueValueDeserializer(ValueDeserializer):
159 '''Value extractor that also understands opaque values.
161 Normally opaque values can't be retraced, unless they are being extracted
162 in the context of handles.'''
164 def visitOpaque(self, opaque, lvalue, rvalue):
165 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
168 class SwizzledValueRegistrator(stdapi.Visitor):
169 '''Type visitor which will register (un)swizzled value pairs, to later be
172 def visitLiteral(self, literal, lvalue, rvalue):
175 def visitAlias(self, alias, lvalue, rvalue):
176 self.visit(alias.type, lvalue, rvalue)
178 def visitEnum(self, enum, lvalue, rvalue):
181 def visitBitmask(self, bitmask, lvalue, rvalue):
184 def visitArray(self, array, lvalue, rvalue):
185 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
186 print ' if (__a%s) {' % (array.tag)
187 length = '__a%s->values.size()' % array.tag
188 index = '__j' + array.tag
189 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
191 self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
196 def visitPointer(self, pointer, lvalue, rvalue):
197 print ' const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
198 print ' if (__a%s) {' % (pointer.tag)
200 self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
204 def visitIntPointer(self, pointer, lvalue, rvalue):
207 def visitLinearPointer(self, pointer, lvalue, rvalue):
208 assert pointer.size is not None
209 if pointer.size is not None:
210 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
212 def visitReference(self, reference, lvalue, rvalue):
215 def visitHandle(self, handle, lvalue, rvalue):
216 print ' %s __orig_result;' % handle.type
217 OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
218 if handle.range is None:
219 rvalue = "__orig_result"
220 entry = lookupHandle(handle, rvalue)
221 print " %s = %s;" % (entry, lvalue)
222 print ' if (retrace::verbosity >= 2) {'
223 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
226 i = '__h' + handle.tag
227 lvalue = "%s + %s" % (lvalue, i)
228 rvalue = "__orig_result + %s" % (i,)
229 entry = lookupHandle(handle, rvalue)
230 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
231 print ' {entry} = {lvalue};'.format(**locals())
232 print ' if (retrace::verbosity >= 2) {'
233 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
237 def visitBlob(self, blob, lvalue, rvalue):
240 def visitString(self, string, lvalue, rvalue):
246 def retraceFunction(self, function):
247 print 'static void retrace_%s(trace::Call &call) {' % function.name
248 self.retraceFunctionBody(function)
252 def retraceInterfaceMethod(self, interface, method):
253 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
254 self.retraceInterfaceMethodBody(interface, method)
258 def retraceFunctionBody(self, function):
259 if not function.sideeffects:
263 self.deserializeArgs(function)
265 self.invokeFunction(function)
267 self.swizzleValues(function)
269 def retraceInterfaceMethodBody(self, interface, method):
270 if not method.sideeffects:
274 self.deserializeThisPointer(interface)
276 self.deserializeArgs(method)
278 self.invokeInterfaceMethod(interface, method)
280 self.swizzleValues(method)
282 def deserializeThisPointer(self, interface):
283 print ' %s *_this;' % (interface.name,)
286 def deserializeArgs(self, function):
287 print ' retrace::ScopedAllocator _allocator;'
288 print ' (void)_allocator;'
290 for arg in function.args:
291 arg_type = MutableRebuilder().visit(arg.type)
292 #print ' // %s -> %s' % (arg.type, arg_type)
293 print ' %s %s;' % (arg_type, arg.name)
294 rvalue = 'call.arg(%u)' % (arg.index,)
297 self.extractArg(function, arg, arg_type, lvalue, rvalue)
298 except NotImplementedError:
300 print ' memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
304 self.failFunction(function)
305 if function.name[-1].islower():
306 sys.stderr.write('warning: unsupported %s call\n' % function.name)
309 def swizzleValues(self, function):
310 for arg in function.args:
312 arg_type = MutableRebuilder().visit(arg.type)
313 rvalue = 'call.arg(%u)' % (arg.index,)
316 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
317 except NotImplementedError:
318 print ' // XXX: %s' % arg.name
319 if function.type is not stdapi.Void:
323 self.regiterSwizzledValue(function.type, lvalue, rvalue)
324 except NotImplementedError:
325 print ' // XXX: result'
327 def failFunction(self, function):
328 print ' if (retrace::verbosity >= 0) {'
329 print ' retrace::unsupported(call);'
333 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
334 ValueDeserializer().visit(arg_type, lvalue, rvalue)
336 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
337 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
339 def regiterSwizzledValue(self, type, lvalue, rvalue):
340 visitor = SwizzledValueRegistrator()
341 visitor.visit(type, lvalue, rvalue)
343 def invokeFunction(self, function):
344 arg_names = ", ".join(function.argNames())
345 if function.type is not stdapi.Void:
346 print ' %s __result;' % (function.type)
347 print ' __result = %s(%s);' % (function.name, arg_names)
348 print ' (void)__result;'
350 print ' %s(%s);' % (function.name, arg_names)
352 def invokeInterfaceMethod(self, interface, method):
353 arg_names = ", ".join(method.argNames())
354 if method.type is not stdapi.Void:
355 print ' %s __result;' % (method.type)
356 print ' __result = _this->%s(%s);' % (method.name, arg_names)
357 print ' (void)__result;'
359 print ' _this->%s(%s);' % (method.name, arg_names)
361 def filterFunction(self, function):
364 table_name = 'retrace::callbacks'
366 def retraceApi(self, api):
368 print '#include "os_time.hpp"'
369 print '#include "trace_parser.hpp"'
370 print '#include "retrace.hpp"'
373 types = api.getAllTypes()
374 handles = [type for type in types if isinstance(type, stdapi.Handle)]
376 for handle in handles:
377 if handle.name not in handle_names:
378 if handle.key is None:
379 print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
381 key_name, key_type = handle.key
382 print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
383 handle_names.add(handle.name)
386 functions = filter(self.filterFunction, api.functions)
387 for function in functions:
388 self.retraceFunction(function)
389 interfaces = api.getAllInterfaces()
390 for interface in interfaces:
391 for method in interface.iterMethods():
392 self.retraceInterfaceMethod(interface, method)
394 print 'const retrace::Entry %s[] = {' % self.table_name
395 for function in functions:
396 print ' {"%s", &retrace_%s},' % (function.name, function.name)
397 for interface in interfaces:
398 for method in interface.iterMethods():
399 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
400 print ' {NULL, NULL}'