1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
36 import specs.stdapi as stdapi
37 import specs.glapi as glapi
40 class UnsupportedType(Exception):
44 def lookupHandle(handle, value):
45 if handle.key is None:
46 return "_%s_map[%s]" % (handle.name, value)
48 key_name, key_type = handle.key
49 return "_%s_map[%s][%s]" % (handle.name, key_name, value)
52 class ValueAllocator(stdapi.Visitor):
54 def visitLiteral(self, literal, lvalue, rvalue):
57 def visitConst(self, const, lvalue, rvalue):
58 self.visit(const.type, lvalue, rvalue)
60 def visitAlias(self, alias, lvalue, rvalue):
61 self.visit(alias.type, lvalue, rvalue)
63 def visitEnum(self, enum, lvalue, rvalue):
66 def visitBitmask(self, bitmask, lvalue, rvalue):
69 def visitArray(self, array, lvalue, rvalue):
70 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
72 def visitPointer(self, pointer, lvalue, rvalue):
73 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
75 def visitIntPointer(self, pointer, lvalue, rvalue):
78 def visitObjPointer(self, pointer, lvalue, rvalue):
81 def visitLinearPointer(self, pointer, lvalue, rvalue):
84 def visitReference(self, reference, lvalue, rvalue):
85 self.visit(reference.type, lvalue, rvalue);
87 def visitHandle(self, handle, lvalue, rvalue):
90 def visitBlob(self, blob, lvalue, rvalue):
93 def visitString(self, string, lvalue, rvalue):
96 def visitStruct(self, struct, lvalue, rvalue):
99 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
100 self.visit(polymorphic.defaultType, lvalue, rvalue)
102 def visitOpaque(self, opaque, lvalue, rvalue):
106 class ValueDeserializer(stdapi.Visitor):
108 def visitLiteral(self, literal, lvalue, rvalue):
109 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
111 def visitConst(self, const, lvalue, rvalue):
112 self.visit(const.type, lvalue, rvalue)
114 def visitAlias(self, alias, lvalue, rvalue):
115 self.visit(alias.type, lvalue, rvalue)
117 def visitEnum(self, enum, lvalue, rvalue):
118 print ' %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
120 def visitBitmask(self, bitmask, lvalue, rvalue):
121 self.visit(bitmask.type, lvalue, rvalue)
123 def visitArray(self, array, lvalue, rvalue):
125 tmp = '_a_' + array.tag + '_' + str(self.seq)
128 print ' if (%s) {' % (lvalue,)
129 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
130 length = '%s->values.size()' % (tmp,)
131 index = '_j' + array.tag
132 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
134 self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
139 def visitPointer(self, pointer, lvalue, rvalue):
140 tmp = '_a_' + pointer.tag + '_' + str(self.seq)
143 print ' if (%s) {' % (lvalue,)
144 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
146 self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
150 def visitIntPointer(self, pointer, lvalue, rvalue):
151 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
153 def visitObjPointer(self, pointer, lvalue, rvalue):
154 old_lvalue = '(%s).toUIntPtr()' % (rvalue,)
155 new_lvalue = '_obj_map[%s]' % (old_lvalue,)
156 print ' if (retrace::verbosity >= 2) {'
157 print ' std::cout << std::hex << "obj 0x" << size_t(%s) << " <- 0x" << size_t(%s) << std::dec <<"\\n";' % (old_lvalue, new_lvalue)
159 print ' %s = static_cast<%s>(%s);' % (lvalue, pointer, new_lvalue)
161 def visitLinearPointer(self, pointer, lvalue, rvalue):
162 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
164 def visitReference(self, reference, lvalue, rvalue):
165 self.visit(reference.type, lvalue, rvalue);
167 def visitHandle(self, handle, lvalue, rvalue):
168 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
169 self.visit(handle.type, lvalue, rvalue);
170 new_lvalue = lookupHandle(handle, lvalue)
171 print ' if (retrace::verbosity >= 2) {'
172 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
174 print ' %s = %s;' % (lvalue, new_lvalue)
176 def visitBlob(self, blob, lvalue, rvalue):
177 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
179 def visitString(self, string, lvalue, rvalue):
180 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
184 def visitStruct(self, struct, lvalue, rvalue):
185 tmp = '_s_' + struct.tag + '_' + str(self.seq)
188 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
189 print ' assert(%s);' % (tmp)
190 for i in range(len(struct.members)):
191 member_type, member_name = struct.members[i]
192 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
194 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
195 self.visit(polymorphic.defaultType, lvalue, rvalue)
197 def visitOpaque(self, opaque, lvalue, rvalue):
198 raise UnsupportedType
201 class OpaqueValueDeserializer(ValueDeserializer):
202 '''Value extractor that also understands opaque values.
204 Normally opaque values can't be retraced, unless they are being extracted
205 in the context of handles.'''
207 def visitOpaque(self, opaque, lvalue, rvalue):
208 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
211 class SwizzledValueRegistrator(stdapi.Visitor):
212 '''Type visitor which will register (un)swizzled value pairs, to later be
215 def visitLiteral(self, literal, lvalue, rvalue):
218 def visitAlias(self, alias, lvalue, rvalue):
219 self.visit(alias.type, lvalue, rvalue)
221 def visitEnum(self, enum, lvalue, rvalue):
224 def visitBitmask(self, bitmask, lvalue, rvalue):
227 def visitArray(self, array, lvalue, rvalue):
228 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
229 print ' if (_a%s) {' % (array.tag)
230 length = '_a%s->values.size()' % array.tag
231 index = '_j' + array.tag
232 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
234 self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
239 def visitPointer(self, pointer, lvalue, rvalue):
240 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
241 print ' if (_a%s) {' % (pointer.tag)
243 self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
247 def visitIntPointer(self, pointer, lvalue, rvalue):
250 def visitObjPointer(self, pointer, lvalue, rvalue):
251 print r' _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
253 def visitLinearPointer(self, pointer, lvalue, rvalue):
254 assert pointer.size is not None
255 if pointer.size is not None:
256 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
258 def visitReference(self, reference, lvalue, rvalue):
261 def visitHandle(self, handle, lvalue, rvalue):
262 print ' %s _origResult;' % handle.type
263 OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
264 if handle.range is None:
265 rvalue = "_origResult"
266 entry = lookupHandle(handle, rvalue)
267 print " %s = %s;" % (entry, lvalue)
268 print ' if (retrace::verbosity >= 2) {'
269 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
272 i = '_h' + handle.tag
273 lvalue = "%s + %s" % (lvalue, i)
274 rvalue = "_origResult + %s" % (i,)
275 entry = lookupHandle(handle, rvalue)
276 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
277 print ' {entry} = {lvalue};'.format(**locals())
278 print ' if (retrace::verbosity >= 2) {'
279 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
283 def visitBlob(self, blob, lvalue, rvalue):
286 def visitString(self, string, lvalue, rvalue):
291 def visitStruct(self, struct, lvalue, rvalue):
292 tmp = '_s_' + struct.tag + '_' + str(self.seq)
295 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
296 print ' assert(%s);' % (tmp,)
297 print ' (void)%s;' % (tmp,)
298 for i in range(len(struct.members)):
299 member_type, member_name = struct.members[i]
300 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
302 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
303 self.visit(polymorphic.defaultType, lvalue, rvalue)
305 def visitOpaque(self, opaque, lvalue, rvalue):
311 def retraceFunction(self, function):
312 print 'static void retrace_%s(trace::Call &call) {' % function.name
313 self.retraceFunctionBody(function)
317 def retraceInterfaceMethod(self, interface, method):
318 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
319 self.retraceInterfaceMethodBody(interface, method)
323 def retraceFunctionBody(self, function):
324 assert function.sideeffects
326 if function.type is not stdapi.Void:
327 self.checkOrigResult(function)
329 self.deserializeArgs(function)
331 self.declareRet(function)
332 self.invokeFunction(function)
334 self.swizzleValues(function)
336 def retraceInterfaceMethodBody(self, interface, method):
337 assert method.sideeffects
339 if method.type is not stdapi.Void:
340 self.checkOrigResult(method)
342 self.deserializeThisPointer(interface)
344 self.deserializeArgs(method)
346 self.declareRet(method)
347 self.invokeInterfaceMethod(interface, method)
349 self.swizzleValues(method)
351 def checkOrigResult(self, function):
352 '''Hook for checking the original result, to prevent succeeding now
353 where the original did not, which would cause diversion and potentially
354 unpredictable results.'''
356 assert function.type is not stdapi.Void
358 if str(function.type) == 'HRESULT':
359 print r' if (call.ret && FAILED(call.ret->toSInt())) {'
363 def deserializeThisPointer(self, interface):
364 print r' %s *_this;' % (interface.name,)
365 print r' _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
366 print r' if (!_this) {'
367 print r' retrace::warning(call) << "NULL this pointer\n";'
371 def deserializeArgs(self, function):
372 print ' retrace::ScopedAllocator _allocator;'
373 print ' (void)_allocator;'
375 for arg in function.args:
376 arg_type = arg.type.mutable()
377 print ' %s %s;' % (arg_type, arg.name)
378 rvalue = 'call.arg(%u)' % (arg.index,)
381 self.extractArg(function, arg, arg_type, lvalue, rvalue)
382 except UnsupportedType:
384 print ' memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
389 self.failFunction(function)
390 if function.name[-1].islower():
391 sys.stderr.write('warning: unsupported %s call\n' % function.name)
394 def swizzleValues(self, function):
395 for arg in function.args:
397 arg_type = arg.type.mutable()
398 rvalue = 'call.arg(%u)' % (arg.index,)
401 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
402 except UnsupportedType:
403 print ' // XXX: %s' % arg.name
404 if function.type is not stdapi.Void:
408 self.regiterSwizzledValue(function.type, lvalue, rvalue)
409 except UnsupportedType:
411 print ' // XXX: result'
413 def failFunction(self, function):
414 print ' if (retrace::verbosity >= 0) {'
415 print ' retrace::unsupported(call);'
419 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
420 ValueAllocator().visit(arg_type, lvalue, rvalue)
422 ValueDeserializer().visit(arg_type, lvalue, rvalue)
424 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
426 ValueAllocator().visit(arg_type, lvalue, rvalue)
427 except UnsupportedType:
429 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
431 def regiterSwizzledValue(self, type, lvalue, rvalue):
432 visitor = SwizzledValueRegistrator()
433 visitor.visit(type, lvalue, rvalue)
435 def declareRet(self, function):
436 if function.type is not stdapi.Void:
437 print ' %s _result;' % (function.type)
439 def invokeFunction(self, function):
440 arg_names = ", ".join(function.argNames())
441 if function.type is not stdapi.Void:
442 print ' _result = %s(%s);' % (function.name, arg_names)
443 print ' (void)_result;'
445 print ' %s(%s);' % (function.name, arg_names)
447 def invokeInterfaceMethod(self, interface, method):
448 # On release our reference when we reach Release() == 0 call in the
450 if method.name == 'Release':
451 print ' if (call.ret->toUInt()) {'
454 print ' _obj_map.erase(call.arg(0).toUIntPtr());'
456 arg_names = ", ".join(method.argNames())
457 if method.type is not stdapi.Void:
458 print ' _result = _this->%s(%s);' % (method.name, arg_names)
459 print ' (void)_result;'
461 print ' _this->%s(%s);' % (method.name, arg_names)
463 def filterFunction(self, function):
466 table_name = 'retrace::callbacks'
468 def retraceApi(self, api):
470 print '#include "os_time.hpp"'
471 print '#include "trace_parser.hpp"'
472 print '#include "retrace.hpp"'
473 print '#include "retrace_swizzle.hpp"'
476 types = api.getAllTypes()
477 handles = [type for type in types if isinstance(type, stdapi.Handle)]
479 for handle in handles:
480 if handle.name not in handle_names:
481 if handle.key is None:
482 print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
484 key_name, key_type = handle.key
485 print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
486 handle_names.add(handle.name)
489 print 'static std::map<unsigned long long, void *> _obj_map;'
492 functions = filter(self.filterFunction, api.functions)
493 for function in functions:
494 if function.sideeffects and not function.internal:
495 self.retraceFunction(function)
496 interfaces = api.getAllInterfaces()
497 for interface in interfaces:
498 for method in interface.iterMethods():
499 if method.sideeffects and not method.internal:
500 self.retraceInterfaceMethod(interface, method)
502 print 'const retrace::Entry %s[] = {' % self.table_name
503 for function in functions:
504 if not function.internal:
505 if function.sideeffects:
506 print ' {"%s", &retrace_%s},' % (function.name, function.name)
508 print ' {"%s", &retrace::ignore},' % (function.name,)
509 for interface in interfaces:
510 for method in interface.iterMethods():
511 if method.sideeffects:
512 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
514 print ' {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
515 print ' {NULL, NULL}'