]> git.cworth.org Git - apitrace/blobdiff - wrappers/trace.py
Cleanup a bit more IID handling.
[apitrace] / wrappers / trace.py
index e2e11ec36c07433329fc2e69e1f8c15c94efcb33..eb9e4affa213a7a18989d0415c4541f90dba9044 100644 (file)
@@ -245,90 +245,87 @@ class ValueSerializer(stdapi.Visitor):
         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
 
 
-class ValueWrapper(stdapi.Visitor):
-    '''Type visitor which will generate the code to wrap an instance.
+class WrapDecider(stdapi.Traverser):
+    '''Type visitor which will decide wheter this type will need wrapping or not.
     
-    Wrapping is necessary mostly for interfaces, however interface pointers can
-    appear anywhere inside complex types.
+    For complex types (arrays, structures), we need to know this before hand.
     '''
 
-    def visitVoid(self, type, instance):
-        raise NotImplementedError
+    def __init__(self):
+        self.needsWrapping = False
 
-    def visitLiteral(self, type, instance):
+    def visitLinearPointer(self, void):
         pass
 
-    def visitString(self, type, instance):
-        pass
+    def visitInterface(self, interface):
+        self.needsWrapping = True
 
-    def visitConst(self, type, instance):
-        pass
+
+class ValueWrapper(stdapi.Traverser):
+    '''Type visitor which will generate the code to wrap an instance.
+    
+    Wrapping is necessary mostly for interfaces, however interface pointers can
+    appear anywhere inside complex types.
+    '''
 
     def visitStruct(self, struct, instance):
         for type, name in struct.members:
             self.visit(type, "(%s).%s" % (instance, name))
 
     def visitArray(self, array, instance):
-        # XXX: actually it is possible to return an array of pointers
-        pass
-
-    def visitBlob(self, blob, instance):
-        pass
-
-    def visitEnum(self, enum, instance):
-        pass
-
-    def visitBitmask(self, bitmask, instance):
-        pass
-
-    def visitPointer(self, pointer, instance):
         print "    if (%s) {" % instance
-        self.visit(pointer.type, "*" + instance)
+        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
+        self.visit(array.type, instance + "[_i]")
+        print "        }"
         print "    }"
-    
-    def visitIntPointer(self, pointer, instance):
-        pass
 
-    def visitObjPointer(self, pointer, instance):
+    def visitPointer(self, pointer, instance):
         print "    if (%s) {" % instance
         self.visit(pointer.type, "*" + instance)
         print "    }"
     
-    def visitLinearPointer(self, pointer, instance):
-        pass
-
-    def visitReference(self, reference, instance):
-        self.visit(reference.type, instance)
-    
-    def visitHandle(self, handle, instance):
-        self.visit(handle.type, instance)
-
-    def visitAlias(self, alias, instance):
-        self.visit(alias.type, instance)
-
-    def visitOpaque(self, opaque, instance):
-        pass
+    def visitObjPointer(self, pointer, instance):
+        elem_type = pointer.type.mutable()
+        if isinstance(elem_type, stdapi.Interface):
+            self.visitInterfacePointer(elem_type, instance)
+        else:
+            self.visitPointer(self, pointer, instance)
     
     def visitInterface(self, interface, instance):
-        assert instance.startswith('*')
-        instance = instance[1:]
+        raise NotImplementedError
+
+    def visitInterfacePointer(self, interface, instance):
         print "    if (%s) {" % instance
         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
         print "    }"
     
     def visitPolymorphic(self, type, instance):
         # XXX: There might be polymorphic values that need wrapping in the future
-        pass
+        raise NotImplementedError
 
 
 class ValueUnwrapper(ValueWrapper):
     '''Reverse of ValueWrapper.'''
 
-    def visitInterface(self, interface, instance):
-        assert instance.startswith('*')
-        instance = instance[1:]
+    allocated = False
+
+    def visitArray(self, array, instance):
+        if self.allocated or isinstance(instance, stdapi.Interface):
+            return ValueWrapper.visitArray(self, array, instance)
+        elem_type = array.type.mutable()
+        print "    if (%s && %s) {" % (instance, array.length)
+        print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
+        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
+        print "            _t[_i] = %s[_i];" % instance 
+        self.allocated = True
+        self.visit(array.type, "_t[_i]")
+        print "        }"
+        print "        %s = _t;" % instance
+        print "    }"
+
+    def visitInterfacePointer(self, interface, instance):
         print r'    if (%s) {' % instance
-        print r'        %s *pWrapper = static_cast<%s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
+        print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
         print r'            %s = pWrapper->m_pInstance;' % (instance,)
         print r'        } else {'
@@ -368,12 +365,11 @@ class Tracer:
         print
 
         # Interfaces wrapers
-        interfaces = api.getAllInterfaces()
-        map(self.declareWrapperInterface, interfaces)
-        map(self.implementWrapperInterface, interfaces)
-        print
+        self.traceInterfaces(api)
 
         # Function wrappers
+        self.interface = None
+        self.base = None
         map(self.traceFunctionDecl, api.functions)
         map(self.traceFunctionImpl, api.functions)
         print
@@ -448,6 +444,20 @@ class Tracer:
         self.serializeValue(arg.type, arg.name)
 
     def wrapArg(self, function, arg):
+        assert not isinstance(arg.type, stdapi.ObjPointer)
+
+        from specs.winapi import REFIID
+        riid = None
+        for other_arg in function.args:
+            if not other_arg.output and other_arg.type is REFIID:
+                riid = other_arg
+        if riid is not None and isinstance(arg.type, stdapi.Pointer):
+            assert isinstance(arg.type.type, stdapi.ObjPointer)
+            obj_type = arg.type.type.type
+            assert obj_type is stdapi.Void
+            self.wrapIid(function, riid, arg)
+            return
+
         self.wrapValue(arg.type, arg.name)
 
     def unwrapArg(self, function, arg):
@@ -468,13 +478,29 @@ class Tracer:
     def unwrapRet(self, function, instance):
         self.unwrapValue(function.type, instance)
 
+    def needsWrapping(self, type):
+        visitor = WrapDecider()
+        visitor.visit(type)
+        return visitor.needsWrapping
+
     def wrapValue(self, type, instance):
-        visitor = ValueWrapper()
-        visitor.visit(type, instance)
+        if self.needsWrapping(type):
+            visitor = ValueWrapper()
+            visitor.visit(type, instance)
 
     def unwrapValue(self, type, instance):
-        visitor = ValueUnwrapper()
-        visitor.visit(type, instance)
+        if self.needsWrapping(type):
+            visitor = ValueUnwrapper()
+            visitor.visit(type, instance)
+
+    def traceInterfaces(self, api):
+        interfaces = api.getAllInterfaces()
+        if not interfaces:
+            return
+        map(self.declareWrapperInterface, interfaces)
+        self.implementIidWrapper(api)
+        map(self.implementWrapperInterface, interfaces)
+        print
 
     def declareWrapperInterface(self, interface):
         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
@@ -496,6 +522,8 @@ class Tracer:
         print "    %s * m_pInstance;" % (interface.name,)
 
     def implementWrapperInterface(self, interface):
+        self.interface = interface
+
         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
         print '    m_dwMagic = 0xd8365d6c;'
         print '    m_pInstance = pInstance;'
@@ -504,8 +532,11 @@ class Tracer:
         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
         print '}'
         print
+        
         for base, method in interface.iterBaseMethods():
+            self.base = base
             self.implementWrapperInterfaceMethod(interface, base, method)
+
         print
 
     def implementWrapperInterfaceMethod(self, interface, base, method):
@@ -527,17 +558,10 @@ class Tracer:
         print '    trace::localWriter.beginArg(0);'
         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
         print '    trace::localWriter.endArg();'
-
-        from specs.winapi import REFIID
-        from specs.stdapi import Pointer, Opaque, Interface
-
-        riid = None
         for arg in method.args:
             if not arg.output:
                 self.unwrapArg(method, arg)
                 self.serializeArg(method, arg)
-                if arg.type is REFIID:
-                    riid = arg
         print '    trace::localWriter.endEnter();'
         
         self.invokeMethod(interface, base, method)
@@ -547,12 +571,6 @@ class Tracer:
             if arg.output:
                 self.serializeArg(method, arg)
                 self.wrapArg(method, arg)
-                if riid is not None and isinstance(arg.type, Pointer):
-                    if isinstance(arg.type.type, Opaque):
-                        self.wrapIid(interface, method, riid, arg)
-                    else:
-                        assert isinstance(arg.type.type, Pointer)
-                        assert isinstance(arg.type.type.type, Interface)
 
         if method.type is not stdapi.Void:
             print '    trace::localWriter.beginReturn();'
@@ -565,29 +583,49 @@ class Tracer:
             print '    if (!__result)'
             print '        delete this;'
 
-    def wrapIid(self, interface, method, riid, out):
-            print '    if (%s && *%s) {' % (out.name, out.name)
-            print '        if (*%s == m_pInstance) {' % (out.name,)
-            print '            *%s = this;' % (out.name,)
-            print '        }'
-            for iface in self.api.getAllInterfaces():
-                print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
-                print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
-                print r'        }'
-            print r'        else {'
-            print r'            os::log("apitrace: warning: %s::%s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
-            print r'                    "%s", "%s",' % (interface.name, method.name)
-            print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
-            print r'                    %s.Data4[0],' % (riid.name,)
-            print r'                    %s.Data4[1],' % (riid.name,)
-            print r'                    %s.Data4[2],' % (riid.name,)
-            print r'                    %s.Data4[3],' % (riid.name,)
-            print r'                    %s.Data4[4],' % (riid.name,)
-            print r'                    %s.Data4[5],' % (riid.name,)
-            print r'                    %s.Data4[6],' % (riid.name,)
-            print r'                    %s.Data4[7]);' % (riid.name,)
+    def implementIidWrapper(self, api):
+        print r'static void'
+        print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
+        print r'    os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
+        print r'            functionName, reason,'
+        print r'            riid.Data1, riid.Data2, riid.Data3,'
+        print r'            riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
+        print r'}'
+        print 
+        print r'static void'
+        print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
+        print r'    if (!ppvObj || !*ppvObj) {'
+        print r'        return;'
+        print r'    }'
+        else_ = ''
+        for iface in api.getAllInterfaces():
+            print r'    %sif (riid == IID_%s) {' % (else_, iface.name)
+            print r'        *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
+            print r'    }'
+            else_ = 'else '
+        print r'    %s{' % else_
+        print r'        warnIID(functionName, riid, "unknown");'
+        print r'    }'
+        print r'}'
+        print
+
+    def wrapIid(self, function, riid, out):
+        print r'    if (%s && *%s) {' % (out.name, out.name)
+        functionName = function.name
+        else_ = ''
+        if self.interface is not None:
+            functionName = self.interface.name + '::' + functionName
+            print r'        %sif (*%s == m_pInstance) {' % (else_, out.name,)
+            print r'            *%s = this;' % (out.name,)
+            print r'            if (%s) {' % ' && '.join('%s != IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases()) 
+            print r'                warnIID("%s", %s, "unexpected");' % (functionName, riid.name)
+            print r'            }'
             print r'        }'
-            print '    }'
+            else_ = 'else '
+        print r'        %s{' % else_
+        print r'             wrapIID("%s", %s, %s);' % (functionName, riid.name, out.name) 
+        print r'        }'
+        print r'    }'
 
     def invokeMethod(self, interface, base, method):
         if method.type is stdapi.Void: