1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
36 import specs.stdapi as stdapi
39 class UnsupportedType(Exception):
43 def lookupHandle(handle, value):
44 if handle.key is None:
45 return "_%s_map[%s]" % (handle.name, value)
47 key_name, key_type = handle.key
48 return "_%s_map[%s][%s]" % (handle.name, key_name, value)
51 class ValueAllocator(stdapi.Visitor):
53 def visitLiteral(self, literal, lvalue, rvalue):
56 def visitConst(self, const, lvalue, rvalue):
57 self.visit(const.type, lvalue, rvalue)
59 def visitAlias(self, alias, lvalue, rvalue):
60 self.visit(alias.type, lvalue, rvalue)
62 def visitEnum(self, enum, lvalue, rvalue):
65 def visitBitmask(self, bitmask, lvalue, rvalue):
68 def visitArray(self, array, lvalue, rvalue):
69 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
71 def visitPointer(self, pointer, lvalue, rvalue):
72 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
74 def visitIntPointer(self, pointer, lvalue, rvalue):
77 def visitObjPointer(self, pointer, lvalue, rvalue):
80 def visitLinearPointer(self, pointer, lvalue, rvalue):
83 def visitReference(self, reference, lvalue, rvalue):
84 self.visit(reference.type, lvalue, rvalue);
86 def visitHandle(self, handle, lvalue, rvalue):
89 def visitBlob(self, blob, lvalue, rvalue):
92 def visitString(self, string, lvalue, rvalue):
95 def visitStruct(self, struct, lvalue, rvalue):
98 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
99 self.visit(polymorphic.defaultType, lvalue, rvalue)
101 def visitOpaque(self, opaque, lvalue, rvalue):
105 class ValueDeserializer(stdapi.Visitor):
107 def visitLiteral(self, literal, lvalue, rvalue):
108 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
110 def visitConst(self, const, lvalue, rvalue):
111 self.visit(const.type, lvalue, rvalue)
113 def visitAlias(self, alias, lvalue, rvalue):
114 self.visit(alias.type, lvalue, rvalue)
116 def visitEnum(self, enum, lvalue, rvalue):
117 print ' %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
119 def visitBitmask(self, bitmask, lvalue, rvalue):
120 self.visit(bitmask.type, lvalue, rvalue)
122 def visitArray(self, array, lvalue, rvalue):
124 tmp = '_a_' + array.tag + '_' + str(self.seq)
127 print ' if (%s) {' % (lvalue,)
128 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
129 length = '%s->values.size()' % (tmp,)
130 index = '_j' + array.tag
131 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
133 self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
138 def visitPointer(self, pointer, lvalue, rvalue):
139 tmp = '_a_' + pointer.tag + '_' + str(self.seq)
142 print ' if (%s) {' % (lvalue,)
143 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
145 self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
149 def visitIntPointer(self, pointer, lvalue, rvalue):
150 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
152 def visitObjPointer(self, pointer, lvalue, rvalue):
153 print ' %s = static_cast<%s>(retrace::toObjPointer(call, %s));' % (lvalue, pointer, rvalue)
155 def visitLinearPointer(self, pointer, lvalue, rvalue):
156 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
158 def visitReference(self, reference, lvalue, rvalue):
159 self.visit(reference.type, lvalue, rvalue);
161 def visitHandle(self, handle, lvalue, rvalue):
162 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
163 self.visit(handle.type, lvalue, rvalue);
164 new_lvalue = lookupHandle(handle, lvalue)
165 print ' if (retrace::verbosity >= 2) {'
166 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
168 print ' %s = %s;' % (lvalue, new_lvalue)
170 def visitBlob(self, blob, lvalue, rvalue):
171 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
173 def visitString(self, string, lvalue, rvalue):
174 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
178 def visitStruct(self, struct, lvalue, rvalue):
179 tmp = '_s_' + struct.tag + '_' + str(self.seq)
182 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
183 print ' assert(%s);' % (tmp)
184 for i in range(len(struct.members)):
185 member_type, member_name = struct.members[i]
186 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
188 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
189 self.visit(polymorphic.defaultType, lvalue, rvalue)
191 def visitOpaque(self, opaque, lvalue, rvalue):
192 raise UnsupportedType
195 class OpaqueValueDeserializer(ValueDeserializer):
196 '''Value extractor that also understands opaque values.
198 Normally opaque values can't be retraced, unless they are being extracted
199 in the context of handles.'''
201 def visitOpaque(self, opaque, lvalue, rvalue):
202 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
205 class SwizzledValueRegistrator(stdapi.Visitor):
206 '''Type visitor which will register (un)swizzled value pairs, to later be
209 def visitLiteral(self, literal, lvalue, rvalue):
212 def visitAlias(self, alias, lvalue, rvalue):
213 self.visit(alias.type, lvalue, rvalue)
215 def visitEnum(self, enum, lvalue, rvalue):
218 def visitBitmask(self, bitmask, lvalue, rvalue):
221 def visitArray(self, array, lvalue, rvalue):
222 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
223 print ' if (_a%s) {' % (array.tag)
224 length = '_a%s->values.size()' % array.tag
225 index = '_j' + array.tag
226 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
228 self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
233 def visitPointer(self, pointer, lvalue, rvalue):
234 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
235 print ' if (_a%s) {' % (pointer.tag)
237 self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
241 def visitIntPointer(self, pointer, lvalue, rvalue):
244 def visitObjPointer(self, pointer, lvalue, rvalue):
245 print r' retrace::addObj(call, %s, %s);' % (rvalue, lvalue)
247 def visitLinearPointer(self, pointer, lvalue, rvalue):
248 assert pointer.size is not None
249 if pointer.size is not None:
250 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
252 def visitReference(self, reference, lvalue, rvalue):
255 def visitHandle(self, handle, lvalue, rvalue):
256 print ' %s _origResult;' % handle.type
257 OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
258 if handle.range is None:
259 rvalue = "_origResult"
260 entry = lookupHandle(handle, rvalue)
261 print " %s = %s;" % (entry, lvalue)
262 print ' if (retrace::verbosity >= 2) {'
263 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
266 i = '_h' + handle.tag
267 lvalue = "%s + %s" % (lvalue, i)
268 rvalue = "_origResult + %s" % (i,)
269 entry = lookupHandle(handle, rvalue)
270 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
271 print ' {entry} = {lvalue};'.format(**locals())
272 print ' if (retrace::verbosity >= 2) {'
273 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
277 def visitBlob(self, blob, lvalue, rvalue):
280 def visitString(self, string, lvalue, rvalue):
285 def visitStruct(self, struct, lvalue, rvalue):
286 tmp = '_s_' + struct.tag + '_' + str(self.seq)
289 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
290 print ' assert(%s);' % (tmp,)
291 print ' (void)%s;' % (tmp,)
292 for i in range(len(struct.members)):
293 member_type, member_name = struct.members[i]
294 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
296 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
297 self.visit(polymorphic.defaultType, lvalue, rvalue)
299 def visitOpaque(self, opaque, lvalue, rvalue):
305 def retraceFunction(self, function):
306 print 'static void retrace_%s(trace::Call &call) {' % function.name
307 self.retraceFunctionBody(function)
311 def retraceInterfaceMethod(self, interface, method):
312 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
313 self.retraceInterfaceMethodBody(interface, method)
317 def retraceFunctionBody(self, function):
318 assert function.sideeffects
320 if function.type is not stdapi.Void:
321 self.checkOrigResult(function)
323 self.deserializeArgs(function)
325 self.declareRet(function)
326 self.invokeFunction(function)
328 self.swizzleValues(function)
330 def retraceInterfaceMethodBody(self, interface, method):
331 assert method.sideeffects
333 if method.type is not stdapi.Void:
334 self.checkOrigResult(method)
336 self.deserializeThisPointer(interface)
338 self.deserializeArgs(method)
340 self.declareRet(method)
341 self.invokeInterfaceMethod(interface, method)
343 self.swizzleValues(method)
345 def checkOrigResult(self, function):
346 '''Hook for checking the original result, to prevent succeeding now
347 where the original did not, which would cause diversion and potentially
348 unpredictable results.'''
350 assert function.type is not stdapi.Void
352 if str(function.type) == 'HRESULT':
353 print r' if (call.ret && FAILED(call.ret->toSInt())) {'
357 def deserializeThisPointer(self, interface):
358 print r' %s *_this;' % (interface.name,)
359 print r' _this = static_cast<%s *>(retrace::toObjPointer(call, call.arg(0)));' % (interface.name,)
360 print r' if (!_this) {'
364 def deserializeArgs(self, function):
365 print ' retrace::ScopedAllocator _allocator;'
366 print ' (void)_allocator;'
368 for arg in function.args:
369 arg_type = arg.type.mutable()
370 print ' %s %s;' % (arg_type, arg.name)
371 rvalue = 'call.arg(%u)' % (arg.index,)
374 self.extractArg(function, arg, arg_type, lvalue, rvalue)
375 except UnsupportedType:
377 print ' memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
382 self.failFunction(function)
383 if function.name[-1].islower():
384 sys.stderr.write('warning: unsupported %s call\n' % function.name)
387 def swizzleValues(self, function):
388 for arg in function.args:
390 arg_type = arg.type.mutable()
391 rvalue = 'call.arg(%u)' % (arg.index,)
394 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
395 except UnsupportedType:
396 print ' // XXX: %s' % arg.name
397 if function.type is not stdapi.Void:
401 self.regiterSwizzledValue(function.type, lvalue, rvalue)
402 except UnsupportedType:
404 print ' // XXX: result'
406 def failFunction(self, function):
407 print ' if (retrace::verbosity >= 0) {'
408 print ' retrace::unsupported(call);'
412 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
413 ValueAllocator().visit(arg_type, lvalue, rvalue)
415 ValueDeserializer().visit(arg_type, lvalue, rvalue)
417 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
419 ValueAllocator().visit(arg_type, lvalue, rvalue)
420 except UnsupportedType:
422 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
424 def regiterSwizzledValue(self, type, lvalue, rvalue):
425 visitor = SwizzledValueRegistrator()
426 visitor.visit(type, lvalue, rvalue)
428 def declareRet(self, function):
429 if function.type is not stdapi.Void:
430 print ' %s _result;' % (function.type)
432 def invokeFunction(self, function):
433 arg_names = ", ".join(function.argNames())
434 if function.type is not stdapi.Void:
435 print ' _result = %s(%s);' % (function.name, arg_names)
436 print ' (void)_result;'
437 self.checkResult(function.type)
439 print ' %s(%s);' % (function.name, arg_names)
441 def invokeInterfaceMethod(self, interface, method):
442 # On release our reference when we reach Release() == 0 call in the
444 if method.name == 'Release':
445 print ' if (call.ret->toUInt()) {'
448 print ' retrace::delObj(call.arg(0));'
450 arg_names = ", ".join(method.argNames())
451 if method.type is not stdapi.Void:
452 print ' _result = _this->%s(%s);' % (method.name, arg_names)
453 print ' (void)_result;'
454 self.checkResult(method.type)
456 print ' _this->%s(%s);' % (method.name, arg_names)
458 def checkResult(self, resultType):
459 if str(resultType) == 'HRESULT':
460 print r' if (FAILED(_result)) {'
461 print r' retrace::warning(call) << "failed\n";'
464 def filterFunction(self, function):
467 table_name = 'retrace::callbacks'
469 def retraceApi(self, api):
471 print '#include "os_time.hpp"'
472 print '#include "trace_parser.hpp"'
473 print '#include "retrace.hpp"'
474 print '#include "retrace_swizzle.hpp"'
477 types = api.getAllTypes()
478 handles = [type for type in types if isinstance(type, stdapi.Handle)]
480 for handle in handles:
481 if handle.name not in handle_names:
482 if handle.key is None:
483 print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
485 key_name, key_type = handle.key
486 print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
487 handle_names.add(handle.name)
490 functions = filter(self.filterFunction, api.getAllFunctions())
491 for function in functions:
492 if function.sideeffects and not function.internal:
493 self.retraceFunction(function)
494 interfaces = api.getAllInterfaces()
495 for interface in interfaces:
496 for method in interface.iterMethods():
497 if method.sideeffects and not method.internal:
498 self.retraceInterfaceMethod(interface, method)
500 print 'const retrace::Entry %s[] = {' % self.table_name
501 for function in functions:
502 if not function.internal:
503 if function.sideeffects:
504 print ' {"%s", &retrace_%s},' % (function.name, function.name)
506 print ' {"%s", &retrace::ignore},' % (function.name,)
507 for interface in interfaces:
508 for method in interface.iterMethods():
509 if method.sideeffects:
510 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
512 print ' {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
513 print ' {NULL, NULL}'