]> git.cworth.org Git - apitrace/commitdiff
specs: Initial attempt to support unions.
authorJosé Fonseca <jose.r.fonseca@gmail.com>
Tue, 20 Nov 2012 11:08:08 +0000 (11:08 +0000)
committerJosé Fonseca <jose.r.fonseca@gmail.com>
Tue, 20 Nov 2012 11:08:08 +0000 (11:08 +0000)
retrace/retrace.py
specs/d3d11.py
specs/stdapi.py
wrappers/trace.py

index c4ad2d2bb494aceb861dcf163818bb702af58622..f1c5bb0ff810b4fd40702b7aa584a30e622561c6 100644 (file)
@@ -96,6 +96,9 @@ class ValueAllocator(stdapi.Visitor):
         pass
 
     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
+        if polymorphic.defaultType is None:
+            # FIXME
+            raise UnsupportedType
         self.visit(polymorphic.defaultType, lvalue, rvalue)
 
     def visitOpaque(self, opaque, lvalue, rvalue):
@@ -186,6 +189,9 @@ class ValueDeserializer(stdapi.Visitor):
             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
 
     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
+        if polymorphic.defaultType is None:
+            # FIXME
+            raise UnsupportedType
         self.visit(polymorphic.defaultType, lvalue, rvalue)
     
     def visitOpaque(self, opaque, lvalue, rvalue):
@@ -294,6 +300,9 @@ class SwizzledValueRegistrator(stdapi.Visitor):
             self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
     
     def visitPolymorphic(self, polymorphic, lvalue, rvalue):
+        if polymorphic.defaultType is None:
+            # FIXME
+            raise UnsupportedType
         self.visit(polymorphic.defaultType, lvalue, rvalue)
     
     def visitOpaque(self, opaque, lvalue, rvalue):
@@ -380,8 +389,7 @@ class Retracer:
         if not success:
             print '    if (1) {'
             self.failFunction(function)
-            if function.name[-1].islower():
-                sys.stderr.write('warning: unsupported %s call\n' % function.name)
+            sys.stderr.write('warning: unsupported %s call\n' % function.name)
             print '    }'
 
     def swizzleValues(self, function):
index 8ee25f22486e0501cbfb1e2b2499191836e3345a..6592bcbf7d48778662f9dd7f8623994f8877230f 100644 (file)
@@ -677,15 +677,15 @@ D3D11_TEX3D_RTV = Struct("D3D11_TEX3D_RTV", [
 D3D11_RENDER_TARGET_VIEW_DESC = Struct("D3D11_RENDER_TARGET_VIEW_DESC", [
     (DXGI_FORMAT, "Format"),
     (D3D11_RTV_DIMENSION, "ViewDimension"),
-    (Union(None, [
-        (D3D11_BUFFER_RTV, "Buffer"),
-        (D3D11_TEX1D_RTV, "Texture1D"),
-        (D3D11_TEX1D_ARRAY_RTV, "Texture1DArray"),
-        (D3D11_TEX2D_RTV, "Texture2D"),
-        (D3D11_TEX2D_ARRAY_RTV, "Texture2DArray"),
-        (D3D11_TEX2DMS_RTV, "Texture2DMS"),
-        (D3D11_TEX2DMS_ARRAY_RTV, "Texture2DMSArray"),
-        (D3D11_TEX3D_RTV, "Texture3D"),
+    (Union_("{self}.ViewDimension", [
+        ("D3D11_RTV_DIMENSION_BUFFER", D3D11_BUFFER_RTV, "Buffer"),
+        ("D3D11_RTV_DIMENSION_TEXTURE1D", D3D11_TEX1D_RTV, "Texture1D"),
+        ("D3D11_RTV_DIMENSION_TEXTURE1DARRAY", D3D11_TEX1D_ARRAY_RTV, "Texture1DArray"),
+        ("D3D11_RTV_DIMENSION_TEXTURE2D", D3D11_TEX2D_RTV, "Texture2D"),
+        ("D3D11_RTV_DIMENSION_TEXTURE2DARRAY", D3D11_TEX2D_ARRAY_RTV, "Texture2DArray"),
+        ("D3D11_RTV_DIMENSION_TEXTURE2DMS", D3D11_TEX2DMS_RTV, "Texture2DMS"),
+        ("D3D11_RTV_DIMENSION_TEXTURE2DMSARRAY", D3D11_TEX2DMS_ARRAY_RTV, "Texture2DMSArray"),
+        ("D3D11_RTV_DIMENSION_TEXTURE3D", D3D11_TEX3D_RTV, "Texture3D"),
     ]), None),
 ])
 
index 1097347f38ce3a97a26b9312c7de183c469e8c0f..4b4808e458641173cf0b350a06a91cc715f85881 100644 (file)
@@ -296,7 +296,7 @@ class Struct(Type):
 
         # Eliminate anonymous unions
         for type, name in members:
-            if name is not None:
+            if name is not None or isinstance(type, Polymorphic):
                 self.members.append((type, name))
             else:
                 assert isinstance(type, Union)
@@ -320,6 +320,13 @@ class Union(Type):
         self.name = name
         self.members = members
 
+def Union_(kindExpr, kindTypes, contextLess=True):
+    switchTypes = []
+    for kindCase, kindType, kindMemberName in kindTypes:
+        switchType = Struct(None, [(kindType, kindMemberName)])
+        switchTypes.append((kindCase, switchType))
+    return Polymorphic(kindExpr, switchTypes, contextLess=contextLess)
+
 
 class Alias(Type):
 
@@ -515,8 +522,12 @@ def OpaqueBlob(type, size):
 
 class Polymorphic(Type):
 
-    def __init__(self, switchExpr, switchTypes, defaultType, contextLess=True):
-        Type.__init__(self, defaultType.expr)
+    def __init__(self, switchExpr, switchTypes, defaultType=None, contextLess=True):
+        if defaultType is None:
+            Type.__init__(self, None)
+            contextLess = False
+        else:
+            Type.__init__(self, defaultType.expr)
         self.switchExpr = switchExpr
         self.switchTypes = switchTypes
         self.defaultType = defaultType
@@ -526,8 +537,12 @@ class Polymorphic(Type):
         return visitor.visitPolymorphic(self, *args, **kwargs)
 
     def iterSwitch(self):
-        cases = [['default']]
-        types = [self.defaultType]
+        cases = []
+        types = []
+
+        if self.defaultType is not None:
+            cases.append(['default'])
+            types.append(self.defaultType)
 
         for expr, type in self.switchTypes:
             case = 'case %s' % expr
@@ -726,7 +741,10 @@ class Rebuilder(Visitor):
     def visitPolymorphic(self, polymorphic):
         switchExpr = polymorphic.switchExpr
         switchTypes = [(expr, self.visit(type)) for expr, type in polymorphic.switchTypes]
-        defaultType = self.visit(polymorphic.defaultType)
+        if polymorphic.defaultType is None:
+            defaultType = None
+        else:
+            defaultType = self.visit(polymorphic.defaultType)
         return Polymorphic(switchExpr, switchTypes, defaultType, polymorphic.contextLess)
 
 
@@ -816,9 +834,10 @@ class Traverser(Visitor):
             self.visit(method.type, *args, **kwargs)
 
     def visitPolymorphic(self, polymorphic, *args, **kwargs):
-        self.visit(polymorphic.defaultType, *args, **kwargs)
         for expr, type in polymorphic.switchTypes:
             self.visit(type, *args, **kwargs)
+        if polymorphic.defaultType is not None:
+            self.visit(polymorphic.defaultType, *args, **kwargs)
 
 
 class Collector(Traverser):
index f1aaf82c67474ab8263c707b1fa8f0dc10aed11e..34f4c6930bce7717f60f4ee45aa12f24eafb96ab 100644 (file)
@@ -104,10 +104,17 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
     def visitStruct(self, struct):
         print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
         for type, name,  in struct.members:
-            print '    "%s",' % (name,)
+            if name is None:
+                print '    "",'
+            else:
+                print '    "%s",' % (name,)
         print '};'
         print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
-        print '   %u, "%s", %u, _struct%s_members' % (struct.id, struct.name, len(struct.members), struct.tag)
+        if struct.name is None:
+            structName = '""'
+        else:
+            structName = '"%s"' % struct.name
+        print '    %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag)
         print '};'
         print
 
@@ -120,22 +127,22 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
     def visitEnum(self, enum):
         print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
         for value in enum.values:
-            print '   {"%s", %s},' % (value, value)
+            print '    {"%s", %s},' % (value, value)
         print '};'
         print
         print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
-        print '   %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
+        print '    %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
         print '};'
         print
 
     def visitBitmask(self, bitmask):
         print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
         for value in bitmask.values:
-            print '   {"%s", %s},' % (value, value)
+            print '    {"%s", %s},' % (value, value)
         print '};'
         print
         print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
-        print '   %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
+        print '    %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
         print '};'
         print
 
@@ -189,11 +196,6 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
     ComplexValueSerializer visitor above.
     '''
 
-    def __init__(self):
-        #stdapi.Visitor.__init__(self)
-        self.indices = []
-        self.instances = []
-
     def visitLiteral(self, literal, instance):
         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
 
@@ -219,7 +221,12 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
     def visitStruct(self, struct, instance):
         print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
         for type, name in struct.members:
-            self.visitMember(instance, type, '(%s).%s' % (instance, name,))
+            if name is None:
+                # Anonymous structure/union member
+                memberInstance = instance
+            else:
+                memberInstance = '(%s).%s' % (instance, name)
+            self.visitMember(instance, type, memberInstance)
         print '    trace::localWriter.endStruct();'
 
     def visitArray(self, array, instance):
@@ -287,11 +294,14 @@ class ValueSerializer(stdapi.Visitor, ExpanderMixin):
         if polymorphic.contextLess:
             print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
         else:
-            print '    switch (%s) {' % polymorphic.switchExpr
+            print '    switch (%s) {' % self.expand(polymorphic.switchExpr)
             for cases, type in polymorphic.iterSwitch():
                 for case in cases:
                     print '    %s:' % case
-                self.visit(type, 'static_cast<%s>(%s)' % (type, instance))
+                caseInstance = instance
+                if type.expr is not None:
+                    caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
+                self.visit(type, caseInstance)
                 print '        break;'
             print '    }'
 
@@ -321,7 +331,12 @@ class ValueWrapper(stdapi.Traverser, ExpanderMixin):
 
     def visitStruct(self, struct, instance):
         for type, name in struct.members:
-            self.visitMember(instance, type, "(%s).%s" % (instance, name))
+            if name is None:
+                # Anonymous structure/union member
+                memberInstance = instance
+            else:
+                memberInstance = '(%s).%s' % (instance, name)
+            self.visitMember(instance, type, memberInstance)
 
     def visitArray(self, array, instance):
         array_length = self.expand(array.length)