]> git.cworth.org Git - apitrace/blobdiff - specs/stdapi.py
Properly (un)wrap array arguments.
[apitrace] / specs / stdapi.py
index 2057124c9edf9bd4610a0f423fd7e7f359fee71b..2e4e5d86c2b89c52a798735e7040dd439c3aa0d3 100644 (file)
@@ -68,6 +68,12 @@ class Type:
     def visit(self, visitor, *args, **kwargs):
         raise NotImplementedError
 
+    def mutable(self):
+        '''Return a mutable version of this type.
+
+        Convenience wrapper around MutableRebuilder.'''
+        visitor = MutableRebuilder()
+        return visitor.visit(self)
 
 
 class _Void(Type):
@@ -667,84 +673,108 @@ class Rebuilder(Visitor):
         return Polymorphic(defaultType, switchExpr, switchTypes)
 
 
-class Collector(Visitor):
-    '''Visitor which collects all unique types as it traverses them.'''
+class MutableRebuilder(Rebuilder):
+    '''Type visitor which derives a mutable type.'''
 
-    def __init__(self):
-        self.__visited = set()
-        self.types = []
+    def visitConst(self, const):
+        # Strip out const qualifier
+        return const.type
+
+    def visitAlias(self, alias):
+        # Tear the alias on type changes
+        type = self.visit(alias.type)
+        if type is alias.type:
+            return alias
+        return type
+
+    def visitReference(self, reference):
+        # Strip out references
+        return reference.type
 
-    def visit(self, type):
-        if type in self.__visited:
-            return
-        self.__visited.add(type)
-        Visitor.visit(self, type)
-        self.types.append(type)
 
-    def visitVoid(self, literal):
+class Traverser(Visitor):
+    '''Visitor which all types.'''
+
+    def visitVoid(self, void, *args, **kwargs):
         pass
 
-    def visitLiteral(self, literal):
+    def visitLiteral(self, literal, *args, **kwargs):
         pass
 
-    def visitString(self, string):
+    def visitString(self, string, *args, **kwargs):
         pass
 
-    def visitConst(self, const):
-        self.visit(const.type)
+    def visitConst(self, const, *args, **kwargs):
+        self.visit(const.type, *args, **kwargs)
 
-    def visitStruct(self, struct):
+    def visitStruct(self, struct, *args, **kwargs):
         for type, name in struct.members:
-            self.visit(type)
+            self.visit(type, *args, **kwargs)
 
-    def visitArray(self, array):
-        self.visit(array.type)
+    def visitArray(self, array, *args, **kwargs):
+        self.visit(array.type, *args, **kwargs)
 
-    def visitBlob(self, array):
+    def visitBlob(self, array, *args, **kwargs):
         pass
 
-    def visitEnum(self, enum):
+    def visitEnum(self, enum, *args, **kwargs):
         pass
 
-    def visitBitmask(self, bitmask):
-        self.visit(bitmask.type)
+    def visitBitmask(self, bitmask, *args, **kwargs):
+        self.visit(bitmask.type, *args, **kwargs)
 
-    def visitPointer(self, pointer):
-        self.visit(pointer.type)
+    def visitPointer(self, pointer, *args, **kwargs):
+        self.visit(pointer.type, *args, **kwargs)
 
-    def visitIntPointer(self, pointer):
+    def visitIntPointer(self, pointer, *args, **kwargs):
         pass
 
-    def visitObjPointer(self, pointer):
-        self.visit(pointer.type)
+    def visitObjPointer(self, pointer, *args, **kwargs):
+        self.visit(pointer.type, *args, **kwargs)
 
-    def visitLinearPointer(self, pointer):
-        self.visit(pointer.type)
+    def visitLinearPointer(self, pointer, *args, **kwargs):
+        self.visit(pointer.type, *args, **kwargs)
 
-    def visitReference(self, reference):
-        self.visit(reference.type)
+    def visitReference(self, reference, *args, **kwargs):
+        self.visit(reference.type, *args, **kwargs)
 
-    def visitHandle(self, handle):
-        self.visit(handle.type)
+    def visitHandle(self, handle, *args, **kwargs):
+        self.visit(handle.type, *args, **kwargs)
 
-    def visitAlias(self, alias):
-        self.visit(alias.type)
+    def visitAlias(self, alias, *args, **kwargs):
+        self.visit(alias.type, *args, **kwargs)
 
-    def visitOpaque(self, opaque):
+    def visitOpaque(self, opaque, *args, **kwargs):
         pass
 
-    def visitInterface(self, interface):
+    def visitInterface(self, interface, *args, **kwargs):
         if interface.base is not None:
-            self.visit(interface.base)
+            self.visit(interface.base, *args, **kwargs)
         for method in interface.iterMethods():
             for arg in method.args:
-                self.visit(arg.type)
-            self.visit(method.type)
+                self.visit(arg.type, *args, **kwargs)
+            self.visit(method.type, *args, **kwargs)
 
-    def visitPolymorphic(self, polymorphic):
-        self.visit(polymorphic.defaultType)
+    def visitPolymorphic(self, polymorphic, *args, **kwargs):
+        self.visit(polymorphic.defaultType, *args, **kwargs)
         for expr, type in polymorphic.switchTypes:
-            self.visit(type)
+            self.visit(type, *args, **kwargs)
+
+
+class Collector(Traverser):
+    '''Visitor which collects all unique types as it traverses them.'''
+
+    def __init__(self):
+        self.__visited = set()
+        self.types = []
+
+    def visit(self, type):
+        if type in self.__visited:
+            return
+        self.__visited.add(type)
+        Visitor.visit(self, type)
+        self.types.append(type)
+
 
 
 class API: