]> git.cworth.org Git - apitrace/blobdiff - wrappers/trace.py
Cleanup a bit more IID handling.
[apitrace] / wrappers / trace.py
index ef86f8cc09bc276be4bb62b8cb56f6087728de53..eb9e4affa213a7a18989d0415c4541f90dba9044 100644 (file)
@@ -254,9 +254,6 @@ class WrapDecider(stdapi.Traverser):
     def __init__(self):
         self.needsWrapping = False
 
-    def visitVoid(self, void):
-        raise NotImplementedError
-
     def visitLinearPointer(self, void):
         pass
 
@@ -288,13 +285,16 @@ class ValueWrapper(stdapi.Traverser):
         print "    }"
     
     def visitObjPointer(self, pointer, instance):
-        print "    if (%s) {" % instance
-        self.visit(pointer.type, "*" + instance)
-        print "    }"
+        elem_type = pointer.type.mutable()
+        if isinstance(elem_type, stdapi.Interface):
+            self.visitInterfacePointer(elem_type, instance)
+        else:
+            self.visitPointer(self, pointer, instance)
     
     def visitInterface(self, interface, instance):
-        assert instance.startswith('*')
-        instance = instance[1:]
+        raise NotImplementedError
+
+    def visitInterfacePointer(self, interface, instance):
         print "    if (%s) {" % instance
         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
         print "    }"
@@ -323,9 +323,7 @@ class ValueUnwrapper(ValueWrapper):
         print "        %s = _t;" % instance
         print "    }"
 
-    def visitInterface(self, interface, instance):
-        assert instance.startswith('*')
-        instance = instance[1:]
+    def visitInterfacePointer(self, interface, instance):
         print r'    if (%s) {' % instance
         print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
@@ -367,12 +365,11 @@ class Tracer:
         print
 
         # Interfaces wrapers
-        interfaces = api.getAllInterfaces()
-        map(self.declareWrapperInterface, interfaces)
-        map(self.implementWrapperInterface, interfaces)
-        print
+        self.traceInterfaces(api)
 
         # Function wrappers
+        self.interface = None
+        self.base = None
         map(self.traceFunctionDecl, api.functions)
         map(self.traceFunctionImpl, api.functions)
         print
@@ -447,6 +444,20 @@ class Tracer:
         self.serializeValue(arg.type, arg.name)
 
     def wrapArg(self, function, arg):
+        assert not isinstance(arg.type, stdapi.ObjPointer)
+
+        from specs.winapi import REFIID
+        riid = None
+        for other_arg in function.args:
+            if not other_arg.output and other_arg.type is REFIID:
+                riid = other_arg
+        if riid is not None and isinstance(arg.type, stdapi.Pointer):
+            assert isinstance(arg.type.type, stdapi.ObjPointer)
+            obj_type = arg.type.type.type
+            assert obj_type is stdapi.Void
+            self.wrapIid(function, riid, arg)
+            return
+
         self.wrapValue(arg.type, arg.name)
 
     def unwrapArg(self, function, arg):
@@ -482,6 +493,15 @@ class Tracer:
             visitor = ValueUnwrapper()
             visitor.visit(type, instance)
 
+    def traceInterfaces(self, api):
+        interfaces = api.getAllInterfaces()
+        if not interfaces:
+            return
+        map(self.declareWrapperInterface, interfaces)
+        self.implementIidWrapper(api)
+        map(self.implementWrapperInterface, interfaces)
+        print
+
     def declareWrapperInterface(self, interface):
         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
         print "{"
@@ -502,6 +522,8 @@ class Tracer:
         print "    %s * m_pInstance;" % (interface.name,)
 
     def implementWrapperInterface(self, interface):
+        self.interface = interface
+
         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
         print '    m_dwMagic = 0xd8365d6c;'
         print '    m_pInstance = pInstance;'
@@ -510,8 +532,11 @@ class Tracer:
         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
         print '}'
         print
+        
         for base, method in interface.iterBaseMethods():
+            self.base = base
             self.implementWrapperInterfaceMethod(interface, base, method)
+
         print
 
     def implementWrapperInterfaceMethod(self, interface, base, method):
@@ -533,17 +558,10 @@ class Tracer:
         print '    trace::localWriter.beginArg(0);'
         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
         print '    trace::localWriter.endArg();'
-
-        from specs.winapi import REFIID
-        from specs.stdapi import Pointer, Opaque, Interface
-
-        riid = None
         for arg in method.args:
             if not arg.output:
                 self.unwrapArg(method, arg)
                 self.serializeArg(method, arg)
-                if arg.type is REFIID:
-                    riid = arg
         print '    trace::localWriter.endEnter();'
         
         self.invokeMethod(interface, base, method)
@@ -553,12 +571,6 @@ class Tracer:
             if arg.output:
                 self.serializeArg(method, arg)
                 self.wrapArg(method, arg)
-                if riid is not None and isinstance(arg.type, Pointer):
-                    if isinstance(arg.type.type, Opaque):
-                        self.wrapIid(interface, method, riid, arg)
-                    else:
-                        assert isinstance(arg.type.type, Pointer)
-                        assert isinstance(arg.type.type.type, Interface)
 
         if method.type is not stdapi.Void:
             print '    trace::localWriter.beginReturn();'
@@ -571,29 +583,49 @@ class Tracer:
             print '    if (!__result)'
             print '        delete this;'
 
-    def wrapIid(self, interface, method, riid, out):
-            print '    if (%s && *%s) {' % (out.name, out.name)
-            print '        if (*%s == m_pInstance) {' % (out.name,)
-            print '            *%s = this;' % (out.name,)
-            print '        }'
-            for iface in self.api.getAllInterfaces():
-                print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
-                print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
-                print r'        }'
-            print r'        else {'
-            print r'            os::log("apitrace: warning: %s::%s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
-            print r'                    "%s", "%s",' % (interface.name, method.name)
-            print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
-            print r'                    %s.Data4[0],' % (riid.name,)
-            print r'                    %s.Data4[1],' % (riid.name,)
-            print r'                    %s.Data4[2],' % (riid.name,)
-            print r'                    %s.Data4[3],' % (riid.name,)
-            print r'                    %s.Data4[4],' % (riid.name,)
-            print r'                    %s.Data4[5],' % (riid.name,)
-            print r'                    %s.Data4[6],' % (riid.name,)
-            print r'                    %s.Data4[7]);' % (riid.name,)
+    def implementIidWrapper(self, api):
+        print r'static void'
+        print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
+        print r'    os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
+        print r'            functionName, reason,'
+        print r'            riid.Data1, riid.Data2, riid.Data3,'
+        print r'            riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
+        print r'}'
+        print 
+        print r'static void'
+        print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
+        print r'    if (!ppvObj || !*ppvObj) {'
+        print r'        return;'
+        print r'    }'
+        else_ = ''
+        for iface in api.getAllInterfaces():
+            print r'    %sif (riid == IID_%s) {' % (else_, iface.name)
+            print r'        *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
+            print r'    }'
+            else_ = 'else '
+        print r'    %s{' % else_
+        print r'        warnIID(functionName, riid, "unknown");'
+        print r'    }'
+        print r'}'
+        print
+
+    def wrapIid(self, function, riid, out):
+        print r'    if (%s && *%s) {' % (out.name, out.name)
+        functionName = function.name
+        else_ = ''
+        if self.interface is not None:
+            functionName = self.interface.name + '::' + functionName
+            print r'        %sif (*%s == m_pInstance) {' % (else_, out.name,)
+            print r'            *%s = this;' % (out.name,)
+            print r'            if (%s) {' % ' && '.join('%s != IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases()) 
+            print r'                warnIID("%s", %s, "unexpected");' % (functionName, riid.name)
+            print r'            }'
             print r'        }'
-            print '    }'
+            else_ = 'else '
+        print r'        %s{' % else_
+        print r'             wrapIID("%s", %s, %s);' % (functionName, riid.name, out.name) 
+        print r'        }'
+        print r'    }'
 
     def invokeMethod(self, interface, base, method):
         if method.type is stdapi.Void: