]> git.cworth.org Git - apitrace/blobdiff - trace.py
Don't abuse NotImplementedError.
[apitrace] / trace.py
index bb209e408f197d3046e60b41c2685ee58379a30c..4d6d97450501b12280740aecbd06b45a9d22979e 100644 (file)
--- a/trace.py
+++ b/trace.py
@@ -29,7 +29,7 @@
 import specs.stdapi as stdapi
 
 
-def interface_wrap_name(interface):
+def getWrapperInterfaceName(interface):
     return "Wrap" + interface.expr
 
 
@@ -105,9 +105,21 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
     def visitPointer(self, pointer):
         self.visit(pointer.type)
 
+    def visitIntPointer(self, pointer):
+        pass
+
+    def visitObjPointer(self, pointer):
+        self.visit(pointer.type)
+
+    def visitLinearPointer(self, pointer):
+        self.visit(pointer.type)
+
     def visitHandle(self, handle):
         self.visit(handle.type)
 
+    def visitReference(self, reference):
+        self.visit(reference.type)
+
     def visitAlias(self, alias):
         self.visit(alias.type)
 
@@ -115,19 +127,7 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
         pass
 
     def visitInterface(self, interface):
-        print "class %s : public %s " % (interface_wrap_name(interface), interface.name)
-        print "{"
-        print "public:"
-        print "    %s(%s * pInstance);" % (interface_wrap_name(interface), interface.name)
-        print "    virtual ~%s();" % interface_wrap_name(interface)
-        print
-        for method in interface.iterMethods():
-            print "    " + method.prototype() + ";"
-        print
-        #print "private:"
-        print "    %s * m_pInstance;" % (interface.name,)
-        print "};"
-        print
+        pass
 
     def visitPolymorphic(self, polymorphic):
         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
@@ -154,10 +154,20 @@ class ValueSerializer(stdapi.Visitor):
         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
 
     def visitString(self, string, instance):
+        if string.kind == 'String':
+            cast = 'const char *'
+        elif string.kind == 'WString':
+            cast = 'const wchar_t *'
+        else:
+            assert False
+        if cast != string.expr:
+            # reinterpret_cast is necessary for GLubyte * <=> char *
+            instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
         if string.length is not None:
-            print '    trace::localWriter.writeString((const char *)%s, %s);' % (instance, string.length)
+            length = ', %s' % string.length
         else:
-            print '    trace::localWriter.writeString((const char *)%s);' % instance
+            length = ''
+        print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
 
     def visitConst(self, const, instance):
         self.visit(const.type, instance)
@@ -201,6 +211,18 @@ class ValueSerializer(stdapi.Visitor):
         print '        trace::localWriter.writeNull();'
         print '    }'
 
+    def visitIntPointer(self, pointer, instance):
+        print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
+
+    def visitObjPointer(self, pointer, instance):
+        print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
+
+    def visitLinearPointer(self, pointer, instance):
+        print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
+
+    def visitReference(self, reference, instance):
+        self.visit(reference.type, instance)
+
     def visitHandle(self, handle, instance):
         self.visit(handle.type, instance)
 
@@ -257,7 +279,21 @@ class ValueWrapper(stdapi.Visitor):
         print "    if (%s) {" % instance
         self.visit(pointer.type, "*" + instance)
         print "    }"
+    
+    def visitIntPointer(self, pointer, instance):
+        pass
+
+    def visitObjPointer(self, pointer, instance):
+        print "    if (%s) {" % instance
+        self.visit(pointer.type, "*" + instance)
+        print "    }"
+    
+    def visitLinearPointer(self, pointer, instance):
+        pass
 
+    def visitReference(self, reference, instance):
+        self.visit(reference.type, instance)
+    
     def visitHandle(self, handle, instance):
         self.visit(handle.type, instance)
 
@@ -271,7 +307,7 @@ class ValueWrapper(stdapi.Visitor):
         assert instance.startswith('*')
         instance = instance[1:]
         print "    if (%s) {" % instance
-        print "        %s = new %s(%s);" % (instance, interface_wrap_name(interface), instance)
+        print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
         print "    }"
     
     def visitPolymorphic(self, type, instance):
@@ -285,9 +321,14 @@ class ValueUnwrapper(ValueWrapper):
     def visitInterface(self, interface, instance):
         assert instance.startswith('*')
         instance = instance[1:]
-        print "    if (%s) {" % instance
-        print "        %s = static_cast<%s *>(%s)->m_pInstance;" % (instance, interface_wrap_name(interface), instance)
-        print "    }"
+        print r'    if (%s) {' % instance
+        print r'        %s *pWrapper = static_cast<%s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
+        print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
+        print r'            %s = pWrapper->m_pInstance;' % (instance,)
+        print r'        } else {'
+        print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
+        print r'        }'
+        print r'    }'
 
 
 class Tracer:
@@ -315,14 +356,15 @@ class Tracer:
         print
 
         # Generate the serializer functions
-        types = api.all_types()
+        types = api.getAllTypes()
         visitor = ComplexValueSerializer(self.serializerFactory())
         map(visitor.visit, types)
         print
 
         # Interfaces wrapers
-        interfaces = [type for type in types if isinstance(type, stdapi.Interface)]
-        map(self.traceInterfaceImpl, interfaces)
+        interfaces = api.getAllInterfaces()
+        map(self.declareWrapperInterface, interfaces)
+        map(self.implementWrapperInterface, interfaces)
         print
 
         # Function wrappers
@@ -428,76 +470,138 @@ class Tracer:
         visitor = ValueUnwrapper()
         visitor.visit(type, instance)
 
-    def traceInterfaceImpl(self, interface):
-        print '%s::%s(%s * pInstance) {' % (interface_wrap_name(interface), interface_wrap_name(interface), interface.name)
+    def declareWrapperInterface(self, interface):
+        print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
+        print "{"
+        print "public:"
+        print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
+        print "    virtual ~%s();" % getWrapperInterfaceName(interface)
+        print
+        for method in interface.iterMethods():
+            print "    " + method.prototype() + ";"
+        print
+        self.declareWrapperInterfaceVariables(interface)
+        print "};"
+        print
+
+    def declareWrapperInterfaceVariables(self, interface):
+        #print "private:"
+        print "    DWORD m_dwMagic;"
+        print "    %s * m_pInstance;" % (interface.name,)
+
+    def implementWrapperInterface(self, interface):
+        print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
+        print '    m_dwMagic = 0xd8365d6c;'
         print '    m_pInstance = pInstance;'
         print '}'
         print
-        print '%s::~%s() {' % (interface_wrap_name(interface), interface_wrap_name(interface))
+        print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
         print '}'
         print
-        for method in interface.iterMethods():
-            self.traceMethod(interface, method)
+        for base, method in interface.iterBaseMethods():
+            self.implementWrapperInterfaceMethod(interface, base, method)
+        print
+
+    def implementWrapperInterfaceMethod(self, interface, base, method):
+        print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
+        if method.type is not stdapi.Void:
+            print '    %s __result;' % method.type
+    
+        self.implementWrapperInterfaceMethodBody(interface, base, method)
+    
+        if method.type is not stdapi.Void:
+            print '    return __result;'
+        print '}'
         print
 
-    def traceMethod(self, interface, method):
-        print method.prototype(interface_wrap_name(interface) + '::' + method.name) + ' {'
+    def implementWrapperInterfaceMethodBody(self, interface, base, method):
         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
         print '    trace::localWriter.beginArg(0);'
         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
         print '    trace::localWriter.endArg();'
+
+        from specs.winapi import REFIID
+        from specs.stdapi import Pointer, Opaque, Interface
+
+        riid = None
         for arg in method.args:
             if not arg.output:
                 self.unwrapArg(method, arg)
                 self.serializeArg(method, arg)
-        if method.type is stdapi.Void:
-            result = ''
-        else:
-            print '    %s __result;' % method.type
-            result = '__result = '
+                if arg.type is REFIID:
+                    riid = arg
         print '    trace::localWriter.endEnter();'
-        print '    %sm_pInstance->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
+        
+        self.invokeMethod(interface, base, method)
+
         print '    trace::localWriter.beginLeave(__call);'
         for arg in method.args:
             if arg.output:
                 self.serializeArg(method, arg)
                 self.wrapArg(method, arg)
+                if riid is not None and isinstance(arg.type, Pointer):
+                    if isinstance(arg.type.type, Opaque):
+                        self.wrapIid(interface, method, riid, arg)
+                    else:
+                        assert isinstance(arg.type.type, Pointer)
+                        assert isinstance(arg.type.type.type, Interface)
+
         if method.type is not stdapi.Void:
             print '    trace::localWriter.beginReturn();'
             self.serializeValue(method.type, "__result")
             print '    trace::localWriter.endReturn();'
             self.wrapValue(method.type, '__result')
         print '    trace::localWriter.endLeave();'
-        if method.name == 'QueryInterface':
-            print '    if (ppvObj && *ppvObj) {'
-            print '        if (*ppvObj == m_pInstance) {'
-            print '            *ppvObj = this;'
+        if method.name == 'Release':
+            assert method.type is not stdapi.Void
+            print '    if (!__result)'
+            print '        delete this;'
+
+    def wrapIid(self, interface, method, riid, out):
+            print '    if (%s && *%s) {' % (out.name, out.name)
+            print '        if (*%s == m_pInstance) {' % (out.name,)
+            print '            *%s = this;' % (out.name,)
             print '        }'
-            for iface in self.api.interfaces:
-                print r'        else if (riid == IID_%s) {' % iface.name
-                print r'            *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
+            for iface in self.api.getAllInterfaces():
+                print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
+                print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
                 print r'        }'
             print r'        else {'
-            print r'            os::log("apitrace: warning: unknown REFIID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
-            print r'                    riid.Data1, riid.Data2, riid.Data3,'
-            print r'                    riid.Data4[0],'
-            print r'                    riid.Data4[1],'
-            print r'                    riid.Data4[2],'
-            print r'                    riid.Data4[3],'
-            print r'                    riid.Data4[4],'
-            print r'                    riid.Data4[5],'
-            print r'                    riid.Data4[6],'
-            print r'                    riid.Data4[7]);'
+            print r'            os::log("apitrace: warning: %s::%s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
+            print r'                    "%s", "%s",' % (interface.name, method.name)
+            print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
+            print r'                    %s.Data4[0],' % (riid.name,)
+            print r'                    %s.Data4[1],' % (riid.name,)
+            print r'                    %s.Data4[2],' % (riid.name,)
+            print r'                    %s.Data4[3],' % (riid.name,)
+            print r'                    %s.Data4[4],' % (riid.name,)
+            print r'                    %s.Data4[5],' % (riid.name,)
+            print r'                    %s.Data4[6],' % (riid.name,)
+            print r'                    %s.Data4[7]);' % (riid.name,)
             print r'        }'
             print '    }'
-        if method.name == 'Release':
-            assert method.type is not stdapi.Void
-            print '    if (!__result)'
-            print '        delete this;'
-        if method.type is not stdapi.Void:
-            print '    return __result;'
-        print '}'
-        print
 
+    def invokeMethod(self, interface, base, method):
+        if method.type is stdapi.Void:
+            result = ''
+        else:
+            result = '__result = '
+        print '    %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args]))
+    
+    def emit_memcpy(self, dest, src, length):
+        print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
+        print '        trace::localWriter.beginArg(0);'
+        print '        trace::localWriter.writeOpaque(%s);' % dest
+        print '        trace::localWriter.endArg();'
+        print '        trace::localWriter.beginArg(1);'
+        print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
+        print '        trace::localWriter.endArg();'
+        print '        trace::localWriter.beginArg(2);'
+        print '        trace::localWriter.writeUInt(%s);' % length
+        print '        trace::localWriter.endArg();'
+        print '        trace::localWriter.endEnter();'
+        print '        trace::localWriter.beginLeave(__call);'
+        print '        trace::localWriter.endLeave();'
+