1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
36 import specs.stdapi as stdapi
37 import specs.glapi as glapi
40 class UnsupportedType(Exception):
44 def lookupHandle(handle, value):
45 if handle.key is None:
46 return "_%s_map[%s]" % (handle.name, value)
48 key_name, key_type = handle.key
49 return "_%s_map[%s][%s]" % (handle.name, key_name, value)
52 class ValueAllocator(stdapi.Visitor):
54 def visitLiteral(self, literal, lvalue, rvalue):
57 def visitConst(self, const, lvalue, rvalue):
58 self.visit(const.type, lvalue, rvalue)
60 def visitAlias(self, alias, lvalue, rvalue):
61 self.visit(alias.type, lvalue, rvalue)
63 def visitEnum(self, enum, lvalue, rvalue):
66 def visitBitmask(self, bitmask, lvalue, rvalue):
69 def visitArray(self, array, lvalue, rvalue):
70 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
72 def visitPointer(self, pointer, lvalue, rvalue):
73 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
75 def visitIntPointer(self, pointer, lvalue, rvalue):
78 def visitObjPointer(self, pointer, lvalue, rvalue):
81 def visitLinearPointer(self, pointer, lvalue, rvalue):
84 def visitReference(self, reference, lvalue, rvalue):
85 self.visit(reference.type, lvalue, rvalue);
87 def visitHandle(self, handle, lvalue, rvalue):
90 def visitBlob(self, blob, lvalue, rvalue):
93 def visitString(self, string, lvalue, rvalue):
96 def visitStruct(self, struct, lvalue, rvalue):
99 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
100 self.visit(polymorphic.defaultType, lvalue, rvalue)
102 def visitOpaque(self, opaque, lvalue, rvalue):
106 class ValueDeserializer(stdapi.Visitor):
108 def visitLiteral(self, literal, lvalue, rvalue):
109 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
111 def visitConst(self, const, lvalue, rvalue):
112 self.visit(const.type, lvalue, rvalue)
114 def visitAlias(self, alias, lvalue, rvalue):
115 self.visit(alias.type, lvalue, rvalue)
117 def visitEnum(self, enum, lvalue, rvalue):
118 print ' %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
120 def visitBitmask(self, bitmask, lvalue, rvalue):
121 self.visit(bitmask.type, lvalue, rvalue)
123 def visitArray(self, array, lvalue, rvalue):
125 tmp = '_a_' + array.tag + '_' + str(self.seq)
128 print ' if (%s) {' % (lvalue,)
129 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
130 length = '%s->values.size()' % (tmp,)
131 index = '_j' + array.tag
132 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
134 self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
139 def visitPointer(self, pointer, lvalue, rvalue):
140 tmp = '_a_' + pointer.tag + '_' + str(self.seq)
143 print ' if (%s) {' % (lvalue,)
144 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
146 self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
150 def visitIntPointer(self, pointer, lvalue, rvalue):
151 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
153 def visitObjPointer(self, pointer, lvalue, rvalue):
154 old_lvalue = '(%s).toUIntPtr()' % (rvalue,)
155 new_lvalue = '_obj_map[%s]' % (old_lvalue,)
156 print ' if (retrace::verbosity >= 2) {'
157 print ' std::cout << std::hex << "obj 0x" << size_t(%s) << " <- 0x" << size_t(%s) << std::dec <<"\\n";' % (old_lvalue, new_lvalue)
159 print ' %s = static_cast<%s>(%s);' % (lvalue, pointer, new_lvalue)
161 def visitLinearPointer(self, pointer, lvalue, rvalue):
162 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
164 def visitReference(self, reference, lvalue, rvalue):
165 self.visit(reference.type, lvalue, rvalue);
167 def visitHandle(self, handle, lvalue, rvalue):
168 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
169 self.visit(handle.type, lvalue, rvalue);
170 new_lvalue = lookupHandle(handle, lvalue)
171 print ' if (retrace::verbosity >= 2) {'
172 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
174 print ' %s = %s;' % (lvalue, new_lvalue)
176 def visitBlob(self, blob, lvalue, rvalue):
177 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
179 def visitString(self, string, lvalue, rvalue):
180 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
184 def visitStruct(self, struct, lvalue, rvalue):
185 tmp = '_s_' + struct.tag + '_' + str(self.seq)
188 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
189 print ' assert(%s);' % (tmp)
190 for i in range(len(struct.members)):
191 member_type, member_name = struct.members[i]
192 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
194 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
195 self.visit(polymorphic.defaultType, lvalue, rvalue)
197 def visitOpaque(self, opaque, lvalue, rvalue):
198 raise UnsupportedType
201 class OpaqueValueDeserializer(ValueDeserializer):
202 '''Value extractor that also understands opaque values.
204 Normally opaque values can't be retraced, unless they are being extracted
205 in the context of handles.'''
207 def visitOpaque(self, opaque, lvalue, rvalue):
208 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
211 class SwizzledValueRegistrator(stdapi.Visitor):
212 '''Type visitor which will register (un)swizzled value pairs, to later be
215 def visitLiteral(self, literal, lvalue, rvalue):
218 def visitAlias(self, alias, lvalue, rvalue):
219 self.visit(alias.type, lvalue, rvalue)
221 def visitEnum(self, enum, lvalue, rvalue):
224 def visitBitmask(self, bitmask, lvalue, rvalue):
227 def visitArray(self, array, lvalue, rvalue):
228 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
229 print ' if (_a%s) {' % (array.tag)
230 length = '_a%s->values.size()' % array.tag
231 index = '_j' + array.tag
232 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
234 self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
239 def visitPointer(self, pointer, lvalue, rvalue):
240 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
241 print ' if (_a%s) {' % (pointer.tag)
243 self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
247 def visitIntPointer(self, pointer, lvalue, rvalue):
250 def visitObjPointer(self, pointer, lvalue, rvalue):
251 print r' _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
253 def visitLinearPointer(self, pointer, lvalue, rvalue):
254 assert pointer.size is not None
255 if pointer.size is not None:
256 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
258 def visitReference(self, reference, lvalue, rvalue):
261 def visitHandle(self, handle, lvalue, rvalue):
262 print ' %s _origResult;' % handle.type
263 OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
264 if handle.range is None:
265 rvalue = "_origResult"
266 entry = lookupHandle(handle, rvalue)
267 print " %s = %s;" % (entry, lvalue)
268 print ' if (retrace::verbosity >= 2) {'
269 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
272 i = '_h' + handle.tag
273 lvalue = "%s + %s" % (lvalue, i)
274 rvalue = "_origResult + %s" % (i,)
275 entry = lookupHandle(handle, rvalue)
276 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
277 print ' {entry} = {lvalue};'.format(**locals())
278 print ' if (retrace::verbosity >= 2) {'
279 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
283 def visitBlob(self, blob, lvalue, rvalue):
286 def visitString(self, string, lvalue, rvalue):
291 def visitStruct(self, struct, lvalue, rvalue):
292 tmp = '_s_' + struct.tag + '_' + str(self.seq)
295 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
296 print ' assert(%s);' % (tmp,)
297 print ' (void)%s;' % (tmp,)
298 for i in range(len(struct.members)):
299 member_type, member_name = struct.members[i]
300 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
302 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
303 self.visit(polymorphic.defaultType, lvalue, rvalue)
305 def visitOpaque(self, opaque, lvalue, rvalue):
311 def retraceFunction(self, function):
312 print 'static void retrace_%s(trace::Call &call) {' % function.name
313 self.retraceFunctionBody(function)
317 def retraceInterfaceMethod(self, interface, method):
318 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
319 self.retraceInterfaceMethodBody(interface, method)
323 def retraceFunctionBody(self, function):
324 assert function.sideeffects
326 if function.type is not stdapi.Void:
327 self.checkOrigResult(function)
329 self.deserializeArgs(function)
331 self.invokeFunction(function)
333 self.swizzleValues(function)
335 def retraceInterfaceMethodBody(self, interface, method):
336 assert method.sideeffects
338 if method.type is not stdapi.Void:
339 self.checkOrigResult(method)
341 self.deserializeThisPointer(interface)
343 self.deserializeArgs(method)
345 self.invokeInterfaceMethod(interface, method)
347 self.swizzleValues(method)
349 def checkOrigResult(self, function):
350 '''Hook for checking the original result, to prevent succeeding now
351 where the original did not, which would cause diversion and potentially
352 unpredictable results.'''
354 assert function.type is not stdapi.Void
356 if str(function.type) == 'HRESULT':
357 print r' if (call.ret && FAILED(call.ret->toSInt())) {'
361 def deserializeThisPointer(self, interface):
362 print r' %s *_this;' % (interface.name,)
363 print r' _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
364 print r' if (!_this) {'
365 print r' retrace::warning(call) << "NULL this pointer\n";'
369 def deserializeArgs(self, function):
370 print ' retrace::ScopedAllocator _allocator;'
371 print ' (void)_allocator;'
373 for arg in function.args:
374 arg_type = arg.type.mutable()
375 print ' %s %s;' % (arg_type, arg.name)
376 rvalue = 'call.arg(%u)' % (arg.index,)
379 self.extractArg(function, arg, arg_type, lvalue, rvalue)
380 except UnsupportedType:
382 print ' memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
387 self.failFunction(function)
388 if function.name[-1].islower():
389 sys.stderr.write('warning: unsupported %s call\n' % function.name)
392 def swizzleValues(self, function):
393 for arg in function.args:
395 arg_type = arg.type.mutable()
396 rvalue = 'call.arg(%u)' % (arg.index,)
399 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
400 except UnsupportedType:
401 print ' // XXX: %s' % arg.name
402 if function.type is not stdapi.Void:
406 self.regiterSwizzledValue(function.type, lvalue, rvalue)
407 except UnsupportedType:
409 print ' // XXX: result'
411 def failFunction(self, function):
412 print ' if (retrace::verbosity >= 0) {'
413 print ' retrace::unsupported(call);'
417 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
418 ValueAllocator().visit(arg_type, lvalue, rvalue)
420 ValueDeserializer().visit(arg_type, lvalue, rvalue)
422 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
424 ValueAllocator().visit(arg_type, lvalue, rvalue)
425 except UnsupportedType:
427 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
429 def regiterSwizzledValue(self, type, lvalue, rvalue):
430 visitor = SwizzledValueRegistrator()
431 visitor.visit(type, lvalue, rvalue)
433 def invokeFunction(self, function):
434 arg_names = ", ".join(function.argNames())
435 if function.type is not stdapi.Void:
436 print ' %s _result;' % (function.type)
437 print ' _result = %s(%s);' % (function.name, arg_names)
438 print ' (void)_result;'
440 print ' %s(%s);' % (function.name, arg_names)
442 def invokeInterfaceMethod(self, interface, method):
443 # On release our reference when we reach Release() == 0 call in the
445 if method.name == 'Release':
446 print ' if (call.ret->toUInt()) {'
449 print ' _obj_map.erase(call.arg(0).toUIntPtr());'
451 arg_names = ", ".join(method.argNames())
452 if method.type is not stdapi.Void:
453 print ' %s _result;' % (method.type)
454 print ' _result = _this->%s(%s);' % (method.name, arg_names)
455 print ' (void)_result;'
457 print ' _this->%s(%s);' % (method.name, arg_names)
459 def filterFunction(self, function):
462 table_name = 'retrace::callbacks'
464 def retraceApi(self, api):
466 print '#include "os_time.hpp"'
467 print '#include "trace_parser.hpp"'
468 print '#include "retrace.hpp"'
471 types = api.getAllTypes()
472 handles = [type for type in types if isinstance(type, stdapi.Handle)]
474 for handle in handles:
475 if handle.name not in handle_names:
476 if handle.key is None:
477 print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
479 key_name, key_type = handle.key
480 print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
481 handle_names.add(handle.name)
484 print 'static std::map<unsigned long long, void *> _obj_map;'
487 functions = filter(self.filterFunction, api.functions)
488 for function in functions:
489 if function.sideeffects:
490 self.retraceFunction(function)
491 interfaces = api.getAllInterfaces()
492 for interface in interfaces:
493 for method in interface.iterMethods():
494 if method.sideeffects:
495 self.retraceInterfaceMethod(interface, method)
497 print 'const retrace::Entry %s[] = {' % self.table_name
498 for function in functions:
499 if function.sideeffects:
500 print ' {"%s", &retrace_%s},' % (function.name, function.name)
502 print ' {"%s", &retrace::ignore},' % (function.name,)
503 for interface in interfaces:
504 for method in interface.iterMethods():
505 if method.sideeffects:
506 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
508 print ' {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
509 print ' {NULL, NULL}'