X-Git-Url: https://git.cworth.org/git?a=blobdiff_plain;f=wrappers%2Ftrace.py;h=8ff2b9174efbdcb863f3aee73116c70749ed6f3a;hb=912abd59063df00d901d765fb15a01e8486689c8;hp=da61e2a9660c5319abd6e6508bb47d38e86d3344;hpb=4fb1ab04af9593b7aea4be4ef8b5e56eec4f5f0d;p=apitrace diff --git a/wrappers/trace.py b/wrappers/trace.py index da61e2a..8ff2b91 100644 --- a/wrappers/trace.py +++ b/wrappers/trace.py @@ -40,44 +40,6 @@ def getWrapperInterfaceName(interface): -class ExpanderMixin: - '''Mixin class that provides a bunch of methods to expand C expressions - from the specifications.''' - - __structs = None - __indices = None - - def expand(self, expr): - # Expand a C expression, replacing certain variables - if not isinstance(expr, basestring): - return expr - variables = {} - - if self.__structs is not None: - variables['self'] = '(%s)' % self.__structs[0] - if self.__indices is not None: - variables['i'] = self.__indices[0] - - expandedExpr = expr.format(**variables) - if expandedExpr != expr and 0: - sys.stderr.write(" %r -> %r\n" % (expr, expandedExpr)) - return expandedExpr - - def visitMember(self, structInstance, member_type, *args, **kwargs): - self.__structs = (structInstance, self.__structs) - try: - return self.visit(member_type, *args, **kwargs) - finally: - _, self.__structs = self.__structs - - def visitElement(self, element_index, element_type, *args, **kwargs): - self.__indices = (element_index, self.__indices) - try: - return self.visit(element_type, *args, **kwargs) - finally: - _, self.__indices = self.__indices - - class ComplexValueSerializer(stdapi.OnceVisitor): '''Type visitors which generates serialization functions for complex types. @@ -104,10 +66,17 @@ class ComplexValueSerializer(stdapi.OnceVisitor): def visitStruct(self, struct): print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members)) for type, name, in struct.members: - print ' "%s",' % (name,) + if name is None: + print ' "",' + else: + print ' "%s",' % (name,) print '};' print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,) - print ' %u, "%s", %u, _struct%s_members' % (struct.id, struct.name, len(struct.members), struct.tag) + if struct.name is None: + structName = '""' + else: + structName = '"%s"' % struct.name + print ' %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag) print '};' print @@ -120,22 +89,22 @@ class ComplexValueSerializer(stdapi.OnceVisitor): def visitEnum(self, enum): print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag) for value in enum.values: - print ' {"%s", %s},' % (value, value) + print ' {"%s", %s},' % (value, value) print '};' print print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag) - print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag) + print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag) print '};' print def visitBitmask(self, bitmask): print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag) for value in bitmask.values: - print ' {"%s", %s},' % (value, value) + print ' {"%s", %s},' % (value, value) print '};' print print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag) - print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag) + print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag) print '};' print @@ -181,7 +150,7 @@ class ComplexValueSerializer(stdapi.OnceVisitor): print -class ValueSerializer(stdapi.Visitor, ExpanderMixin): +class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin): '''Visitor which generates code to serialize any type. Simple types are serialized inline here, whereas the serialization of @@ -189,11 +158,6 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin): ComplexValueSerializer visitor above. ''' - def __init__(self): - #stdapi.Visitor.__init__(self) - self.indices = [] - self.instances = [] - def visitLiteral(self, literal, instance): print ' trace::localWriter.write%s(%s);' % (literal.kind, instance) @@ -208,7 +172,7 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin): # reinterpret_cast is necessary for GLubyte * <=> char * instance = 'reinterpret_cast<%s>(%s)' % (cast, instance) if string.length is not None: - length = ', %s' % string.length + length = ', %s' % self.expand(string.length) else: length = '' print ' trace::localWriter.write%s(%s%s);' % (suffix, instance, length) @@ -218,8 +182,8 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin): def visitStruct(self, struct, instance): print ' trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,) - for type, name in struct.members: - self.visitMember(instance, type, '(%s).%s' % (instance, name,)) + for member in struct.members: + self.visitMember(member, instance) print ' trace::localWriter.endStruct();' def visitArray(self, array, instance): @@ -287,12 +251,21 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin): if polymorphic.contextLess: print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance) else: - print ' switch (%s) {' % polymorphic.switchExpr + switchExpr = self.expand(polymorphic.switchExpr) + print ' switch (%s) {' % switchExpr for cases, type in polymorphic.iterSwitch(): for case in cases: print ' %s:' % case - self.visit(type, 'static_cast<%s>(%s)' % (type, instance)) + caseInstance = instance + if type.expr is not None: + caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance) + self.visit(type, caseInstance) print ' break;' + if polymorphic.defaultType is None: + print r' default:' + print r' os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,) + print r' trace::localWriter.writeNull();' + print r' break;' print ' }' @@ -312,7 +285,7 @@ class WrapDecider(stdapi.Traverser): self.needsWrapping = True -class ValueWrapper(stdapi.Traverser, ExpanderMixin): +class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin): '''Type visitor which will generate the code to wrap an instance. Wrapping is necessary mostly for interfaces, however interface pointers can @@ -320,8 +293,8 @@ class ValueWrapper(stdapi.Traverser, ExpanderMixin): ''' def visitStruct(self, struct, instance): - for type, name in struct.members: - self.visitMember(instance, type, "(%s).%s" % (instance, name)) + for member in struct.members: + self.visitMember(member, instance) def visitArray(self, array, instance): array_length = self.expand(array.length) @@ -340,6 +313,8 @@ class ValueWrapper(stdapi.Traverser, ExpanderMixin): elem_type = pointer.type.mutable() if isinstance(elem_type, stdapi.Interface): self.visitInterfacePointer(elem_type, instance) + elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface): + self.visitInterfacePointer(elem_type.type, instance) else: self.visitPointer(pointer, instance) @@ -348,7 +323,7 @@ class ValueWrapper(stdapi.Traverser, ExpanderMixin): def visitInterfacePointer(self, interface, instance): print " if (%s) {" % instance - print " %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance) + print " %s = %s::_Create(__FUNCTION__, %s);" % (instance, getWrapperInterfaceName(interface), instance) print " }" def visitPolymorphic(self, type, instance): @@ -361,6 +336,23 @@ class ValueUnwrapper(ValueWrapper): allocated = False + def visitStruct(self, struct, instance): + if not self.allocated: + # Argument is constant. We need to create a non const + print ' {' + print " %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct) + print ' *_t = %s;' % (instance,) + assert instance.startswith('*') + print ' %s = _t;' % (instance[1:],) + instance = '*_t' + self.allocated = True + try: + return ValueWrapper.visitStruct(self, struct, instance) + finally: + print ' }' + else: + return ValueWrapper.visitStruct(self, struct, instance) + def visitArray(self, array, instance): if self.allocated or isinstance(instance, stdapi.Interface): return ValueWrapper.visitArray(self, array, instance) @@ -390,6 +382,9 @@ class ValueUnwrapper(ValueWrapper): class Tracer: '''Base class to orchestrate the code generation of API tracing.''' + # 0-3 are reserved to memcpy, malloc, free, and realloc + __id = 4 + def __init__(self): self.api = None @@ -407,8 +402,9 @@ class Tracer: self.header(api) # Includes - for header in api.headers: - print header + for module in api.modules: + for header in module.headers: + print header print # Generate the serializer functions @@ -423,8 +419,10 @@ class Tracer: # Function wrappers self.interface = None self.base = None - map(self.traceFunctionDecl, api.functions) - map(self.traceFunctionImpl, api.functions) + for function in api.getAllFunctions(): + self.traceFunctionDecl(function) + for function in api.getAllFunctions(): + self.traceFunctionImpl(function) print self.footer(api) @@ -441,6 +439,7 @@ class Tracer: print print '#include "trace.hpp"' print + print 'static std::map g_WrappedObjects;' def footer(self, api): pass @@ -453,9 +452,14 @@ class Tracer: print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args])) else: print 'static const char ** _%s_args = NULL;' % (function.name,) - print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name) + print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, self.getFunctionSigId(), function.name, len(function.args), function.name) print + def getFunctionSigId(self): + id = Tracer.__id + Tracer.__id += 1 + return id + def isFunctionPublic(self, function): return True @@ -502,9 +506,9 @@ class Tracer: print ' }' if function.type is not stdapi.Void: self.serializeRet(function, "_result") - print ' trace::localWriter.endLeave();' if function.type is not stdapi.Void: self.wrapRet(function, "_result") + print ' trace::localWriter.endLeave();' def invokeFunction(self, function, prefix='_', suffix=''): if function.type is stdapi.Void: @@ -519,7 +523,7 @@ class Tracer: return 'true' if str(function.type) == 'HRESULT': return 'SUCCEEDED(_result)' - return 'false' + return 'true' def serializeArg(self, function, arg): print ' trace::localWriter.beginArg(%u);' % (arg.index,) @@ -560,9 +564,6 @@ class Tracer: def wrapRet(self, function, instance): self.wrapValue(function.type, instance) - def unwrapRet(self, function, instance): - self.unwrapValue(function.type, instance) - def needsWrapping(self, type): visitor = WrapDecider() visitor.visit(type) @@ -590,9 +591,11 @@ class Tracer: def declareWrapperInterface(self, interface): print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name) print "{" - print "public:" + print "private:" print " %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name) print " virtual ~%s();" % getWrapperInterfaceName(interface) + print "public:" + print " static %s* _Create(const char *functionName, %s * pInstance);" % (getWrapperInterfaceName(interface), interface.name) print for method in interface.iterMethods(): print " " + method.prototype() + ";" @@ -600,6 +603,11 @@ class Tracer: #print "private:" for type, name, value in self.enumWrapperInterfaceVariables(interface): print ' %s %s;' % (type, name) + for i in range(64): + print r' virtual void _dummy%i(void) const {' % i + print r' os::log("error: %s: unexpected virtual method\n");' % interface.name + print r' os::abort();' + print r' }' print "};" print @@ -607,17 +615,45 @@ class Tracer: return [ ("DWORD", "m_dwMagic", "0xd8365d6c"), ("%s *" % interface.name, "m_pInstance", "pInstance"), + ("void *", "m_pVtbl", "*(void **)pInstance"), + ("UINT", "m_NumMethods", len(list(interface.iterBaseMethods()))), ] def implementWrapperInterface(self, interface): self.interface = interface + # Private constructor print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name) for type, name, value in self.enumWrapperInterfaceVariables(interface): print ' %s = %s;' % (name, value) print '}' print + + # Public constructor + print '%s *%s::_Create(const char *functionName, %s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name) + print r' std::map::const_iterator it = g_WrappedObjects.find(pInstance);' + print r' if (it != g_WrappedObjects.end()) {' + print r' Wrap%s *pWrapper = (Wrap%s *)it->second;' % (interface.name, interface.name) + print r' assert(pWrapper);' + print r' assert(pWrapper->m_dwMagic == 0xd8365d6c);' + print r' assert(pWrapper->m_pInstance == pInstance);' + print r' if (pWrapper->m_pVtbl == *(void **)pInstance &&' + print r' pWrapper->m_NumMethods >= %s) {' % len(list(interface.iterBaseMethods())) + #print r' os::log("%s: fetched pvObj=%p pWrapper=%p pVtbl=%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' + print r' return pWrapper;' + print r' }' + print r' }' + print r' Wrap%s *pWrapper = new Wrap%s(pInstance);' % (interface.name, interface.name) + #print r' os::log("%%s: created %s pvObj=%%p pWrapper=%%p pVtbl=%%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' % interface.name + print r' g_WrappedObjects[pInstance] = pWrapper;' + print r' return pWrapper;' + print '}' + print + + # Destructor print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface)) + #print r' os::log("%s::Release: deleted pvObj=%%p pWrapper=%%p pVtbl=%%p\n", m_pInstance, this, m_pVtbl);' % interface.name + print r' g_WrappedObjects.erase(m_pInstance);' print '}' print @@ -629,6 +665,10 @@ class Tracer: def implementWrapperInterfaceMethod(self, interface, base, method): print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {' + + if False: + print r' os::log("%%s(%%p -> %%p)\n", "%s", this, m_pInstance);' % (getWrapperInterfaceName(interface) + '::' + method.name) + if method.type is not stdapi.Void: print ' %s _result;' % method.type @@ -643,7 +683,7 @@ class Tracer: assert not method.internal print ' static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args])) - print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1) + print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (self.getFunctionSigId(), interface.name + '::' + method.name, len(method.args) + 1) print ' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base) @@ -670,14 +710,16 @@ class Tracer: if method.type is not stdapi.Void: self.serializeRet(method, '_result') - print ' trace::localWriter.endLeave();' if method.type is not stdapi.Void: self.wrapRet(method, '_result') if method.name == 'Release': assert method.type is not stdapi.Void - print ' if (!_result)' - print ' delete this;' + print r' if (!_result) {' + print r' delete this;' + print r' }' + + print ' trace::localWriter.endLeave();' def implementIidWrapper(self, api): print r'static void' @@ -696,7 +738,7 @@ class Tracer: else_ = '' for iface in api.getAllInterfaces(): print r' %sif (riid == IID_%s) {' % (else_, iface.name) - print r' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name) + print r' *ppvObj = Wrap%s::_Create(functionName, (%s *) *ppvObj);' % (iface.name, iface.name) print r' }' else_ = 'else ' print r' %s{' % else_