]> git.cworth.org Git - apitrace/blobdiff - trace.py
Merge branch 'd3dretrace'
[apitrace] / trace.py
index 8fd5840736f85208ab1555191d8614fef1762f76..4d6d97450501b12280740aecbd06b45a9d22979e 100644 (file)
--- a/trace.py
+++ b/trace.py
@@ -108,12 +108,18 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
     def visitIntPointer(self, pointer):
         pass
 
+    def visitObjPointer(self, pointer):
+        self.visit(pointer.type)
+
     def visitLinearPointer(self, pointer):
         self.visit(pointer.type)
 
     def visitHandle(self, handle):
         self.visit(handle.type)
 
+    def visitReference(self, reference):
+        self.visit(reference.type)
+
     def visitAlias(self, alias):
         self.visit(alias.type)
 
@@ -208,9 +214,15 @@ class ValueSerializer(stdapi.Visitor):
     def visitIntPointer(self, pointer, instance):
         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
 
+    def visitObjPointer(self, pointer, instance):
+        print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
+
     def visitLinearPointer(self, pointer, instance):
         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
 
+    def visitReference(self, reference, instance):
+        self.visit(reference.type, instance)
+
     def visitHandle(self, handle, instance):
         self.visit(handle.type, instance)
 
@@ -271,9 +283,17 @@ class ValueWrapper(stdapi.Visitor):
     def visitIntPointer(self, pointer, instance):
         pass
 
+    def visitObjPointer(self, pointer, instance):
+        print "    if (%s) {" % instance
+        self.visit(pointer.type, "*" + instance)
+        print "    }"
+    
     def visitLinearPointer(self, pointer, instance):
         pass
 
+    def visitReference(self, reference, instance):
+        self.visit(reference.type, instance)
+    
     def visitHandle(self, handle, instance):
         self.visit(handle.type, instance)
 
@@ -478,23 +498,23 @@ class Tracer:
         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
         print '}'
         print
-        for method in interface.iterMethods():
-            self.implementWrapperInterfaceMethod(interface, method)
+        for base, method in interface.iterBaseMethods():
+            self.implementWrapperInterfaceMethod(interface, base, method)
         print
 
-    def implementWrapperInterfaceMethod(self, interface, method):
+    def implementWrapperInterfaceMethod(self, interface, base, method):
         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
         if method.type is not stdapi.Void:
             print '    %s __result;' % method.type
     
-        self.implementWrapperInterfaceMethodBody(interface, method)
+        self.implementWrapperInterfaceMethodBody(interface, base, method)
     
         if method.type is not stdapi.Void:
             print '    return __result;'
         print '}'
         print
 
-    def implementWrapperInterfaceMethodBody(self, interface, method):
+    def implementWrapperInterfaceMethodBody(self, interface, base, method):
         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
@@ -514,7 +534,7 @@ class Tracer:
                     riid = arg
         print '    trace::localWriter.endEnter();'
         
-        self.invokeMethod(interface, method)
+        self.invokeMethod(interface, base, method)
 
         print '    trace::localWriter.beginLeave(__call);'
         for arg in method.args:
@@ -523,7 +543,7 @@ class Tracer:
                 self.wrapArg(method, arg)
                 if riid is not None and isinstance(arg.type, Pointer):
                     if isinstance(arg.type.type, Opaque):
-                        self.wrapIid(riid, arg)
+                        self.wrapIid(interface, method, riid, arg)
                     else:
                         assert isinstance(arg.type.type, Pointer)
                         assert isinstance(arg.type.type.type, Interface)
@@ -539,18 +559,18 @@ class Tracer:
             print '    if (!__result)'
             print '        delete this;'
 
-    def wrapIid(self, riid, out):
+    def wrapIid(self, interface, method, riid, out):
             print '    if (%s && *%s) {' % (out.name, out.name)
             print '        if (*%s == m_pInstance) {' % (out.name,)
             print '            *%s = this;' % (out.name,)
             print '        }'
-            for iface in self.api.interfaces:
+            for iface in self.api.getAllInterfaces():
                 print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
                 print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
                 print r'        }'
             print r'        else {'
-            print r'            os::log("apitrace: warning: %s: unknown REFIID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
-            print r'                    __FUNCTION__,'
+            print r'            os::log("apitrace: warning: %s::%s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
+            print r'                    "%s", "%s",' % (interface.name, method.name)
             print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
             print r'                    %s.Data4[0],' % (riid.name,)
             print r'                    %s.Data4[1],' % (riid.name,)
@@ -563,12 +583,12 @@ class Tracer:
             print r'        }'
             print '    }'
 
-    def invokeMethod(self, interface, method):
+    def invokeMethod(self, interface, base, method):
         if method.type is stdapi.Void:
             result = ''
         else:
             result = '__result = '
-        print '    %sm_pInstance->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
+        print '    %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args]))
     
     def emit_memcpy(self, dest, src, length):
         print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'