1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
36 import specs.stdapi as stdapi
39 class UnsupportedType(Exception):
43 def lookupHandle(handle, value):
44 if handle.key is None:
45 return "_%s_map[%s]" % (handle.name, value)
47 key_name, key_type = handle.key
48 return "_%s_map[%s][%s]" % (handle.name, key_name, value)
51 class ValueAllocator(stdapi.Visitor):
53 def visitLiteral(self, literal, lvalue, rvalue):
56 def visitConst(self, const, lvalue, rvalue):
57 self.visit(const.type, lvalue, rvalue)
59 def visitAlias(self, alias, lvalue, rvalue):
60 self.visit(alias.type, lvalue, rvalue)
62 def visitEnum(self, enum, lvalue, rvalue):
65 def visitBitmask(self, bitmask, lvalue, rvalue):
68 def visitArray(self, array, lvalue, rvalue):
69 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
71 def visitPointer(self, pointer, lvalue, rvalue):
72 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
74 def visitIntPointer(self, pointer, lvalue, rvalue):
77 def visitObjPointer(self, pointer, lvalue, rvalue):
80 def visitLinearPointer(self, pointer, lvalue, rvalue):
83 def visitReference(self, reference, lvalue, rvalue):
84 self.visit(reference.type, lvalue, rvalue);
86 def visitHandle(self, handle, lvalue, rvalue):
89 def visitBlob(self, blob, lvalue, rvalue):
92 def visitString(self, string, lvalue, rvalue):
95 def visitStruct(self, struct, lvalue, rvalue):
98 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
99 self.visit(polymorphic.defaultType, lvalue, rvalue)
101 def visitOpaque(self, opaque, lvalue, rvalue):
105 class ValueDeserializer(stdapi.Visitor):
107 def visitLiteral(self, literal, lvalue, rvalue):
108 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
110 def visitConst(self, const, lvalue, rvalue):
111 self.visit(const.type, lvalue, rvalue)
113 def visitAlias(self, alias, lvalue, rvalue):
114 self.visit(alias.type, lvalue, rvalue)
116 def visitEnum(self, enum, lvalue, rvalue):
117 print ' %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
119 def visitBitmask(self, bitmask, lvalue, rvalue):
120 self.visit(bitmask.type, lvalue, rvalue)
122 def visitArray(self, array, lvalue, rvalue):
124 tmp = '_a_' + array.tag + '_' + str(self.seq)
127 print ' if (%s) {' % (lvalue,)
128 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
129 length = '%s->values.size()' % (tmp,)
130 index = '_j' + array.tag
131 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
133 self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
138 def visitPointer(self, pointer, lvalue, rvalue):
139 tmp = '_a_' + pointer.tag + '_' + str(self.seq)
142 print ' if (%s) {' % (lvalue,)
143 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
145 self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
149 def visitIntPointer(self, pointer, lvalue, rvalue):
150 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
152 def visitObjPointer(self, pointer, lvalue, rvalue):
153 print ' %s = static_cast<%s>(retrace::toObjPointer(call, %s));' % (lvalue, pointer, rvalue)
155 def visitLinearPointer(self, pointer, lvalue, rvalue):
156 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
158 def visitReference(self, reference, lvalue, rvalue):
159 self.visit(reference.type, lvalue, rvalue);
161 def visitHandle(self, handle, lvalue, rvalue):
162 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
163 self.visit(handle.type, lvalue, rvalue);
164 new_lvalue = lookupHandle(handle, lvalue)
165 print ' if (retrace::verbosity >= 2) {'
166 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
168 print ' %s = %s;' % (lvalue, new_lvalue)
170 def visitBlob(self, blob, lvalue, rvalue):
171 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
173 def visitString(self, string, lvalue, rvalue):
174 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
178 def visitStruct(self, struct, lvalue, rvalue):
179 tmp = '_s_' + struct.tag + '_' + str(self.seq)
182 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
183 print ' assert(%s);' % (tmp)
184 for i in range(len(struct.members)):
185 member_type, member_name = struct.members[i]
186 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
188 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
189 self.visit(polymorphic.defaultType, lvalue, rvalue)
191 def visitOpaque(self, opaque, lvalue, rvalue):
192 raise UnsupportedType
195 class OpaqueValueDeserializer(ValueDeserializer):
196 '''Value extractor that also understands opaque values.
198 Normally opaque values can't be retraced, unless they are being extracted
199 in the context of handles.'''
201 def visitOpaque(self, opaque, lvalue, rvalue):
202 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
205 class SwizzledValueRegistrator(stdapi.Visitor):
206 '''Type visitor which will register (un)swizzled value pairs, to later be
209 def visitLiteral(self, literal, lvalue, rvalue):
212 def visitAlias(self, alias, lvalue, rvalue):
213 self.visit(alias.type, lvalue, rvalue)
215 def visitEnum(self, enum, lvalue, rvalue):
218 def visitBitmask(self, bitmask, lvalue, rvalue):
221 def visitArray(self, array, lvalue, rvalue):
222 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
223 print ' if (_a%s) {' % (array.tag)
224 length = '_a%s->values.size()' % array.tag
225 index = '_j' + array.tag
226 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
228 self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
233 def visitPointer(self, pointer, lvalue, rvalue):
234 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
235 print ' if (_a%s) {' % (pointer.tag)
237 self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
241 def visitIntPointer(self, pointer, lvalue, rvalue):
244 def visitObjPointer(self, pointer, lvalue, rvalue):
245 print r' retrace::addObj(%s, %s);' % (rvalue, lvalue)
247 def visitLinearPointer(self, pointer, lvalue, rvalue):
248 assert pointer.size is not None
249 if pointer.size is not None:
250 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
252 def visitReference(self, reference, lvalue, rvalue):
255 def visitHandle(self, handle, lvalue, rvalue):
256 print ' %s _origResult;' % handle.type
257 OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
258 if handle.range is None:
259 rvalue = "_origResult"
260 entry = lookupHandle(handle, rvalue)
261 print " %s = %s;" % (entry, lvalue)
262 print ' if (retrace::verbosity >= 2) {'
263 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
266 i = '_h' + handle.tag
267 lvalue = "%s + %s" % (lvalue, i)
268 rvalue = "_origResult + %s" % (i,)
269 entry = lookupHandle(handle, rvalue)
270 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
271 print ' {entry} = {lvalue};'.format(**locals())
272 print ' if (retrace::verbosity >= 2) {'
273 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
277 def visitBlob(self, blob, lvalue, rvalue):
280 def visitString(self, string, lvalue, rvalue):
285 def visitStruct(self, struct, lvalue, rvalue):
286 tmp = '_s_' + struct.tag + '_' + str(self.seq)
289 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
290 print ' assert(%s);' % (tmp,)
291 print ' (void)%s;' % (tmp,)
292 for i in range(len(struct.members)):
293 member_type, member_name = struct.members[i]
294 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
296 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
297 self.visit(polymorphic.defaultType, lvalue, rvalue)
299 def visitOpaque(self, opaque, lvalue, rvalue):
305 def retraceFunction(self, function):
306 print 'static void retrace_%s(trace::Call &call) {' % function.name
307 self.retraceFunctionBody(function)
311 def retraceInterfaceMethod(self, interface, method):
312 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
313 self.retraceInterfaceMethodBody(interface, method)
317 def retraceFunctionBody(self, function):
318 assert function.sideeffects
320 if function.type is not stdapi.Void:
321 self.checkOrigResult(function)
323 self.deserializeArgs(function)
325 self.declareRet(function)
326 self.invokeFunction(function)
328 self.swizzleValues(function)
330 def retraceInterfaceMethodBody(self, interface, method):
331 assert method.sideeffects
333 if method.type is not stdapi.Void:
334 self.checkOrigResult(method)
336 self.deserializeThisPointer(interface)
338 self.deserializeArgs(method)
340 self.declareRet(method)
341 self.invokeInterfaceMethod(interface, method)
343 self.swizzleValues(method)
345 def checkOrigResult(self, function):
346 '''Hook for checking the original result, to prevent succeeding now
347 where the original did not, which would cause diversion and potentially
348 unpredictable results.'''
350 assert function.type is not stdapi.Void
352 if str(function.type) == 'HRESULT':
353 print r' if (call.ret && FAILED(call.ret->toSInt())) {'
357 def deserializeThisPointer(self, interface):
358 print r' %s *_this;' % (interface.name,)
359 print r' _this = static_cast<%s *>(retrace::toObjPointer(call, call.arg(0)));' % (interface.name,)
360 print r' if (!_this) {'
364 def deserializeArgs(self, function):
365 print ' retrace::ScopedAllocator _allocator;'
366 print ' (void)_allocator;'
368 for arg in function.args:
369 arg_type = arg.type.mutable()
370 print ' %s %s;' % (arg_type, arg.name)
371 rvalue = 'call.arg(%u)' % (arg.index,)
374 self.extractArg(function, arg, arg_type, lvalue, rvalue)
375 except UnsupportedType:
377 print ' memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
382 self.failFunction(function)
383 if function.name[-1].islower():
384 sys.stderr.write('warning: unsupported %s call\n' % function.name)
387 def swizzleValues(self, function):
388 for arg in function.args:
390 arg_type = arg.type.mutable()
391 rvalue = 'call.arg(%u)' % (arg.index,)
394 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
395 except UnsupportedType:
396 print ' // XXX: %s' % arg.name
397 if function.type is not stdapi.Void:
401 self.regiterSwizzledValue(function.type, lvalue, rvalue)
402 except UnsupportedType:
404 print ' // XXX: result'
406 def failFunction(self, function):
407 print ' if (retrace::verbosity >= 0) {'
408 print ' retrace::unsupported(call);'
412 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
413 ValueAllocator().visit(arg_type, lvalue, rvalue)
415 ValueDeserializer().visit(arg_type, lvalue, rvalue)
417 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
419 ValueAllocator().visit(arg_type, lvalue, rvalue)
420 except UnsupportedType:
422 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
424 def regiterSwizzledValue(self, type, lvalue, rvalue):
425 visitor = SwizzledValueRegistrator()
426 visitor.visit(type, lvalue, rvalue)
428 def declareRet(self, function):
429 if function.type is not stdapi.Void:
430 print ' %s _result;' % (function.type)
432 def invokeFunction(self, function):
433 arg_names = ", ".join(function.argNames())
434 if function.type is not stdapi.Void:
435 print ' _result = %s(%s);' % (function.name, arg_names)
436 print ' (void)_result;'
438 print ' %s(%s);' % (function.name, arg_names)
440 def invokeInterfaceMethod(self, interface, method):
441 # On release our reference when we reach Release() == 0 call in the
443 if method.name == 'Release':
444 print ' if (call.ret->toUInt()) {'
447 print ' retrace::delObj(call.arg(0));'
449 arg_names = ", ".join(method.argNames())
450 if method.type is not stdapi.Void:
451 print ' _result = _this->%s(%s);' % (method.name, arg_names)
452 print ' (void)_result;'
454 print ' _this->%s(%s);' % (method.name, arg_names)
456 def filterFunction(self, function):
459 table_name = 'retrace::callbacks'
461 def retraceApi(self, api):
463 print '#include "os_time.hpp"'
464 print '#include "trace_parser.hpp"'
465 print '#include "retrace.hpp"'
466 print '#include "retrace_swizzle.hpp"'
469 types = api.getAllTypes()
470 handles = [type for type in types if isinstance(type, stdapi.Handle)]
472 for handle in handles:
473 if handle.name not in handle_names:
474 if handle.key is None:
475 print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
477 key_name, key_type = handle.key
478 print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
479 handle_names.add(handle.name)
482 functions = filter(self.filterFunction, api.getAllFunctions())
483 for function in functions:
484 if function.sideeffects and not function.internal:
485 self.retraceFunction(function)
486 interfaces = api.getAllInterfaces()
487 for interface in interfaces:
488 for method in interface.iterMethods():
489 if method.sideeffects and not method.internal:
490 self.retraceInterfaceMethod(interface, method)
492 print 'const retrace::Entry %s[] = {' % self.table_name
493 for function in functions:
494 if not function.internal:
495 if function.sideeffects:
496 print ' {"%s", &retrace_%s},' % (function.name, function.name)
498 print ' {"%s", &retrace::ignore},' % (function.name,)
499 for interface in interfaces:
500 for method in interface.iterMethods():
501 if method.sideeffects:
502 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
504 print ' {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
505 print ' {NULL, NULL}'