]> git.cworth.org Git - apitrace/blobdiff - wrappers/trace.py
trace: Unwrap all args before serializing them.
[apitrace] / wrappers / trace.py
index 4563f454f710822887ad9bd5eaa4de5606db7d4f..61a2cb95fcf8249c13719464a36bee5891f8f76a 100644 (file)
@@ -40,44 +40,6 @@ def getWrapperInterfaceName(interface):
 
 
 
-class ExpanderMixin:
-    '''Mixin class that provides a bunch of methods to expand C expressions
-    from the specifications.'''
-
-    __structs = None
-    __indices = None
-
-    def expand(self, expr):
-        # Expand a C expression, replacing certain variables
-        if not isinstance(expr, basestring):
-            return expr
-        variables = {}
-
-        if self.__structs is not None:
-            variables['self'] = '(%s)' % self.__structs[0]
-        if self.__indices is not None:
-            variables['i'] = self.__indices[0]
-
-        expandedExpr = expr.format(**variables)
-        if expandedExpr != expr and 0:
-            sys.stderr.write("  %r -> %r\n" % (expr, expandedExpr))
-        return expandedExpr
-
-    def visitMember(self, structInstance, member_type, *args, **kwargs):
-        self.__structs = (structInstance, self.__structs)
-        try:
-            return self.visit(member_type, *args, **kwargs)
-        finally:
-            _, self.__structs = self.__structs
-
-    def visitElement(self, element_index, element_type, *args, **kwargs):
-        self.__indices = (element_index, self.__indices)
-        try:
-            return self.visit(element_type, *args, **kwargs)
-        finally:
-            _, self.__indices = self.__indices
-
-
 class ComplexValueSerializer(stdapi.OnceVisitor):
     '''Type visitors which generates serialization functions for
     complex types.
@@ -104,10 +66,17 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
     def visitStruct(self, struct):
         print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
         for type, name,  in struct.members:
-            print '    "%s",' % (name,)
+            if name is None:
+                print '    "",'
+            else:
+                print '    "%s",' % (name,)
         print '};'
         print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
-        print '   %u, "%s", %u, _struct%s_members' % (struct.id, struct.name, len(struct.members), struct.tag)
+        if struct.name is None:
+            structName = '""'
+        else:
+            structName = '"%s"' % struct.name
+        print '    %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag)
         print '};'
         print
 
@@ -120,22 +89,22 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
     def visitEnum(self, enum):
         print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
         for value in enum.values:
-            print '   {"%s", %s},' % (value, value)
+            print '    {"%s", %s},' % (value, value)
         print '};'
         print
         print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
-        print '   %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
+        print '    %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
         print '};'
         print
 
     def visitBitmask(self, bitmask):
         print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
         for value in bitmask.values:
-            print '   {"%s", %s},' % (value, value)
+            print '    {"%s", %s},' % (value, value)
         print '};'
         print
         print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
-        print '   %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
+        print '    %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
         print '};'
         print
 
@@ -181,7 +150,7 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
         print
 
 
-class ValueSerializer(stdapi.Visitor, ExpanderMixin):
+class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin):
     '''Visitor which generates code to serialize any type.
     
     Simple types are serialized inline here, whereas the serialization of
@@ -189,11 +158,6 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
     ComplexValueSerializer visitor above.
     '''
 
-    def __init__(self):
-        #stdapi.Visitor.__init__(self)
-        self.indices = []
-        self.instances = []
-
     def visitLiteral(self, literal, instance):
         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
 
@@ -208,7 +172,7 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
             # reinterpret_cast is necessary for GLubyte * <=> char *
             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
         if string.length is not None:
-            length = ', %s' % string.length
+            length = ', %s' % self.expand(string.length)
         else:
             length = ''
         print '    trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
@@ -218,8 +182,8 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
 
     def visitStruct(self, struct, instance):
         print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
-        for type, name in struct.members:
-            self.visitMember(instance, type, '(%s).%s' % (instance, name,))
+        for member in struct.members:
+            self.visitMember(member, instance)
         print '    trace::localWriter.endStruct();'
 
     def visitArray(self, array, instance):
@@ -287,12 +251,21 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
         if polymorphic.contextLess:
             print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
         else:
-            print '    switch (%s) {' % polymorphic.switchExpr
+            switchExpr = self.expand(polymorphic.switchExpr)
+            print '    switch (%s) {' % switchExpr
             for cases, type in polymorphic.iterSwitch():
                 for case in cases:
                     print '    %s:' % case
-                self.visit(type, 'static_cast<%s>(%s)' % (type, instance))
+                caseInstance = instance
+                if type.expr is not None:
+                    caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
+                self.visit(type, caseInstance)
                 print '        break;'
+            if polymorphic.defaultType is None:
+                print r'    default:'
+                print r'        os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,)
+                print r'        trace::localWriter.writeNull();'
+                print r'        break;'
             print '    }'
 
 
@@ -312,7 +285,7 @@ class WrapDecider(stdapi.Traverser):
         self.needsWrapping = True
 
 
-class ValueWrapper(stdapi.Traverser, ExpanderMixin):
+class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin):
     '''Type visitor which will generate the code to wrap an instance.
     
     Wrapping is necessary mostly for interfaces, however interface pointers can
@@ -320,8 +293,8 @@ class ValueWrapper(stdapi.Traverser, ExpanderMixin):
     '''
 
     def visitStruct(self, struct, instance):
-        for type, name in struct.members:
-            self.visitMember(instance, type, "(%s).%s" % (instance, name))
+        for member in struct.members:
+            self.visitMember(member, instance)
 
     def visitArray(self, array, instance):
         array_length = self.expand(array.length)
@@ -350,7 +323,7 @@ class ValueWrapper(stdapi.Traverser, ExpanderMixin):
 
     def visitInterfacePointer(self, interface, instance):
         print "    if (%s) {" % instance
-        print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
+        print "        %s = %s::_Create(__FUNCTION__, %s);" % (instance, getWrapperInterfaceName(interface), instance)
         print "    }"
     
     def visitPolymorphic(self, type, instance):
@@ -409,6 +382,9 @@ class ValueUnwrapper(ValueWrapper):
 class Tracer:
     '''Base class to orchestrate the code generation of API tracing.'''
 
+    # 0-3 are reserved to memcpy, malloc, free, and realloc
+    __id = 4
+
     def __init__(self):
         self.api = None
 
@@ -426,8 +402,9 @@ class Tracer:
         self.header(api)
 
         # Includes
-        for header in api.headers:
-            print header
+        for module in api.modules:
+            for header in module.headers:
+                print header
         print
 
         # Generate the serializer functions
@@ -442,8 +419,10 @@ class Tracer:
         # Function wrappers
         self.interface = None
         self.base = None
-        map(self.traceFunctionDecl, api.functions)
-        map(self.traceFunctionImpl, api.functions)
+        for function in api.getAllFunctions():
+            self.traceFunctionDecl(function)
+        for function in api.getAllFunctions():
+            self.traceFunctionImpl(function)
         print
 
         self.footer(api)
@@ -460,6 +439,7 @@ class Tracer:
         print
         print '#include "trace.hpp"'
         print
+        print 'static std::map<void *, void *> g_WrappedObjects;'
 
     def footer(self, api):
         pass
@@ -472,9 +452,14 @@ class Tracer:
                 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
             else:
                 print 'static const char ** _%s_args = NULL;' % (function.name,)
-            print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
+            print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, self.getFunctionSigId(), function.name, len(function.args), function.name)
             print
 
+    def getFunctionSigId(self):
+        id = Tracer.__id
+        Tracer.__id += 1
+        return id
+
     def isFunctionPublic(self, function):
         return True
 
@@ -508,20 +493,24 @@ class Tracer:
             for arg in function.args:
                 if not arg.output:
                     self.unwrapArg(function, arg)
+            for arg in function.args:
+                if not arg.output:
                     self.serializeArg(function, arg)
             print '    trace::localWriter.endEnter();'
         self.invokeFunction(function)
         if not function.internal:
             print '    trace::localWriter.beginLeave(_call);'
+            print '    if (%s) {' % self.wasFunctionSuccessful(function)
             for arg in function.args:
                 if arg.output:
                     self.serializeArg(function, arg)
                     self.wrapArg(function, arg)
+            print '    }'
             if function.type is not stdapi.Void:
                 self.serializeRet(function, "_result")
-            print '    trace::localWriter.endLeave();'
             if function.type is not stdapi.Void:
                 self.wrapRet(function, "_result")
+            print '    trace::localWriter.endLeave();'
 
     def invokeFunction(self, function, prefix='_', suffix=''):
         if function.type is stdapi.Void:
@@ -531,6 +520,13 @@ class Tracer:
         dispatch = prefix + function.name + suffix
         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
 
+    def wasFunctionSuccessful(self, function):
+        if function.type is stdapi.Void:
+            return 'true'
+        if str(function.type) == 'HRESULT':
+            return 'SUCCEEDED(_result)'
+        return 'true'
+
     def serializeArg(self, function, arg):
         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
         self.serializeArgValue(function, arg)
@@ -570,9 +566,6 @@ class Tracer:
     def wrapRet(self, function, instance):
         self.wrapValue(function.type, instance)
 
-    def unwrapRet(self, function, instance):
-        self.unwrapValue(function.type, instance)
-
     def needsWrapping(self, type):
         visitor = WrapDecider()
         visitor.visit(type)
@@ -600,9 +593,11 @@ class Tracer:
     def declareWrapperInterface(self, interface):
         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
         print "{"
-        print "public:"
+        print "private:"
         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
+        print "public:"
+        print "    static %s* _Create(const char *functionName, %s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
         print
         for method in interface.iterMethods():
             print "    " + method.prototype() + ";"
@@ -610,6 +605,11 @@ class Tracer:
         #print "private:"
         for type, name, value in self.enumWrapperInterfaceVariables(interface):
             print '    %s %s;' % (type, name)
+        for i in range(64):
+            print r'    virtual void _dummy%i(void) const {' % i
+            print r'        os::log("error: %s: unexpected virtual method\n");' % interface.name
+            print r'        os::abort();'
+            print r'    }'
         print "};"
         print
 
@@ -617,17 +617,45 @@ class Tracer:
         return [
             ("DWORD", "m_dwMagic", "0xd8365d6c"),
             ("%s *" % interface.name, "m_pInstance", "pInstance"),
+            ("void *", "m_pVtbl", "*(void **)pInstance"),
+            ("UINT", "m_NumMethods", len(list(interface.iterBaseMethods()))),
         ] 
 
     def implementWrapperInterface(self, interface):
         self.interface = interface
 
+        # Private constructor
         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
         for type, name, value in self.enumWrapperInterfaceVariables(interface):
             print '    %s = %s;' % (name, value)
         print '}'
         print
+
+        # Public constructor
+        print '%s *%s::_Create(const char *functionName, %s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
+        print r'    std::map<void *, void *>::const_iterator it = g_WrappedObjects.find(pInstance);'
+        print r'    if (it != g_WrappedObjects.end()) {'
+        print r'        Wrap%s *pWrapper = (Wrap%s *)it->second;' % (interface.name, interface.name)
+        print r'        assert(pWrapper);'
+        print r'        assert(pWrapper->m_dwMagic == 0xd8365d6c);'
+        print r'        assert(pWrapper->m_pInstance == pInstance);'
+        print r'        if (pWrapper->m_pVtbl == *(void **)pInstance &&'
+        print r'            pWrapper->m_NumMethods >= %s) {' % len(list(interface.iterBaseMethods()))
+        #print r'            os::log("%s: fetched pvObj=%p pWrapper=%p pVtbl=%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);'
+        print r'            return pWrapper;'
+        print r'        }'
+        print r'    }'
+        print r'    Wrap%s *pWrapper = new Wrap%s(pInstance);' % (interface.name, interface.name)
+        #print r'    os::log("%%s: created %s pvObj=%%p pWrapper=%%p pVtbl=%%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' % interface.name
+        print r'    g_WrappedObjects[pInstance] = pWrapper;'
+        print r'    return pWrapper;'
+        print '}'
+        print
+
+        # Destructor
         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
+        #print r'        os::log("%s::Release: deleted pvObj=%%p pWrapper=%%p pVtbl=%%p\n", m_pInstance, this, m_pVtbl);' % interface.name
+        print r'        g_WrappedObjects.erase(m_pInstance);'
         print '}'
         print
         
@@ -639,6 +667,10 @@ class Tracer:
 
     def implementWrapperInterfaceMethod(self, interface, base, method):
         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
+
+        if False:
+            print r'    os::log("%%s(%%p -> %%p)\n", "%s", this, m_pInstance);' % (getWrapperInterfaceName(interface) + '::' + method.name)
+
         if method.type is not stdapi.Void:
             print '    %s _result;' % method.type
     
@@ -653,7 +685,7 @@ class Tracer:
         assert not method.internal
 
         print '    static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
-        print '    static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
+        print '    static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (self.getFunctionSigId(), interface.name + '::' + method.name, len(method.args) + 1)
 
         print '    %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
 
@@ -664,27 +696,34 @@ class Tracer:
         for arg in method.args:
             if not arg.output:
                 self.unwrapArg(method, arg)
+        for arg in method.args:
+            if not arg.output:
                 self.serializeArg(method, arg)
         print '    trace::localWriter.endEnter();'
         
         self.invokeMethod(interface, base, method)
 
         print '    trace::localWriter.beginLeave(_call);'
+
+        print '    if (%s) {' % self.wasFunctionSuccessful(method)
         for arg in method.args:
             if arg.output:
                 self.serializeArg(method, arg)
                 self.wrapArg(method, arg)
+        print '    }'
 
         if method.type is not stdapi.Void:
             self.serializeRet(method, '_result')
-        print '    trace::localWriter.endLeave();'
         if method.type is not stdapi.Void:
             self.wrapRet(method, '_result')
 
         if method.name == 'Release':
             assert method.type is not stdapi.Void
-            print '    if (!_result)'
-            print '        delete this;'
+            print r'    if (!_result) {'
+            print r'        delete this;'
+            print r'    }'
+        
+        print '    trace::localWriter.endLeave();'
 
     def implementIidWrapper(self, api):
         print r'static void'
@@ -703,7 +742,7 @@ class Tracer:
         else_ = ''
         for iface in api.getAllInterfaces():
             print r'    %sif (riid == IID_%s) {' % (else_, iface.name)
-            print r'        *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
+            print r'        *ppvObj = Wrap%s::_Create(functionName, (%s *) *ppvObj);' % (iface.name, iface.name)
             print r'    }'
             else_ = 'else '
         print r'    %s{' % else_