1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
36 import specs.stdapi as stdapi
39 class UnsupportedType(Exception):
43 def lookupHandle(handle, value):
44 if handle.key is None:
45 return "_%s_map[%s]" % (handle.name, value)
47 key_name, key_type = handle.key
48 return "_%s_map[%s][%s]" % (handle.name, key_name, value)
51 class ValueAllocator(stdapi.Visitor):
53 def visitLiteral(self, literal, lvalue, rvalue):
56 def visitConst(self, const, lvalue, rvalue):
57 self.visit(const.type, lvalue, rvalue)
59 def visitAlias(self, alias, lvalue, rvalue):
60 self.visit(alias.type, lvalue, rvalue)
62 def visitEnum(self, enum, lvalue, rvalue):
65 def visitBitmask(self, bitmask, lvalue, rvalue):
68 def visitArray(self, array, lvalue, rvalue):
69 print ' %s = static_cast<%s *>(_allocator.alloc(&%s, sizeof *%s));' % (lvalue, array.type, rvalue, lvalue)
71 def visitPointer(self, pointer, lvalue, rvalue):
72 print ' %s = static_cast<%s *>(_allocator.alloc(&%s, sizeof *%s));' % (lvalue, pointer.type, rvalue, lvalue)
74 def visitIntPointer(self, pointer, lvalue, rvalue):
77 def visitObjPointer(self, pointer, lvalue, rvalue):
80 def visitLinearPointer(self, pointer, lvalue, rvalue):
83 def visitReference(self, reference, lvalue, rvalue):
84 self.visit(reference.type, lvalue, rvalue);
86 def visitHandle(self, handle, lvalue, rvalue):
89 def visitBlob(self, blob, lvalue, rvalue):
92 def visitString(self, string, lvalue, rvalue):
95 def visitStruct(self, struct, lvalue, rvalue):
98 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
99 assert polymorphic.defaultType is not None
100 self.visit(polymorphic.defaultType, lvalue, rvalue)
102 def visitOpaque(self, opaque, lvalue, rvalue):
106 class ValueDeserializer(stdapi.Visitor, stdapi.ExpanderMixin):
108 def visitLiteral(self, literal, lvalue, rvalue):
109 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
111 def visitConst(self, const, lvalue, rvalue):
112 self.visit(const.type, lvalue, rvalue)
114 def visitAlias(self, alias, lvalue, rvalue):
115 self.visit(alias.type, lvalue, rvalue)
117 def visitEnum(self, enum, lvalue, rvalue):
118 print ' %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
120 def visitBitmask(self, bitmask, lvalue, rvalue):
121 self.visit(bitmask.type, lvalue, rvalue)
123 def visitArray(self, array, lvalue, rvalue):
125 tmp = '_a_' + array.tag + '_' + str(self.seq)
128 print ' if (%s) {' % (lvalue,)
129 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
130 length = '%s->values.size()' % (tmp,)
131 index = '_j' + array.tag
132 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
134 self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
139 def visitPointer(self, pointer, lvalue, rvalue):
140 tmp = '_a_' + pointer.tag + '_' + str(self.seq)
143 print ' if (%s) {' % (lvalue,)
144 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
146 self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
150 def visitIntPointer(self, pointer, lvalue, rvalue):
151 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
153 def visitObjPointer(self, pointer, lvalue, rvalue):
154 print ' %s = static_cast<%s>(retrace::toObjPointer(call, %s));' % (lvalue, pointer, rvalue)
156 def visitLinearPointer(self, pointer, lvalue, rvalue):
157 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
159 def visitReference(self, reference, lvalue, rvalue):
160 self.visit(reference.type, lvalue, rvalue);
162 def visitHandle(self, handle, lvalue, rvalue):
163 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
164 self.visit(handle.type, lvalue, rvalue);
165 new_lvalue = lookupHandle(handle, lvalue)
166 print ' if (retrace::verbosity >= 2) {'
167 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
169 print ' %s = %s;' % (lvalue, new_lvalue)
171 def visitBlob(self, blob, lvalue, rvalue):
172 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
174 def visitString(self, string, lvalue, rvalue):
175 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
179 def visitStruct(self, struct, lvalue, rvalue):
180 tmp = '_s_' + struct.tag + '_' + str(self.seq)
183 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
184 print ' assert(%s);' % (tmp)
185 for i in range(len(struct.members)):
186 member = struct.members[i]
187 self.visitMember(member, lvalue, '*%s->members[%s]' % (tmp, i))
189 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
190 if polymorphic.defaultType is None:
191 switchExpr = self.expand(polymorphic.switchExpr)
192 print r' switch (%s) {' % switchExpr
193 for cases, type in polymorphic.iterSwitch():
197 if type.expr is not None:
198 caseLvalue = 'static_cast<%s>(%s)' % (type, caseLvalue)
201 self.visit(type, caseLvalue, rvalue)
205 if polymorphic.defaultType is None:
207 print r' retrace::warning(call) << "unexpected polymorphic case" << %s << "\n";' % (switchExpr,)
211 self.visit(polymorphic.defaultType, lvalue, rvalue)
213 def visitOpaque(self, opaque, lvalue, rvalue):
214 raise UnsupportedType
217 class OpaqueValueDeserializer(ValueDeserializer):
218 '''Value extractor that also understands opaque values.
220 Normally opaque values can't be retraced, unless they are being extracted
221 in the context of handles.'''
223 def visitOpaque(self, opaque, lvalue, rvalue):
224 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
227 class SwizzledValueRegistrator(stdapi.Visitor, stdapi.ExpanderMixin):
228 '''Type visitor which will register (un)swizzled value pairs, to later be
231 def visitLiteral(self, literal, lvalue, rvalue):
234 def visitAlias(self, alias, lvalue, rvalue):
235 self.visit(alias.type, lvalue, rvalue)
237 def visitEnum(self, enum, lvalue, rvalue):
240 def visitBitmask(self, bitmask, lvalue, rvalue):
243 def visitArray(self, array, lvalue, rvalue):
244 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
245 print ' if (_a%s) {' % (array.tag)
246 length = '_a%s->values.size()' % array.tag
247 index = '_j' + array.tag
248 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
250 self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
255 def visitPointer(self, pointer, lvalue, rvalue):
256 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
257 print ' if (_a%s) {' % (pointer.tag)
259 self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
263 def visitIntPointer(self, pointer, lvalue, rvalue):
266 def visitObjPointer(self, pointer, lvalue, rvalue):
267 print r' retrace::addObj(call, %s, %s);' % (rvalue, lvalue)
269 def visitLinearPointer(self, pointer, lvalue, rvalue):
270 assert pointer.size is not None
271 if pointer.size is not None:
272 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
274 def visitReference(self, reference, lvalue, rvalue):
277 def visitHandle(self, handle, lvalue, rvalue):
278 print ' %s _origResult;' % handle.type
279 OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
280 if handle.range is None:
281 rvalue = "_origResult"
282 entry = lookupHandle(handle, rvalue)
283 print " %s = %s;" % (entry, lvalue)
284 print ' if (retrace::verbosity >= 2) {'
285 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
288 i = '_h' + handle.tag
289 lvalue = "%s + %s" % (lvalue, i)
290 rvalue = "_origResult + %s" % (i,)
291 entry = lookupHandle(handle, rvalue)
292 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
293 print ' {entry} = {lvalue};'.format(**locals())
294 print ' if (retrace::verbosity >= 2) {'
295 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
299 def visitBlob(self, blob, lvalue, rvalue):
302 def visitString(self, string, lvalue, rvalue):
307 def visitStruct(self, struct, lvalue, rvalue):
308 tmp = '_s_' + struct.tag + '_' + str(self.seq)
311 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
312 print ' assert(%s);' % (tmp,)
313 print ' (void)%s;' % (tmp,)
314 for i in range(len(struct.members)):
315 member = struct.members[i]
316 self.visitMember(member, lvalue, '*%s->members[%s]' % (tmp, i))
318 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
319 assert polymorphic.defaultType is not None
320 self.visit(polymorphic.defaultType, lvalue, rvalue)
322 def visitOpaque(self, opaque, lvalue, rvalue):
328 def retraceFunction(self, function):
329 print 'static void retrace_%s(trace::Call &call) {' % function.name
330 self.retraceFunctionBody(function)
334 def retraceInterfaceMethod(self, interface, method):
335 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
336 self.retraceInterfaceMethodBody(interface, method)
340 def retraceFunctionBody(self, function):
341 assert function.sideeffects
343 if function.type is not stdapi.Void:
344 self.checkOrigResult(function)
346 self.deserializeArgs(function)
348 self.declareRet(function)
349 self.invokeFunction(function)
351 self.swizzleValues(function)
353 def retraceInterfaceMethodBody(self, interface, method):
354 assert method.sideeffects
356 if method.type is not stdapi.Void:
357 self.checkOrigResult(method)
359 self.deserializeThisPointer(interface)
361 self.deserializeArgs(method)
363 self.declareRet(method)
364 self.invokeInterfaceMethod(interface, method)
366 self.swizzleValues(method)
368 def checkOrigResult(self, function):
369 '''Hook for checking the original result, to prevent succeeding now
370 where the original did not, which would cause diversion and potentially
371 unpredictable results.'''
373 assert function.type is not stdapi.Void
375 if str(function.type) == 'HRESULT':
376 print r' if (call.ret && FAILED(call.ret->toSInt())) {'
380 def deserializeThisPointer(self, interface):
381 print r' %s *_this;' % (interface.name,)
382 print r' _this = static_cast<%s *>(retrace::toObjPointer(call, call.arg(0)));' % (interface.name,)
383 print r' if (!_this) {'
387 def deserializeArgs(self, function):
388 print ' retrace::ScopedAllocator _allocator;'
389 print ' (void)_allocator;'
391 for arg in function.args:
392 arg_type = arg.type.mutable()
393 print ' %s %s;' % (arg_type, arg.name)
394 rvalue = 'call.arg(%u)' % (arg.index,)
397 self.extractArg(function, arg, arg_type, lvalue, rvalue)
398 except UnsupportedType:
400 print ' memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
405 self.failFunction(function)
406 sys.stderr.write('warning: unsupported %s call\n' % function.name)
409 def swizzleValues(self, function):
410 for arg in function.args:
412 arg_type = arg.type.mutable()
413 rvalue = 'call.arg(%u)' % (arg.index,)
416 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
417 except UnsupportedType:
418 print ' // XXX: %s' % arg.name
419 if function.type is not stdapi.Void:
423 self.regiterSwizzledValue(function.type, lvalue, rvalue)
424 except UnsupportedType:
426 print ' // XXX: result'
428 def failFunction(self, function):
429 print ' if (retrace::verbosity >= 0) {'
430 print ' retrace::unsupported(call);'
434 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
435 ValueAllocator().visit(arg_type, lvalue, rvalue)
437 ValueDeserializer().visit(arg_type, lvalue, rvalue)
439 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
441 ValueAllocator().visit(arg_type, lvalue, rvalue)
442 except UnsupportedType:
444 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
446 def regiterSwizzledValue(self, type, lvalue, rvalue):
447 visitor = SwizzledValueRegistrator()
448 visitor.visit(type, lvalue, rvalue)
450 def declareRet(self, function):
451 if function.type is not stdapi.Void:
452 print ' %s _result;' % (function.type)
454 def invokeFunction(self, function):
455 arg_names = ", ".join(function.argNames())
456 if function.type is not stdapi.Void:
457 print ' _result = %s(%s);' % (function.name, arg_names)
458 print ' (void)_result;'
459 self.checkResult(function.type)
461 print ' %s(%s);' % (function.name, arg_names)
463 def invokeInterfaceMethod(self, interface, method):
464 # On release our reference when we reach Release() == 0 call in the
466 if method.name == 'Release':
467 print ' if (call.ret->toUInt()) {'
470 print ' retrace::delObj(call.arg(0));'
472 arg_names = ", ".join(method.argNames())
473 if method.type is not stdapi.Void:
474 print ' _result = _this->%s(%s);' % (method.name, arg_names)
475 print ' (void)_result;'
476 self.checkResult(method.type)
478 print ' _this->%s(%s);' % (method.name, arg_names)
480 def checkResult(self, resultType):
481 if str(resultType) == 'HRESULT':
482 print r' if (FAILED(_result)) {'
483 print r' retrace::warning(call) << "failed\n";'
486 def filterFunction(self, function):
489 table_name = 'retrace::callbacks'
491 def retraceApi(self, api):
493 print '#include "os_time.hpp"'
494 print '#include "trace_parser.hpp"'
495 print '#include "retrace.hpp"'
496 print '#include "retrace_swizzle.hpp"'
499 types = api.getAllTypes()
500 handles = [type for type in types if isinstance(type, stdapi.Handle)]
502 for handle in handles:
503 if handle.name not in handle_names:
504 if handle.key is None:
505 print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
507 key_name, key_type = handle.key
508 print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
509 handle_names.add(handle.name)
512 functions = filter(self.filterFunction, api.getAllFunctions())
513 for function in functions:
514 if function.sideeffects and not function.internal:
515 self.retraceFunction(function)
516 interfaces = api.getAllInterfaces()
517 for interface in interfaces:
518 for method in interface.iterMethods():
519 if method.sideeffects and not method.internal:
520 self.retraceInterfaceMethod(interface, method)
522 print 'const retrace::Entry %s[] = {' % self.table_name
523 for function in functions:
524 if not function.internal:
525 if function.sideeffects:
526 print ' {"%s", &retrace_%s},' % (function.name, function.name)
528 print ' {"%s", &retrace::ignore},' % (function.name,)
529 for interface in interfaces:
530 for method in interface.iterMethods():
531 if method.sideeffects:
532 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
534 print ' {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
535 print ' {NULL, NULL}'