1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
36 import specs.stdapi as stdapi
39 class UnsupportedType(Exception):
43 def lookupHandle(handle, value):
44 if handle.key is None:
45 return "_%s_map[%s]" % (handle.name, value)
47 key_name, key_type = handle.key
48 return "_%s_map[%s][%s]" % (handle.name, key_name, value)
51 class ValueAllocator(stdapi.Visitor):
53 def visitLiteral(self, literal, lvalue, rvalue):
56 def visitConst(self, const, lvalue, rvalue):
57 self.visit(const.type, lvalue, rvalue)
59 def visitAlias(self, alias, lvalue, rvalue):
60 self.visit(alias.type, lvalue, rvalue)
62 def visitEnum(self, enum, lvalue, rvalue):
65 def visitBitmask(self, bitmask, lvalue, rvalue):
68 def visitArray(self, array, lvalue, rvalue):
69 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
71 def visitPointer(self, pointer, lvalue, rvalue):
72 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
74 def visitIntPointer(self, pointer, lvalue, rvalue):
77 def visitObjPointer(self, pointer, lvalue, rvalue):
80 def visitLinearPointer(self, pointer, lvalue, rvalue):
83 def visitReference(self, reference, lvalue, rvalue):
84 self.visit(reference.type, lvalue, rvalue);
86 def visitHandle(self, handle, lvalue, rvalue):
89 def visitBlob(self, blob, lvalue, rvalue):
92 def visitString(self, string, lvalue, rvalue):
95 def visitStruct(self, struct, lvalue, rvalue):
98 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
99 if polymorphic.defaultType is None:
101 raise UnsupportedType
102 self.visit(polymorphic.defaultType, lvalue, rvalue)
104 def visitOpaque(self, opaque, lvalue, rvalue):
108 class ValueDeserializer(stdapi.Visitor):
110 def visitLiteral(self, literal, lvalue, rvalue):
111 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
113 def visitConst(self, const, lvalue, rvalue):
114 self.visit(const.type, lvalue, rvalue)
116 def visitAlias(self, alias, lvalue, rvalue):
117 self.visit(alias.type, lvalue, rvalue)
119 def visitEnum(self, enum, lvalue, rvalue):
120 print ' %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
122 def visitBitmask(self, bitmask, lvalue, rvalue):
123 self.visit(bitmask.type, lvalue, rvalue)
125 def visitArray(self, array, lvalue, rvalue):
127 tmp = '_a_' + array.tag + '_' + str(self.seq)
130 print ' if (%s) {' % (lvalue,)
131 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
132 length = '%s->values.size()' % (tmp,)
133 index = '_j' + array.tag
134 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
136 self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
141 def visitPointer(self, pointer, lvalue, rvalue):
142 tmp = '_a_' + pointer.tag + '_' + str(self.seq)
145 print ' if (%s) {' % (lvalue,)
146 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
148 self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
152 def visitIntPointer(self, pointer, lvalue, rvalue):
153 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
155 def visitObjPointer(self, pointer, lvalue, rvalue):
156 print ' %s = static_cast<%s>(retrace::toObjPointer(call, %s));' % (lvalue, pointer, rvalue)
158 def visitLinearPointer(self, pointer, lvalue, rvalue):
159 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
161 def visitReference(self, reference, lvalue, rvalue):
162 self.visit(reference.type, lvalue, rvalue);
164 def visitHandle(self, handle, lvalue, rvalue):
165 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
166 self.visit(handle.type, lvalue, rvalue);
167 new_lvalue = lookupHandle(handle, lvalue)
168 print ' if (retrace::verbosity >= 2) {'
169 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
171 print ' %s = %s;' % (lvalue, new_lvalue)
173 def visitBlob(self, blob, lvalue, rvalue):
174 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
176 def visitString(self, string, lvalue, rvalue):
177 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
181 def visitStruct(self, struct, lvalue, rvalue):
182 tmp = '_s_' + struct.tag + '_' + str(self.seq)
185 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
186 print ' assert(%s);' % (tmp)
187 for i in range(len(struct.members)):
188 member_type, member_name = struct.members[i]
189 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
191 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
192 if polymorphic.defaultType is None:
194 raise UnsupportedType
195 self.visit(polymorphic.defaultType, lvalue, rvalue)
197 def visitOpaque(self, opaque, lvalue, rvalue):
198 raise UnsupportedType
201 class OpaqueValueDeserializer(ValueDeserializer):
202 '''Value extractor that also understands opaque values.
204 Normally opaque values can't be retraced, unless they are being extracted
205 in the context of handles.'''
207 def visitOpaque(self, opaque, lvalue, rvalue):
208 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
211 class SwizzledValueRegistrator(stdapi.Visitor):
212 '''Type visitor which will register (un)swizzled value pairs, to later be
215 def visitLiteral(self, literal, lvalue, rvalue):
218 def visitAlias(self, alias, lvalue, rvalue):
219 self.visit(alias.type, lvalue, rvalue)
221 def visitEnum(self, enum, lvalue, rvalue):
224 def visitBitmask(self, bitmask, lvalue, rvalue):
227 def visitArray(self, array, lvalue, rvalue):
228 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
229 print ' if (_a%s) {' % (array.tag)
230 length = '_a%s->values.size()' % array.tag
231 index = '_j' + array.tag
232 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
234 self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
239 def visitPointer(self, pointer, lvalue, rvalue):
240 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
241 print ' if (_a%s) {' % (pointer.tag)
243 self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
247 def visitIntPointer(self, pointer, lvalue, rvalue):
250 def visitObjPointer(self, pointer, lvalue, rvalue):
251 print r' retrace::addObj(call, %s, %s);' % (rvalue, lvalue)
253 def visitLinearPointer(self, pointer, lvalue, rvalue):
254 assert pointer.size is not None
255 if pointer.size is not None:
256 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
258 def visitReference(self, reference, lvalue, rvalue):
261 def visitHandle(self, handle, lvalue, rvalue):
262 print ' %s _origResult;' % handle.type
263 OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
264 if handle.range is None:
265 rvalue = "_origResult"
266 entry = lookupHandle(handle, rvalue)
267 print " %s = %s;" % (entry, lvalue)
268 print ' if (retrace::verbosity >= 2) {'
269 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
272 i = '_h' + handle.tag
273 lvalue = "%s + %s" % (lvalue, i)
274 rvalue = "_origResult + %s" % (i,)
275 entry = lookupHandle(handle, rvalue)
276 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
277 print ' {entry} = {lvalue};'.format(**locals())
278 print ' if (retrace::verbosity >= 2) {'
279 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
283 def visitBlob(self, blob, lvalue, rvalue):
286 def visitString(self, string, lvalue, rvalue):
291 def visitStruct(self, struct, lvalue, rvalue):
292 tmp = '_s_' + struct.tag + '_' + str(self.seq)
295 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
296 print ' assert(%s);' % (tmp,)
297 print ' (void)%s;' % (tmp,)
298 for i in range(len(struct.members)):
299 member_type, member_name = struct.members[i]
300 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
302 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
303 if polymorphic.defaultType is None:
305 raise UnsupportedType
306 self.visit(polymorphic.defaultType, lvalue, rvalue)
308 def visitOpaque(self, opaque, lvalue, rvalue):
314 def retraceFunction(self, function):
315 print 'static void retrace_%s(trace::Call &call) {' % function.name
316 self.retraceFunctionBody(function)
320 def retraceInterfaceMethod(self, interface, method):
321 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
322 self.retraceInterfaceMethodBody(interface, method)
326 def retraceFunctionBody(self, function):
327 assert function.sideeffects
329 if function.type is not stdapi.Void:
330 self.checkOrigResult(function)
332 self.deserializeArgs(function)
334 self.declareRet(function)
335 self.invokeFunction(function)
337 self.swizzleValues(function)
339 def retraceInterfaceMethodBody(self, interface, method):
340 assert method.sideeffects
342 if method.type is not stdapi.Void:
343 self.checkOrigResult(method)
345 self.deserializeThisPointer(interface)
347 self.deserializeArgs(method)
349 self.declareRet(method)
350 self.invokeInterfaceMethod(interface, method)
352 self.swizzleValues(method)
354 def checkOrigResult(self, function):
355 '''Hook for checking the original result, to prevent succeeding now
356 where the original did not, which would cause diversion and potentially
357 unpredictable results.'''
359 assert function.type is not stdapi.Void
361 if str(function.type) == 'HRESULT':
362 print r' if (call.ret && FAILED(call.ret->toSInt())) {'
366 def deserializeThisPointer(self, interface):
367 print r' %s *_this;' % (interface.name,)
368 print r' _this = static_cast<%s *>(retrace::toObjPointer(call, call.arg(0)));' % (interface.name,)
369 print r' if (!_this) {'
373 def deserializeArgs(self, function):
374 print ' retrace::ScopedAllocator _allocator;'
375 print ' (void)_allocator;'
377 for arg in function.args:
378 arg_type = arg.type.mutable()
379 print ' %s %s;' % (arg_type, arg.name)
380 rvalue = 'call.arg(%u)' % (arg.index,)
383 self.extractArg(function, arg, arg_type, lvalue, rvalue)
384 except UnsupportedType:
386 print ' memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
391 self.failFunction(function)
392 sys.stderr.write('warning: unsupported %s call\n' % function.name)
395 def swizzleValues(self, function):
396 for arg in function.args:
398 arg_type = arg.type.mutable()
399 rvalue = 'call.arg(%u)' % (arg.index,)
402 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
403 except UnsupportedType:
404 print ' // XXX: %s' % arg.name
405 if function.type is not stdapi.Void:
409 self.regiterSwizzledValue(function.type, lvalue, rvalue)
410 except UnsupportedType:
412 print ' // XXX: result'
414 def failFunction(self, function):
415 print ' if (retrace::verbosity >= 0) {'
416 print ' retrace::unsupported(call);'
420 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
421 ValueAllocator().visit(arg_type, lvalue, rvalue)
423 ValueDeserializer().visit(arg_type, lvalue, rvalue)
425 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
427 ValueAllocator().visit(arg_type, lvalue, rvalue)
428 except UnsupportedType:
430 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
432 def regiterSwizzledValue(self, type, lvalue, rvalue):
433 visitor = SwizzledValueRegistrator()
434 visitor.visit(type, lvalue, rvalue)
436 def declareRet(self, function):
437 if function.type is not stdapi.Void:
438 print ' %s _result;' % (function.type)
440 def invokeFunction(self, function):
441 arg_names = ", ".join(function.argNames())
442 if function.type is not stdapi.Void:
443 print ' _result = %s(%s);' % (function.name, arg_names)
444 print ' (void)_result;'
445 self.checkResult(function.type)
447 print ' %s(%s);' % (function.name, arg_names)
449 def invokeInterfaceMethod(self, interface, method):
450 # On release our reference when we reach Release() == 0 call in the
452 if method.name == 'Release':
453 print ' if (call.ret->toUInt()) {'
456 print ' retrace::delObj(call.arg(0));'
458 arg_names = ", ".join(method.argNames())
459 if method.type is not stdapi.Void:
460 print ' _result = _this->%s(%s);' % (method.name, arg_names)
461 print ' (void)_result;'
462 self.checkResult(method.type)
464 print ' _this->%s(%s);' % (method.name, arg_names)
466 def checkResult(self, resultType):
467 if str(resultType) == 'HRESULT':
468 print r' if (FAILED(_result)) {'
469 print r' retrace::warning(call) << "failed\n";'
472 def filterFunction(self, function):
475 table_name = 'retrace::callbacks'
477 def retraceApi(self, api):
479 print '#include "os_time.hpp"'
480 print '#include "trace_parser.hpp"'
481 print '#include "retrace.hpp"'
482 print '#include "retrace_swizzle.hpp"'
485 types = api.getAllTypes()
486 handles = [type for type in types if isinstance(type, stdapi.Handle)]
488 for handle in handles:
489 if handle.name not in handle_names:
490 if handle.key is None:
491 print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
493 key_name, key_type = handle.key
494 print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
495 handle_names.add(handle.name)
498 functions = filter(self.filterFunction, api.getAllFunctions())
499 for function in functions:
500 if function.sideeffects and not function.internal:
501 self.retraceFunction(function)
502 interfaces = api.getAllInterfaces()
503 for interface in interfaces:
504 for method in interface.iterMethods():
505 if method.sideeffects and not method.internal:
506 self.retraceInterfaceMethod(interface, method)
508 print 'const retrace::Entry %s[] = {' % self.table_name
509 for function in functions:
510 if not function.internal:
511 if function.sideeffects:
512 print ' {"%s", &retrace_%s},' % (function.name, function.name)
514 print ' {"%s", &retrace::ignore},' % (function.name,)
515 for interface in interfaces:
516 for method in interface.iterMethods():
517 if method.sideeffects:
518 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
520 print ' {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
521 print ' {NULL, NULL}'