1 ##########################################################################
3 # Copyright 2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
27 """Generic retracing code generator."""
33 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
36 import specs.stdapi as stdapi
37 import specs.glapi as glapi
40 class UnsupportedType(Exception):
44 def lookupHandle(handle, value):
45 if handle.key is None:
46 return "_%s_map[%s]" % (handle.name, value)
48 key_name, key_type = handle.key
49 return "_%s_map[%s][%s]" % (handle.name, key_name, value)
52 class ValueAllocator(stdapi.Visitor):
54 def visitLiteral(self, literal, lvalue, rvalue):
57 def visitConst(self, const, lvalue, rvalue):
58 self.visit(const.type, lvalue, rvalue)
60 def visitAlias(self, alias, lvalue, rvalue):
61 self.visit(alias.type, lvalue, rvalue)
63 def visitEnum(self, enum, lvalue, rvalue):
66 def visitBitmask(self, bitmask, lvalue, rvalue):
69 def visitArray(self, array, lvalue, rvalue):
70 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, array.type, rvalue)
72 def visitPointer(self, pointer, lvalue, rvalue):
73 print ' %s = _allocator.alloc<%s>(&%s);' % (lvalue, pointer.type, rvalue)
75 def visitIntPointer(self, pointer, lvalue, rvalue):
78 def visitObjPointer(self, pointer, lvalue, rvalue):
81 def visitLinearPointer(self, pointer, lvalue, rvalue):
84 def visitReference(self, reference, lvalue, rvalue):
85 self.visit(reference.type, lvalue, rvalue);
87 def visitHandle(self, handle, lvalue, rvalue):
90 def visitBlob(self, blob, lvalue, rvalue):
93 def visitString(self, string, lvalue, rvalue):
96 def visitStruct(self, struct, lvalue, rvalue):
99 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
100 self.visit(polymorphic.defaultType, lvalue, rvalue)
102 def visitOpaque(self, opaque, lvalue, rvalue):
106 class ValueDeserializer(stdapi.Visitor):
108 def visitLiteral(self, literal, lvalue, rvalue):
109 print ' %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
111 def visitConst(self, const, lvalue, rvalue):
112 self.visit(const.type, lvalue, rvalue)
114 def visitAlias(self, alias, lvalue, rvalue):
115 self.visit(alias.type, lvalue, rvalue)
117 def visitEnum(self, enum, lvalue, rvalue):
118 print ' %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
120 def visitBitmask(self, bitmask, lvalue, rvalue):
121 self.visit(bitmask.type, lvalue, rvalue)
123 def visitArray(self, array, lvalue, rvalue):
125 tmp = '_a_' + array.tag + '_' + str(self.seq)
128 print ' if (%s) {' % (lvalue,)
129 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
130 length = '%s->values.size()' % (tmp,)
131 index = '_j' + array.tag
132 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
134 self.visit(array.type, '%s[%s]' % (lvalue, index), '*%s->values[%s]' % (tmp, index))
139 def visitPointer(self, pointer, lvalue, rvalue):
140 tmp = '_a_' + pointer.tag + '_' + str(self.seq)
143 print ' if (%s) {' % (lvalue,)
144 print ' const trace::Array *%s = dynamic_cast<const trace::Array *>(&%s);' % (tmp, rvalue)
146 self.visit(pointer.type, '%s[0]' % (lvalue,), '*%s->values[0]' % (tmp,))
150 def visitIntPointer(self, pointer, lvalue, rvalue):
151 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
153 def visitObjPointer(self, pointer, lvalue, rvalue):
154 print ' %s = static_cast<%s>(retrace::toObjPointer(%s));' % (lvalue, pointer, rvalue)
156 def visitLinearPointer(self, pointer, lvalue, rvalue):
157 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
159 def visitReference(self, reference, lvalue, rvalue):
160 self.visit(reference.type, lvalue, rvalue);
162 def visitHandle(self, handle, lvalue, rvalue):
163 #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
164 self.visit(handle.type, lvalue, rvalue);
165 new_lvalue = lookupHandle(handle, lvalue)
166 print ' if (retrace::verbosity >= 2) {'
167 print ' std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
169 print ' %s = %s;' % (lvalue, new_lvalue)
171 def visitBlob(self, blob, lvalue, rvalue):
172 print ' %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
174 def visitString(self, string, lvalue, rvalue):
175 print ' %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
179 def visitStruct(self, struct, lvalue, rvalue):
180 tmp = '_s_' + struct.tag + '_' + str(self.seq)
183 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
184 print ' assert(%s);' % (tmp)
185 for i in range(len(struct.members)):
186 member_type, member_name = struct.members[i]
187 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
189 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
190 self.visit(polymorphic.defaultType, lvalue, rvalue)
192 def visitOpaque(self, opaque, lvalue, rvalue):
193 raise UnsupportedType
196 class OpaqueValueDeserializer(ValueDeserializer):
197 '''Value extractor that also understands opaque values.
199 Normally opaque values can't be retraced, unless they are being extracted
200 in the context of handles.'''
202 def visitOpaque(self, opaque, lvalue, rvalue):
203 print ' %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
206 class SwizzledValueRegistrator(stdapi.Visitor):
207 '''Type visitor which will register (un)swizzled value pairs, to later be
210 def visitLiteral(self, literal, lvalue, rvalue):
213 def visitAlias(self, alias, lvalue, rvalue):
214 self.visit(alias.type, lvalue, rvalue)
216 def visitEnum(self, enum, lvalue, rvalue):
219 def visitBitmask(self, bitmask, lvalue, rvalue):
222 def visitArray(self, array, lvalue, rvalue):
223 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
224 print ' if (_a%s) {' % (array.tag)
225 length = '_a%s->values.size()' % array.tag
226 index = '_j' + array.tag
227 print ' for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
229 self.visit(array.type, '%s[%s]' % (lvalue, index), '*_a%s->values[%s]' % (array.tag, index))
234 def visitPointer(self, pointer, lvalue, rvalue):
235 print ' const trace::Array *_a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
236 print ' if (_a%s) {' % (pointer.tag)
238 self.visit(pointer.type, '%s[0]' % (lvalue,), '*_a%s->values[0]' % (pointer.tag,))
242 def visitIntPointer(self, pointer, lvalue, rvalue):
245 def visitObjPointer(self, pointer, lvalue, rvalue):
246 print r' retrace::addObj(%s, %s);' % (rvalue, lvalue)
248 def visitLinearPointer(self, pointer, lvalue, rvalue):
249 assert pointer.size is not None
250 if pointer.size is not None:
251 print r' retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
253 def visitReference(self, reference, lvalue, rvalue):
256 def visitHandle(self, handle, lvalue, rvalue):
257 print ' %s _origResult;' % handle.type
258 OpaqueValueDeserializer().visit(handle.type, '_origResult', rvalue);
259 if handle.range is None:
260 rvalue = "_origResult"
261 entry = lookupHandle(handle, rvalue)
262 print " %s = %s;" % (entry, lvalue)
263 print ' if (retrace::verbosity >= 2) {'
264 print ' std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
267 i = '_h' + handle.tag
268 lvalue = "%s + %s" % (lvalue, i)
269 rvalue = "_origResult + %s" % (i,)
270 entry = lookupHandle(handle, rvalue)
271 print ' for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
272 print ' {entry} = {lvalue};'.format(**locals())
273 print ' if (retrace::verbosity >= 2) {'
274 print ' std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
278 def visitBlob(self, blob, lvalue, rvalue):
281 def visitString(self, string, lvalue, rvalue):
286 def visitStruct(self, struct, lvalue, rvalue):
287 tmp = '_s_' + struct.tag + '_' + str(self.seq)
290 print ' const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
291 print ' assert(%s);' % (tmp,)
292 print ' (void)%s;' % (tmp,)
293 for i in range(len(struct.members)):
294 member_type, member_name = struct.members[i]
295 self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
297 def visitPolymorphic(self, polymorphic, lvalue, rvalue):
298 self.visit(polymorphic.defaultType, lvalue, rvalue)
300 def visitOpaque(self, opaque, lvalue, rvalue):
306 def retraceFunction(self, function):
307 print 'static void retrace_%s(trace::Call &call) {' % function.name
308 self.retraceFunctionBody(function)
312 def retraceInterfaceMethod(self, interface, method):
313 print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
314 self.retraceInterfaceMethodBody(interface, method)
318 def retraceFunctionBody(self, function):
319 assert function.sideeffects
321 if function.type is not stdapi.Void:
322 self.checkOrigResult(function)
324 self.deserializeArgs(function)
326 self.declareRet(function)
327 self.invokeFunction(function)
329 self.swizzleValues(function)
331 def retraceInterfaceMethodBody(self, interface, method):
332 assert method.sideeffects
334 if method.type is not stdapi.Void:
335 self.checkOrigResult(method)
337 self.deserializeThisPointer(interface)
339 self.deserializeArgs(method)
341 self.declareRet(method)
342 self.invokeInterfaceMethod(interface, method)
344 self.swizzleValues(method)
346 def checkOrigResult(self, function):
347 '''Hook for checking the original result, to prevent succeeding now
348 where the original did not, which would cause diversion and potentially
349 unpredictable results.'''
351 assert function.type is not stdapi.Void
353 if str(function.type) == 'HRESULT':
354 print r' if (call.ret && FAILED(call.ret->toSInt())) {'
358 def deserializeThisPointer(self, interface):
359 print r' %s *_this;' % (interface.name,)
360 print r' _this = static_cast<%s *>(retrace::toObjPointer(call.arg(0)));' % (interface.name,)
361 print r' if (!_this) {'
362 print r' retrace::warning(call) << "NULL this pointer\n";'
366 def deserializeArgs(self, function):
367 print ' retrace::ScopedAllocator _allocator;'
368 print ' (void)_allocator;'
370 for arg in function.args:
371 arg_type = arg.type.mutable()
372 print ' %s %s;' % (arg_type, arg.name)
373 rvalue = 'call.arg(%u)' % (arg.index,)
376 self.extractArg(function, arg, arg_type, lvalue, rvalue)
377 except UnsupportedType:
379 print ' memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
384 self.failFunction(function)
385 if function.name[-1].islower():
386 sys.stderr.write('warning: unsupported %s call\n' % function.name)
389 def swizzleValues(self, function):
390 for arg in function.args:
392 arg_type = arg.type.mutable()
393 rvalue = 'call.arg(%u)' % (arg.index,)
396 self.regiterSwizzledValue(arg_type, lvalue, rvalue)
397 except UnsupportedType:
398 print ' // XXX: %s' % arg.name
399 if function.type is not stdapi.Void:
403 self.regiterSwizzledValue(function.type, lvalue, rvalue)
404 except UnsupportedType:
406 print ' // XXX: result'
408 def failFunction(self, function):
409 print ' if (retrace::verbosity >= 0) {'
410 print ' retrace::unsupported(call);'
414 def extractArg(self, function, arg, arg_type, lvalue, rvalue):
415 ValueAllocator().visit(arg_type, lvalue, rvalue)
417 ValueDeserializer().visit(arg_type, lvalue, rvalue)
419 def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
421 ValueAllocator().visit(arg_type, lvalue, rvalue)
422 except UnsupportedType:
424 OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
426 def regiterSwizzledValue(self, type, lvalue, rvalue):
427 visitor = SwizzledValueRegistrator()
428 visitor.visit(type, lvalue, rvalue)
430 def declareRet(self, function):
431 if function.type is not stdapi.Void:
432 print ' %s _result;' % (function.type)
434 def invokeFunction(self, function):
435 arg_names = ", ".join(function.argNames())
436 if function.type is not stdapi.Void:
437 print ' _result = %s(%s);' % (function.name, arg_names)
438 print ' (void)_result;'
440 print ' %s(%s);' % (function.name, arg_names)
442 def invokeInterfaceMethod(self, interface, method):
443 # On release our reference when we reach Release() == 0 call in the
445 if method.name == 'Release':
446 print ' if (call.ret->toUInt()) {'
449 print ' retrace::delObj(call.arg(0));'
451 arg_names = ", ".join(method.argNames())
452 if method.type is not stdapi.Void:
453 print ' _result = _this->%s(%s);' % (method.name, arg_names)
454 print ' (void)_result;'
456 print ' _this->%s(%s);' % (method.name, arg_names)
458 def filterFunction(self, function):
461 table_name = 'retrace::callbacks'
463 def retraceApi(self, api):
465 print '#include "os_time.hpp"'
466 print '#include "trace_parser.hpp"'
467 print '#include "retrace.hpp"'
468 print '#include "retrace_swizzle.hpp"'
471 types = api.getAllTypes()
472 handles = [type for type in types if isinstance(type, stdapi.Handle)]
474 for handle in handles:
475 if handle.name not in handle_names:
476 if handle.key is None:
477 print 'static retrace::map<%s> _%s_map;' % (handle.type, handle.name)
479 key_name, key_type = handle.key
480 print 'static std::map<%s, retrace::map<%s> > _%s_map;' % (key_type, handle.type, handle.name)
481 handle_names.add(handle.name)
484 functions = filter(self.filterFunction, api.functions)
485 for function in functions:
486 if function.sideeffects and not function.internal:
487 self.retraceFunction(function)
488 interfaces = api.getAllInterfaces()
489 for interface in interfaces:
490 for method in interface.iterMethods():
491 if method.sideeffects and not method.internal:
492 self.retraceInterfaceMethod(interface, method)
494 print 'const retrace::Entry %s[] = {' % self.table_name
495 for function in functions:
496 if not function.internal:
497 if function.sideeffects:
498 print ' {"%s", &retrace_%s},' % (function.name, function.name)
500 print ' {"%s", &retrace::ignore},' % (function.name,)
501 for interface in interfaces:
502 for method in interface.iterMethods():
503 if method.sideeffects:
504 print ' {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
506 print ' {"%s::%s", &retrace::ignore},' % (interface.name, method.name)
507 print ' {NULL, NULL}'