]> git.cworth.org Git - apitrace/blobdiff - wrappers/trace.py
d3d10/d3d11: Complete union support.
[apitrace] / wrappers / trace.py
index cbdc0c07122e2a9ee4390263f83746a68decb107..506570dca018ed9bc0dde652940cd33babd83fb6 100644 (file)
@@ -40,44 +40,6 @@ def getWrapperInterfaceName(interface):
 
 
 
-class ExpanderMixin:
-    '''Mixin class that provides a bunch of methods to expand C expressions
-    from the specifications.'''
-
-    __structs = None
-    __indices = None
-
-    def expand(self, expr):
-        # Expand a C expression, replacing certain variables
-        if not isinstance(expr, basestring):
-            return expr
-        variables = {}
-
-        if self.__structs is not None:
-            variables['self'] = '(%s)' % self.__structs[0]
-        if self.__indices is not None:
-            variables['i'] = self.__indices[0]
-
-        expandedExpr = expr.format(**variables)
-        if expandedExpr != expr and 0:
-            sys.stderr.write("  %r -> %r\n" % (expr, expandedExpr))
-        return expandedExpr
-
-    def visitMember(self, structInstance, member_type, *args, **kwargs):
-        self.__structs = (structInstance, self.__structs)
-        try:
-            return self.visit(member_type, *args, **kwargs)
-        finally:
-            _, self.__structs = self.__structs
-
-    def visitElement(self, element_index, element_type, *args, **kwargs):
-        self.__indices = (element_index, self.__indices)
-        try:
-            return self.visit(element_type, *args, **kwargs)
-        finally:
-            _, self.__indices = self.__indices
-
-
 class ComplexValueSerializer(stdapi.OnceVisitor):
     '''Type visitors which generates serialization functions for
     complex types.
@@ -104,10 +66,17 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
     def visitStruct(self, struct):
         print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
         for type, name,  in struct.members:
-            print '    "%s",' % (name,)
+            if name is None:
+                print '    "",'
+            else:
+                print '    "%s",' % (name,)
         print '};'
         print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
-        print '   %u, "%s", %u, _struct%s_members' % (struct.id, struct.name, len(struct.members), struct.tag)
+        if struct.name is None:
+            structName = '""'
+        else:
+            structName = '"%s"' % struct.name
+        print '    %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag)
         print '};'
         print
 
@@ -120,22 +89,22 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
     def visitEnum(self, enum):
         print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
         for value in enum.values:
-            print '   {"%s", %s},' % (value, value)
+            print '    {"%s", %s},' % (value, value)
         print '};'
         print
         print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
-        print '   %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
+        print '    %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
         print '};'
         print
 
     def visitBitmask(self, bitmask):
         print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
         for value in bitmask.values:
-            print '   {"%s", %s},' % (value, value)
+            print '    {"%s", %s},' % (value, value)
         print '};'
         print
         print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
-        print '   %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
+        print '    %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
         print '};'
         print
 
@@ -181,7 +150,7 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
         print
 
 
-class ValueSerializer(stdapi.Visitor, ExpanderMixin):
+class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin):
     '''Visitor which generates code to serialize any type.
     
     Simple types are serialized inline here, whereas the serialization of
@@ -189,11 +158,6 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
     ComplexValueSerializer visitor above.
     '''
 
-    def __init__(self):
-        #stdapi.Visitor.__init__(self)
-        self.indices = []
-        self.instances = []
-
     def visitLiteral(self, literal, instance):
         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
 
@@ -208,7 +172,7 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
             # reinterpret_cast is necessary for GLubyte * <=> char *
             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
         if string.length is not None:
-            length = ', %s' % string.length
+            length = ', %s' % self.expand(string.length)
         else:
             length = ''
         print '    trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
@@ -218,8 +182,8 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
 
     def visitStruct(self, struct, instance):
         print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
-        for type, name in struct.members:
-            self.visitMember(instance, type, '(%s).%s' % (instance, name,))
+        for member in struct.members:
+            self.visitMember(member, instance)
         print '    trace::localWriter.endStruct();'
 
     def visitArray(self, array, instance):
@@ -287,12 +251,21 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
         if polymorphic.contextLess:
             print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
         else:
-            print '    switch (%s) {' % polymorphic.switchExpr
+            switchExpr = self.expand(polymorphic.switchExpr)
+            print '    switch (%s) {' % switchExpr
             for cases, type in polymorphic.iterSwitch():
                 for case in cases:
                     print '    %s:' % case
-                self.visit(type, 'static_cast<%s>(%s)' % (type, instance))
+                caseInstance = instance
+                if type.expr is not None:
+                    caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
+                self.visit(type, caseInstance)
                 print '        break;'
+            if polymorphic.defaultType is None:
+                print r'    default:'
+                print r'        os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,)
+                print r'        trace::localWriter.writeNull();'
+                print r'        break;'
             print '    }'
 
 
@@ -312,7 +285,7 @@ class WrapDecider(stdapi.Traverser):
         self.needsWrapping = True
 
 
-class ValueWrapper(stdapi.Traverser, ExpanderMixin):
+class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin):
     '''Type visitor which will generate the code to wrap an instance.
     
     Wrapping is necessary mostly for interfaces, however interface pointers can
@@ -320,8 +293,8 @@ class ValueWrapper(stdapi.Traverser, ExpanderMixin):
     '''
 
     def visitStruct(self, struct, instance):
-        for type, name in struct.members:
-            self.visitMember(instance, type, "(%s).%s" % (instance, name))
+        for member in struct.members:
+            self.visitMember(member, instance)
 
     def visitArray(self, array, instance):
         array_length = self.expand(array.length)
@@ -426,8 +399,9 @@ class Tracer:
         self.header(api)
 
         # Includes
-        for header in api.headers:
-            print header
+        for module in api.modules:
+            for header in module.headers:
+                print header
         print
 
         # Generate the serializer functions
@@ -442,8 +416,10 @@ class Tracer:
         # Function wrappers
         self.interface = None
         self.base = None
-        map(self.traceFunctionDecl, api.functions)
-        map(self.traceFunctionImpl, api.functions)
+        for function in api.getAllFunctions():
+            self.traceFunctionDecl(function)
+        for function in api.getAllFunctions():
+            self.traceFunctionImpl(function)
         print
 
         self.footer(api)
@@ -538,7 +514,7 @@ class Tracer:
             return 'true'
         if str(function.type) == 'HRESULT':
             return 'SUCCEEDED(_result)'
-        return 'false'
+        return 'true'
 
     def serializeArg(self, function, arg):
         print '    trace::localWriter.beginArg(%u);' % (arg.index,)