]> git.cworth.org Git - apitrace/blobdiff - specs/stdapi.py
Handle variations of LockRect.
[apitrace] / specs / stdapi.py
index 2057124c9edf9bd4610a0f423fd7e7f359fee71b..57e9aa6eb1fcacda7aa77e976b6f1c06dbb9334c 100644 (file)
@@ -68,6 +68,12 @@ class Type:
     def visit(self, visitor, *args, **kwargs):
         raise NotImplementedError
 
+    def mutable(self):
+        '''Return a mutable version of this type.
+
+        Convenience wrapper around MutableRebuilder.'''
+        visitor = MutableRebuilder()
+        return visitor.visit(self)
 
 
 class _Void(Type):
@@ -395,6 +401,12 @@ class Interface(Type):
     def visit(self, visitor, *args, **kwargs):
         return visitor.visitInterface(self, *args, **kwargs)
 
+    def getMethodByName(self, name):
+        for method in self.iterMethods():
+            if method.name == name:
+                return method
+        return None
+
     def iterMethods(self):
         if self.base is not None:
             for method in self.base.iterMethods():
@@ -403,6 +415,13 @@ class Interface(Type):
             yield method
         raise StopIteration
 
+    def iterBases(self):
+        iface = self
+        while iface is not None:
+            yield iface
+            iface = iface.base
+        raise StopIteration
+
     def iterBaseMethods(self):
         if self.base is not None:
             for iface, method in self.base.iterBaseMethods():
@@ -426,6 +445,7 @@ class Method(Function):
             s += ' const'
         return s
 
+
 def StdMethod(*args, **kwargs):
     kwargs.setdefault('call', '__stdcall')
     return Method(*args, **kwargs)
@@ -464,11 +484,12 @@ def OpaqueBlob(type, size):
 
 class Polymorphic(Type):
 
-    def __init__(self, defaultType, switchExpr, switchTypes):
+    def __init__(self, switchExpr, switchTypes, defaultType, contextLess=True):
         Type.__init__(self, defaultType.expr)
-        self.defaultType = defaultType
         self.switchExpr = switchExpr
         self.switchTypes = switchTypes
+        self.defaultType = defaultType
+        self.contextLess = contextLess
 
     def visit(self, visitor, *args, **kwargs):
         return visitor.visitPolymorphic(self, *args, **kwargs)
@@ -490,6 +511,13 @@ class Polymorphic(Type):
         return zip(cases, types)
 
 
+def EnumPolymorphic(enumName, switchExpr, switchTypes, defaultType, contextLess=True):
+    enumValues = [expr for expr, type in switchTypes]
+    enum = Enum(enumName, enumValues)
+    polymorphic = Polymorphic(switchExpr, switchTypes, defaultType, contextLess)
+    return enum, polymorphic
+
+
 class Visitor:
     '''Abstract visitor for the type hierarchy.'''
 
@@ -661,90 +689,114 @@ class Rebuilder(Visitor):
         return interface
 
     def visitPolymorphic(self, polymorphic):
-        defaultType = self.visit(polymorphic.defaultType)
         switchExpr = polymorphic.switchExpr
         switchTypes = [(expr, self.visit(type)) for expr, type in polymorphic.switchTypes]
-        return Polymorphic(defaultType, switchExpr, switchTypes)
+        defaultType = self.visit(polymorphic.defaultType)
+        return Polymorphic(switchExpr, switchTypes, defaultType, polymorphic.contextLess)
 
 
-class Collector(Visitor):
-    '''Visitor which collects all unique types as it traverses them.'''
+class MutableRebuilder(Rebuilder):
+    '''Type visitor which derives a mutable type.'''
 
-    def __init__(self):
-        self.__visited = set()
-        self.types = []
+    def visitConst(self, const):
+        # Strip out const qualifier
+        return const.type
 
-    def visit(self, type):
-        if type in self.__visited:
-            return
-        self.__visited.add(type)
-        Visitor.visit(self, type)
-        self.types.append(type)
+    def visitAlias(self, alias):
+        # Tear the alias on type changes
+        type = self.visit(alias.type)
+        if type is alias.type:
+            return alias
+        return type
 
-    def visitVoid(self, literal):
+    def visitReference(self, reference):
+        # Strip out references
+        return reference.type
+
+
+class Traverser(Visitor):
+    '''Visitor which all types.'''
+
+    def visitVoid(self, void, *args, **kwargs):
         pass
 
-    def visitLiteral(self, literal):
+    def visitLiteral(self, literal, *args, **kwargs):
         pass
 
-    def visitString(self, string):
+    def visitString(self, string, *args, **kwargs):
         pass
 
-    def visitConst(self, const):
-        self.visit(const.type)
+    def visitConst(self, const, *args, **kwargs):
+        self.visit(const.type, *args, **kwargs)
 
-    def visitStruct(self, struct):
+    def visitStruct(self, struct, *args, **kwargs):
         for type, name in struct.members:
-            self.visit(type)
+            self.visit(type, *args, **kwargs)
 
-    def visitArray(self, array):
-        self.visit(array.type)
+    def visitArray(self, array, *args, **kwargs):
+        self.visit(array.type, *args, **kwargs)
 
-    def visitBlob(self, array):
+    def visitBlob(self, array, *args, **kwargs):
         pass
 
-    def visitEnum(self, enum):
+    def visitEnum(self, enum, *args, **kwargs):
         pass
 
-    def visitBitmask(self, bitmask):
-        self.visit(bitmask.type)
+    def visitBitmask(self, bitmask, *args, **kwargs):
+        self.visit(bitmask.type, *args, **kwargs)
 
-    def visitPointer(self, pointer):
-        self.visit(pointer.type)
+    def visitPointer(self, pointer, *args, **kwargs):
+        self.visit(pointer.type, *args, **kwargs)
 
-    def visitIntPointer(self, pointer):
+    def visitIntPointer(self, pointer, *args, **kwargs):
         pass
 
-    def visitObjPointer(self, pointer):
-        self.visit(pointer.type)
+    def visitObjPointer(self, pointer, *args, **kwargs):
+        self.visit(pointer.type, *args, **kwargs)
 
-    def visitLinearPointer(self, pointer):
-        self.visit(pointer.type)
+    def visitLinearPointer(self, pointer, *args, **kwargs):
+        self.visit(pointer.type, *args, **kwargs)
 
-    def visitReference(self, reference):
-        self.visit(reference.type)
+    def visitReference(self, reference, *args, **kwargs):
+        self.visit(reference.type, *args, **kwargs)
 
-    def visitHandle(self, handle):
-        self.visit(handle.type)
+    def visitHandle(self, handle, *args, **kwargs):
+        self.visit(handle.type, *args, **kwargs)
 
-    def visitAlias(self, alias):
-        self.visit(alias.type)
+    def visitAlias(self, alias, *args, **kwargs):
+        self.visit(alias.type, *args, **kwargs)
 
-    def visitOpaque(self, opaque):
+    def visitOpaque(self, opaque, *args, **kwargs):
         pass
 
-    def visitInterface(self, interface):
+    def visitInterface(self, interface, *args, **kwargs):
         if interface.base is not None:
-            self.visit(interface.base)
+            self.visit(interface.base, *args, **kwargs)
         for method in interface.iterMethods():
             for arg in method.args:
-                self.visit(arg.type)
-            self.visit(method.type)
+                self.visit(arg.type, *args, **kwargs)
+            self.visit(method.type, *args, **kwargs)
 
-    def visitPolymorphic(self, polymorphic):
-        self.visit(polymorphic.defaultType)
+    def visitPolymorphic(self, polymorphic, *args, **kwargs):
+        self.visit(polymorphic.defaultType, *args, **kwargs)
         for expr, type in polymorphic.switchTypes:
-            self.visit(type)
+            self.visit(type, *args, **kwargs)
+
+
+class Collector(Traverser):
+    '''Visitor which collects all unique types as it traverses them.'''
+
+    def __init__(self):
+        self.__visited = set()
+        self.types = []
+
+    def visit(self, type):
+        if type in self.__visited:
+            return
+        self.__visited.add(type)
+        Visitor.visit(self, type)
+        self.types.append(type)
+
 
 
 class API:
@@ -799,7 +851,7 @@ class API:
         self.addFunctions(api.functions)
         self.addInterfaces(api.interfaces)
 
-    def get_function_by_name(self, name):
+    def getFunctionByName(self, name):
         for function in self.functions:
             if function.name == name:
                 return function