]> git.cworth.org Git - apitrace/blobdiff - wrappers/trace.py
Handle REFIIDs on functions too.
[apitrace] / wrappers / trace.py
index 23a537a1e669d56359845b28bcab2c0c135ff089..706d3479dcb15473d5254b1ad56e6a840e179571 100644 (file)
@@ -371,6 +371,8 @@ class Tracer:
         print
 
         # Function wrappers
+        self.interface = None
+        self.base = None
         map(self.traceFunctionDecl, api.functions)
         map(self.traceFunctionImpl, api.functions)
         print
@@ -446,6 +448,19 @@ class Tracer:
 
     def wrapArg(self, function, arg):
         assert not isinstance(arg.type, stdapi.ObjPointer)
+
+        from specs.winapi import REFIID
+        riid = None
+        for other_arg in function.args:
+            if not other_arg.output and other_arg.type is REFIID:
+                riid = other_arg
+        if riid is not None and isinstance(arg.type, stdapi.Pointer):
+            assert isinstance(arg.type.type, stdapi.ObjPointer)
+            obj_type = arg.type.type.type
+            assert obj_type is stdapi.Void
+            self.wrapIid(function, riid, arg)
+            return
+
         self.wrapValue(arg.type, arg.name)
 
     def unwrapArg(self, function, arg):
@@ -501,6 +516,8 @@ class Tracer:
         print "    %s * m_pInstance;" % (interface.name,)
 
     def implementWrapperInterface(self, interface):
+        self.interface = interface
+
         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
         print '    m_dwMagic = 0xd8365d6c;'
         print '    m_pInstance = pInstance;'
@@ -509,8 +526,11 @@ class Tracer:
         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
         print '}'
         print
+        
         for base, method in interface.iterBaseMethods():
+            self.base = base
             self.implementWrapperInterfaceMethod(interface, base, method)
+
         print
 
     def implementWrapperInterfaceMethod(self, interface, base, method):
@@ -532,16 +552,10 @@ class Tracer:
         print '    trace::localWriter.beginArg(0);'
         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
         print '    trace::localWriter.endArg();'
-
-        from specs.winapi import REFIID
-
-        riid = None
         for arg in method.args:
             if not arg.output:
                 self.unwrapArg(method, arg)
                 self.serializeArg(method, arg)
-                if arg.type is REFIID:
-                    riid = arg
         print '    trace::localWriter.endEnter();'
         
         self.invokeMethod(interface, base, method)
@@ -551,13 +565,6 @@ class Tracer:
             if arg.output:
                 self.serializeArg(method, arg)
                 self.wrapArg(method, arg)
-                if riid is not None and isinstance(arg.type, stdapi.Pointer):
-                    assert isinstance(arg.type.type, stdapi.ObjPointer)
-                    obj_type = arg.type.type.type
-                    assert obj_type is stdapi.Void
-                    self.wrapIid(interface, method, riid, arg)
-                    riid = None
-        assert riid is None
 
         if method.type is not stdapi.Void:
             print '    trace::localWriter.beginReturn();'
@@ -570,29 +577,35 @@ class Tracer:
             print '    if (!__result)'
             print '        delete this;'
 
-    def wrapIid(self, interface, method, riid, out):
-            print '    if (%s && *%s) {' % (out.name, out.name)
-            print '        if (*%s == m_pInstance) {' % (out.name,)
-            print '            *%s = this;' % (out.name,)
-            print '        }'
-            for iface in self.api.getAllInterfaces():
-                print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
-                print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
-                print r'        }'
-            print r'        else {'
-            print r'            os::log("apitrace: warning: %s::%s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
-            print r'                    "%s", "%s",' % (interface.name, method.name)
-            print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
-            print r'                    %s.Data4[0],' % (riid.name,)
-            print r'                    %s.Data4[1],' % (riid.name,)
-            print r'                    %s.Data4[2],' % (riid.name,)
-            print r'                    %s.Data4[3],' % (riid.name,)
-            print r'                    %s.Data4[4],' % (riid.name,)
-            print r'                    %s.Data4[5],' % (riid.name,)
-            print r'                    %s.Data4[6],' % (riid.name,)
-            print r'                    %s.Data4[7]);' % (riid.name,)
+    def wrapIid(self, function, riid, out):
+        print r'    if (%s && *%s) {' % (out.name, out.name)
+        function_name = function.name
+        else_ = ''
+        if self.interface is not None:
+            function_name = self.interface.name + '::' + function_name
+            print r'        %sif (*%s == m_pInstance) {' % (else_, out.name,)
+            print r'            *%s = this;' % (out.name,)
             print r'        }'
-            print '    }'
+            else_ = 'else '
+        for iface in self.api.getAllInterfaces():
+            print r'        %sif (%s == IID_%s) {' % (else_, riid.name, iface.name)
+            print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
+            print r'        }'
+            else_ = 'else '
+        print r'        %s{' % else_
+        print r'            os::log("apitrace: warning: %s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
+        print r'                    "%s",' % (function_name)
+        print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
+        print r'                    %s.Data4[0],' % (riid.name,)
+        print r'                    %s.Data4[1],' % (riid.name,)
+        print r'                    %s.Data4[2],' % (riid.name,)
+        print r'                    %s.Data4[3],' % (riid.name,)
+        print r'                    %s.Data4[4],' % (riid.name,)
+        print r'                    %s.Data4[5],' % (riid.name,)
+        print r'                    %s.Data4[6],' % (riid.name,)
+        print r'                    %s.Data4[7]);' % (riid.name,)
+        print r'        }'
+        print r'    }'
 
     def invokeMethod(self, interface, base, method):
         if method.type is stdapi.Void: