]> git.cworth.org Git - apitrace/blobdiff - wrappers/trace.py
trace: Unwrap all args before serializing them.
[apitrace] / wrappers / trace.py
index 30668365ee3fc27eded75d3f4111f64d0951b0e7..61a2cb95fcf8249c13719464a36bee5891f8f76a 100644 (file)
@@ -39,6 +39,7 @@ def getWrapperInterfaceName(interface):
     return "Wrap" + interface.expr
 
 
+
 class ComplexValueSerializer(stdapi.OnceVisitor):
     '''Type visitors which generates serialization functions for
     complex types.
@@ -63,21 +64,20 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
         self.visit(const.type)
 
     def visitStruct(self, struct):
-        for type, name in struct.members:
-            self.visit(type)
-        print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
-        print '    static const char * members[%u] = {' % (len(struct.members),)
+        print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
         for type, name,  in struct.members:
-            print '        "%s",' % (name,)
-        print '    };'
-        print '    static const trace::StructSig sig = {'
-        print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
-        print '    };'
-        print '    trace::localWriter.beginStruct(&sig);'
-        for type, name in struct.members:
-            self.serializer.visit(type, 'value.%s' % (name,))
-        print '    trace::localWriter.endStruct();'
-        print '}'
+            if name is None:
+                print '    "",'
+            else:
+                print '    "%s",' % (name,)
+        print '};'
+        print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
+        if struct.name is None:
+            structName = '""'
+        else:
+            structName = '"%s"' % struct.name
+        print '    %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag)
+        print '};'
         print
 
     def visitArray(self, array):
@@ -89,22 +89,22 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
     def visitEnum(self, enum):
         print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
         for value in enum.values:
-            print '   {"%s", %s},' % (value, value)
+            print '    {"%s", %s},' % (value, value)
         print '};'
         print
         print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
-        print '   %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
+        print '    %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
         print '};'
         print
 
     def visitBitmask(self, bitmask):
         print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
         for value in bitmask.values:
-            print '   {"%s", %s},' % (value, value)
+            print '    {"%s", %s},' % (value, value)
         print '};'
         print
         print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
-        print '   %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
+        print '    %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
         print '};'
         print
 
@@ -150,7 +150,7 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
         print
 
 
-class ValueSerializer(stdapi.Visitor):
+class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin):
     '''Visitor which generates code to serialize any type.
     
     Simple types are serialized inline here, whereas the serialization of
@@ -172,7 +172,7 @@ class ValueSerializer(stdapi.Visitor):
             # reinterpret_cast is necessary for GLubyte * <=> char *
             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
         if string.length is not None:
-            length = ', %s' % string.length
+            length = ', %s' % self.expand(string.length)
         else:
             length = ''
         print '    trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
@@ -181,17 +181,21 @@ class ValueSerializer(stdapi.Visitor):
         self.visit(const.type, instance)
 
     def visitStruct(self, struct, instance):
-        print '    _write__%s(%s);' % (struct.tag, instance)
+        print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
+        for member in struct.members:
+            self.visitMember(member, instance)
+        print '    trace::localWriter.endStruct();'
 
     def visitArray(self, array, instance):
         length = '_c' + array.type.tag
         index = '_i' + array.type.tag
+        array_length = self.expand(array.length)
         print '    if (%s) {' % instance
-        print '        size_t %s = %s > 0 ? %s : 0;' % (length, array.length, array.length)
+        print '        size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
         print '        trace::localWriter.beginArray(%s);' % length
         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
         print '            trace::localWriter.beginElement();'
-        self.visit(array.type, '(%s)[%s]' % (instance, index))
+        self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
         print '            trace::localWriter.endElement();'
         print '        }'
         print '        trace::localWriter.endArray();'
@@ -200,7 +204,7 @@ class ValueSerializer(stdapi.Visitor):
         print '    }'
 
     def visitBlob(self, blob, instance):
-        print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
+        print '    trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
 
     def visitEnum(self, enum, instance):
         print '    trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
@@ -247,12 +251,21 @@ class ValueSerializer(stdapi.Visitor):
         if polymorphic.contextLess:
             print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
         else:
-            print '    switch (%s) {' % polymorphic.switchExpr
+            switchExpr = self.expand(polymorphic.switchExpr)
+            print '    switch (%s) {' % switchExpr
             for cases, type in polymorphic.iterSwitch():
                 for case in cases:
                     print '    %s:' % case
-                self.visit(type, 'static_cast<%s>(%s)' % (type, instance))
+                caseInstance = instance
+                if type.expr is not None:
+                    caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
+                self.visit(type, caseInstance)
                 print '        break;'
+            if polymorphic.defaultType is None:
+                print r'    default:'
+                print r'        os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,)
+                print r'        trace::localWriter.writeNull();'
+                print r'        break;'
             print '    }'
 
 
@@ -272,7 +285,7 @@ class WrapDecider(stdapi.Traverser):
         self.needsWrapping = True
 
 
-class ValueWrapper(stdapi.Traverser):
+class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin):
     '''Type visitor which will generate the code to wrap an instance.
     
     Wrapping is necessary mostly for interfaces, however interface pointers can
@@ -280,13 +293,14 @@ class ValueWrapper(stdapi.Traverser):
     '''
 
     def visitStruct(self, struct, instance):
-        for type, name in struct.members:
-            self.visit(type, "(%s).%s" % (instance, name))
+        for member in struct.members:
+            self.visitMember(member, instance)
 
     def visitArray(self, array, instance):
+        array_length = self.expand(array.length)
         print "    if (%s) {" % instance
-        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
-        self.visit(array.type, instance + "[_i]")
+        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
+        self.visitElement('_i', array.type, instance + "[_i]")
         print "        }"
         print "    }"
 
@@ -299,6 +313,8 @@ class ValueWrapper(stdapi.Traverser):
         elem_type = pointer.type.mutable()
         if isinstance(elem_type, stdapi.Interface):
             self.visitInterfacePointer(elem_type, instance)
+        elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
+            self.visitInterfacePointer(elem_type.type, instance)
         else:
             self.visitPointer(pointer, instance)
     
@@ -307,7 +323,7 @@ class ValueWrapper(stdapi.Traverser):
 
     def visitInterfacePointer(self, interface, instance):
         print "    if (%s) {" % instance
-        print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
+        print "        %s = %s::_Create(__FUNCTION__, %s);" % (instance, getWrapperInterfaceName(interface), instance)
         print "    }"
     
     def visitPolymorphic(self, type, instance):
@@ -320,13 +336,31 @@ class ValueUnwrapper(ValueWrapper):
 
     allocated = False
 
+    def visitStruct(self, struct, instance):
+        if not self.allocated:
+            # Argument is constant. We need to create a non const
+            print '    {'
+            print "        %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
+            print '        *_t = %s;' % (instance,)
+            assert instance.startswith('*')
+            print '        %s = _t;' % (instance[1:],)
+            instance = '*_t'
+            self.allocated = True
+            try:
+                return ValueWrapper.visitStruct(self, struct, instance)
+            finally:
+                print '    }'
+        else:
+            return ValueWrapper.visitStruct(self, struct, instance)
+
     def visitArray(self, array, instance):
         if self.allocated or isinstance(instance, stdapi.Interface):
             return ValueWrapper.visitArray(self, array, instance)
+        array_length = self.expand(array.length)
         elem_type = array.type.mutable()
-        print "    if (%s && %s) {" % (instance, array.length)
-        print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
-        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
+        print "    if (%s && %s) {" % (instance, array_length)
+        print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
+        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
         print "            _t[_i] = %s[_i];" % instance 
         self.allocated = True
         self.visit(array.type, "_t[_i]")
@@ -348,6 +382,9 @@ class ValueUnwrapper(ValueWrapper):
 class Tracer:
     '''Base class to orchestrate the code generation of API tracing.'''
 
+    # 0-3 are reserved to memcpy, malloc, free, and realloc
+    __id = 4
+
     def __init__(self):
         self.api = None
 
@@ -365,8 +402,9 @@ class Tracer:
         self.header(api)
 
         # Includes
-        for header in api.headers:
-            print header
+        for module in api.modules:
+            for header in module.headers:
+                print header
         print
 
         # Generate the serializer functions
@@ -381,8 +419,10 @@ class Tracer:
         # Function wrappers
         self.interface = None
         self.base = None
-        map(self.traceFunctionDecl, api.functions)
-        map(self.traceFunctionImpl, api.functions)
+        for function in api.getAllFunctions():
+            self.traceFunctionDecl(function)
+        for function in api.getAllFunctions():
+            self.traceFunctionImpl(function)
         print
 
         self.footer(api)
@@ -399,6 +439,7 @@ class Tracer:
         print
         print '#include "trace.hpp"'
         print
+        print 'static std::map<void *, void *> g_WrappedObjects;'
 
     def footer(self, api):
         pass
@@ -411,9 +452,14 @@ class Tracer:
                 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
             else:
                 print 'static const char ** _%s_args = NULL;' % (function.name,)
-            print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
+            print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, self.getFunctionSigId(), function.name, len(function.args), function.name)
             print
 
+    def getFunctionSigId(self):
+        id = Tracer.__id
+        Tracer.__id += 1
+        return id
+
     def isFunctionPublic(self, function):
         return True
 
@@ -447,20 +493,24 @@ class Tracer:
             for arg in function.args:
                 if not arg.output:
                     self.unwrapArg(function, arg)
+            for arg in function.args:
+                if not arg.output:
                     self.serializeArg(function, arg)
             print '    trace::localWriter.endEnter();'
         self.invokeFunction(function)
         if not function.internal:
             print '    trace::localWriter.beginLeave(_call);'
+            print '    if (%s) {' % self.wasFunctionSuccessful(function)
             for arg in function.args:
                 if arg.output:
                     self.serializeArg(function, arg)
                     self.wrapArg(function, arg)
+            print '    }'
             if function.type is not stdapi.Void:
                 self.serializeRet(function, "_result")
-            print '    trace::localWriter.endLeave();'
             if function.type is not stdapi.Void:
                 self.wrapRet(function, "_result")
+            print '    trace::localWriter.endLeave();'
 
     def invokeFunction(self, function, prefix='_', suffix=''):
         if function.type is stdapi.Void:
@@ -470,6 +520,13 @@ class Tracer:
         dispatch = prefix + function.name + suffix
         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
 
+    def wasFunctionSuccessful(self, function):
+        if function.type is stdapi.Void:
+            return 'true'
+        if str(function.type) == 'HRESULT':
+            return 'SUCCEEDED(_result)'
+        return 'true'
+
     def serializeArg(self, function, arg):
         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
         self.serializeArgValue(function, arg)
@@ -509,9 +566,6 @@ class Tracer:
     def wrapRet(self, function, instance):
         self.wrapValue(function.type, instance)
 
-    def unwrapRet(self, function, instance):
-        self.unwrapValue(function.type, instance)
-
     def needsWrapping(self, type):
         visitor = WrapDecider()
         visitor.visit(type)
@@ -539,9 +593,11 @@ class Tracer:
     def declareWrapperInterface(self, interface):
         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
         print "{"
-        print "public:"
+        print "private:"
         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
+        print "public:"
+        print "    static %s* _Create(const char *functionName, %s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
         print
         for method in interface.iterMethods():
             print "    " + method.prototype() + ";"
@@ -549,6 +605,11 @@ class Tracer:
         #print "private:"
         for type, name, value in self.enumWrapperInterfaceVariables(interface):
             print '    %s %s;' % (type, name)
+        for i in range(64):
+            print r'    virtual void _dummy%i(void) const {' % i
+            print r'        os::log("error: %s: unexpected virtual method\n");' % interface.name
+            print r'        os::abort();'
+            print r'    }'
         print "};"
         print
 
@@ -556,17 +617,45 @@ class Tracer:
         return [
             ("DWORD", "m_dwMagic", "0xd8365d6c"),
             ("%s *" % interface.name, "m_pInstance", "pInstance"),
+            ("void *", "m_pVtbl", "*(void **)pInstance"),
+            ("UINT", "m_NumMethods", len(list(interface.iterBaseMethods()))),
         ] 
 
     def implementWrapperInterface(self, interface):
         self.interface = interface
 
+        # Private constructor
         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
         for type, name, value in self.enumWrapperInterfaceVariables(interface):
             print '    %s = %s;' % (name, value)
         print '}'
         print
+
+        # Public constructor
+        print '%s *%s::_Create(const char *functionName, %s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
+        print r'    std::map<void *, void *>::const_iterator it = g_WrappedObjects.find(pInstance);'
+        print r'    if (it != g_WrappedObjects.end()) {'
+        print r'        Wrap%s *pWrapper = (Wrap%s *)it->second;' % (interface.name, interface.name)
+        print r'        assert(pWrapper);'
+        print r'        assert(pWrapper->m_dwMagic == 0xd8365d6c);'
+        print r'        assert(pWrapper->m_pInstance == pInstance);'
+        print r'        if (pWrapper->m_pVtbl == *(void **)pInstance &&'
+        print r'            pWrapper->m_NumMethods >= %s) {' % len(list(interface.iterBaseMethods()))
+        #print r'            os::log("%s: fetched pvObj=%p pWrapper=%p pVtbl=%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);'
+        print r'            return pWrapper;'
+        print r'        }'
+        print r'    }'
+        print r'    Wrap%s *pWrapper = new Wrap%s(pInstance);' % (interface.name, interface.name)
+        #print r'    os::log("%%s: created %s pvObj=%%p pWrapper=%%p pVtbl=%%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' % interface.name
+        print r'    g_WrappedObjects[pInstance] = pWrapper;'
+        print r'    return pWrapper;'
+        print '}'
+        print
+
+        # Destructor
         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
+        #print r'        os::log("%s::Release: deleted pvObj=%%p pWrapper=%%p pVtbl=%%p\n", m_pInstance, this, m_pVtbl);' % interface.name
+        print r'        g_WrappedObjects.erase(m_pInstance);'
         print '}'
         print
         
@@ -578,6 +667,10 @@ class Tracer:
 
     def implementWrapperInterfaceMethod(self, interface, base, method):
         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
+
+        if False:
+            print r'    os::log("%%s(%%p -> %%p)\n", "%s", this, m_pInstance);' % (getWrapperInterfaceName(interface) + '::' + method.name)
+
         if method.type is not stdapi.Void:
             print '    %s _result;' % method.type
     
@@ -592,7 +685,7 @@ class Tracer:
         assert not method.internal
 
         print '    static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
-        print '    static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
+        print '    static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (self.getFunctionSigId(), interface.name + '::' + method.name, len(method.args) + 1)
 
         print '    %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
 
@@ -603,27 +696,34 @@ class Tracer:
         for arg in method.args:
             if not arg.output:
                 self.unwrapArg(method, arg)
+        for arg in method.args:
+            if not arg.output:
                 self.serializeArg(method, arg)
         print '    trace::localWriter.endEnter();'
         
         self.invokeMethod(interface, base, method)
 
         print '    trace::localWriter.beginLeave(_call);'
+
+        print '    if (%s) {' % self.wasFunctionSuccessful(method)
         for arg in method.args:
             if arg.output:
                 self.serializeArg(method, arg)
                 self.wrapArg(method, arg)
+        print '    }'
 
         if method.type is not stdapi.Void:
             self.serializeRet(method, '_result')
-        print '    trace::localWriter.endLeave();'
         if method.type is not stdapi.Void:
             self.wrapRet(method, '_result')
 
         if method.name == 'Release':
             assert method.type is not stdapi.Void
-            print '    if (!_result)'
-            print '        delete this;'
+            print r'    if (!_result) {'
+            print r'        delete this;'
+            print r'    }'
+        
+        print '    trace::localWriter.endLeave();'
 
     def implementIidWrapper(self, api):
         print r'static void'
@@ -642,7 +742,7 @@ class Tracer:
         else_ = ''
         for iface in api.getAllInterfaces():
             print r'    %sif (riid == IID_%s) {' % (else_, iface.name)
-            print r'        *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
+            print r'        *ppvObj = Wrap%s::_Create(functionName, (%s *) *ppvObj);' % (iface.name, iface.name)
             print r'    }'
             else_ = 'else '
         print r'    %s{' % else_