]> git.cworth.org Git - apitrace/blobdiff - specs/stdapi.py
Handle variations of LockRect.
[apitrace] / specs / stdapi.py
index b9d6c64f5c1d83c834e28b32f75fd893a2b0c7cc..57e9aa6eb1fcacda7aa77e976b6f1c06dbb9334c 100644 (file)
@@ -41,7 +41,10 @@ class Type:
         # on this type, so it should preferrably be something representative of
         # the type.
         if tag is None:
-            tag = ''.join([c for c in expr if c.isalnum() or c in '_'])
+            if expr is not None:
+                tag = ''.join([c for c in expr if c.isalnum() or c in '_'])
+            else:
+                tag = 'anonynoums'
         else:
             for c in tag:
                 assert c.isalnum() or c in '_'
@@ -65,6 +68,12 @@ class Type:
     def visit(self, visitor, *args, **kwargs):
         raise NotImplementedError
 
+    def mutable(self):
+        '''Return a mutable version of this type.
+
+        Convenience wrapper around MutableRebuilder.'''
+        visitor = MutableRebuilder()
+        return visitor.visit(self)
 
 
 class _Void(Type):
@@ -127,6 +136,47 @@ class Pointer(Type):
         return visitor.visitPointer(self, *args, **kwargs)
 
 
+class IntPointer(Type):
+    '''Integer encoded as a pointer.'''
+
+    def visit(self, visitor, *args, **kwargs):
+        return visitor.visitIntPointer(self, *args, **kwargs)
+
+
+class ObjPointer(Type):
+    '''Pointer to an object.'''
+
+    def __init__(self, type):
+        Type.__init__(self, type.expr + " *", 'P' + type.tag)
+        self.type = type
+
+    def visit(self, visitor, *args, **kwargs):
+        return visitor.visitObjPointer(self, *args, **kwargs)
+
+
+class LinearPointer(Type):
+    '''Pointer to a linear range of memory.'''
+
+    def __init__(self, type, size = None):
+        Type.__init__(self, type.expr + " *", 'P' + type.tag)
+        self.type = type
+        self.size = size
+
+    def visit(self, visitor, *args, **kwargs):
+        return visitor.visitLinearPointer(self, *args, **kwargs)
+
+
+class Reference(Type):
+    '''C++ references.'''
+
+    def __init__(self, type):
+        Type.__init__(self, type.expr + " &", 'R' + type.tag)
+        self.type = type
+
+    def visit(self, visitor, *args, **kwargs):
+        return visitor.visitReference(self, *args, **kwargs)
+
+
 class Handle(Type):
 
     def __init__(self, name, type, range=None, key=None):
@@ -216,12 +266,35 @@ class Struct(Type):
         Struct.__id += 1
 
         self.name = name
-        self.members = members
+        self.members = []
+
+        # Eliminate anonymous unions
+        for type, name in members:
+            if name is not None:
+                self.members.append((type, name))
+            else:
+                assert isinstance(type, Union)
+                assert type.name is None
+                self.members.extend(type.members)
 
     def visit(self, visitor, *args, **kwargs):
         return visitor.visitStruct(self, *args, **kwargs)
 
 
+class Union(Type):
+
+    __id = 0
+
+    def __init__(self, name, members):
+        Type.__init__(self, name)
+
+        self.id = Union.__id
+        Union.__id += 1
+
+        self.name = name
+        self.members = members
+
+
 class Alias(Type):
 
     def __init__(self, expr, type):
@@ -231,17 +304,12 @@ class Alias(Type):
     def visit(self, visitor, *args, **kwargs):
         return visitor.visitAlias(self, *args, **kwargs)
 
-
-def Out(type, name):
-    arg = Arg(type, name, output=True)
-    return arg
-
-
 class Arg:
 
-    def __init__(self, type, name, output=False):
+    def __init__(self, type, name, input=True, output=False):
         self.type = type
         self.name = name
+        self.input = input
         self.output = output
         self.index = None
 
@@ -249,6 +317,16 @@ class Arg:
         return '%s %s' % (self.type, self.name)
 
 
+def In(type, name):
+    return Arg(type, name, input=True, output=False)
+
+def Out(type, name):
+    return Arg(type, name, input=False, output=True)
+
+def InOut(type, name):
+    return Arg(type, name, input=True, output=True)
+
+
 class Function:
 
     # 0-3 are reserved to memcpy, malloc, free, and realloc
@@ -298,6 +376,9 @@ class Function:
         s += ")"
         return s
 
+    def argNames(self):
+        return [arg.name for arg in self.args]
+
 
 def StdFunction(*args, **kwargs):
     kwargs.setdefault('call', '__stdcall')
@@ -320,6 +401,12 @@ class Interface(Type):
     def visit(self, visitor, *args, **kwargs):
         return visitor.visitInterface(self, *args, **kwargs)
 
+    def getMethodByName(self, name):
+        for method in self.iterMethods():
+            if method.name == name:
+                return method
+        return None
+
     def iterMethods(self):
         if self.base is not None:
             for method in self.base.iterMethods():
@@ -328,27 +415,52 @@ class Interface(Type):
             yield method
         raise StopIteration
 
+    def iterBases(self):
+        iface = self
+        while iface is not None:
+            yield iface
+            iface = iface.base
+        raise StopIteration
+
+    def iterBaseMethods(self):
+        if self.base is not None:
+            for iface, method in self.base.iterBaseMethods():
+                yield iface, method
+        for method in self.methods:
+            yield self, method
+        raise StopIteration
+
 
 class Method(Function):
 
-    def __init__(self, type, name, args):
-        Function.__init__(self, type, name, args, call = '__stdcall')
+    def __init__(self, type, name, args, call = '__stdcall', const=False, sideeffects=True):
+        Function.__init__(self, type, name, args, call = call, sideeffects=sideeffects)
         for index in range(len(self.args)):
             self.args[index].index = index + 1
+        self.const = const
+
+    def prototype(self, name=None):
+        s = Function.prototype(self, name)
+        if self.const:
+            s += ' const'
+        return s
+
+
+def StdMethod(*args, **kwargs):
+    kwargs.setdefault('call', '__stdcall')
+    return Method(*args, **kwargs)
 
 
 class String(Type):
 
-    def __init__(self, expr = "char *", length = None):
+    def __init__(self, expr = "char *", length = None, kind = 'String'):
         Type.__init__(self, expr)
         self.length = length
+        self.kind = kind
 
     def visit(self, visitor, *args, **kwargs):
         return visitor.visitString(self, *args, **kwargs)
 
-# C string (i.e., zero terminated)
-CString = String()
-
 
 class Opaque(Type):
     '''Opaque pointer.'''
@@ -372,11 +484,12 @@ def OpaqueBlob(type, size):
 
 class Polymorphic(Type):
 
-    def __init__(self, defaultType, switchExpr, switchTypes):
+    def __init__(self, switchExpr, switchTypes, defaultType, contextLess=True):
         Type.__init__(self, defaultType.expr)
-        self.defaultType = defaultType
         self.switchExpr = switchExpr
         self.switchTypes = switchTypes
+        self.defaultType = defaultType
+        self.contextLess = contextLess
 
     def visit(self, visitor, *args, **kwargs):
         return visitor.visitPolymorphic(self, *args, **kwargs)
@@ -398,6 +511,13 @@ class Polymorphic(Type):
         return zip(cases, types)
 
 
+def EnumPolymorphic(enumName, switchExpr, switchTypes, defaultType, contextLess=True):
+    enumValues = [expr for expr, type in switchTypes]
+    enum = Enum(enumName, enumValues)
+    polymorphic = Polymorphic(switchExpr, switchTypes, defaultType, contextLess)
+    return enum, polymorphic
+
+
 class Visitor:
     '''Abstract visitor for the type hierarchy.'''
 
@@ -434,6 +554,18 @@ class Visitor:
     def visitPointer(self, pointer, *args, **kwargs):
         raise NotImplementedError
 
+    def visitIntPointer(self, pointer, *args, **kwargs):
+        raise NotImplementedError
+
+    def visitObjPointer(self, pointer, *args, **kwargs):
+        raise NotImplementedError
+
+    def visitLinearPointer(self, pointer, *args, **kwargs):
+        raise NotImplementedError
+
+    def visitReference(self, reference, *args, **kwargs):
+        raise NotImplementedError
+
     def visitHandle(self, handle, *args, **kwargs):
         raise NotImplementedError
 
@@ -480,7 +612,11 @@ class Rebuilder(Visitor):
         return string
 
     def visitConst(self, const):
-        return Const(const.type)
+        const_type = self.visit(const.type)
+        if const_type is const.type:
+            return const
+        else:
+            return Const(const_type)
 
     def visitStruct(self, struct):
         members = [(self.visit(type), name) for type, name in struct.members]
@@ -502,93 +638,165 @@ class Rebuilder(Visitor):
         return Bitmask(type, bitmask.values)
 
     def visitPointer(self, pointer):
-        type = self.visit(pointer.type)
-        return Pointer(type)
+        pointer_type = self.visit(pointer.type)
+        if pointer_type is pointer.type:
+            return pointer
+        else:
+            return Pointer(pointer_type)
+
+    def visitIntPointer(self, pointer):
+        return pointer
+
+    def visitObjPointer(self, pointer):
+        pointer_type = self.visit(pointer.type)
+        if pointer_type is pointer.type:
+            return pointer
+        else:
+            return ObjPointer(pointer_type)
+
+    def visitLinearPointer(self, pointer):
+        pointer_type = self.visit(pointer.type)
+        if pointer_type is pointer.type:
+            return pointer
+        else:
+            return LinearPointer(pointer_type)
+
+    def visitReference(self, reference):
+        reference_type = self.visit(reference.type)
+        if reference_type is reference.type:
+            return reference
+        else:
+            return Reference(reference_type)
 
     def visitHandle(self, handle):
-        type = self.visit(handle.type)
-        return Handle(handle.name, type, range=handle.range, key=handle.key)
+        handle_type = self.visit(handle.type)
+        if handle_type is handle.type:
+            return handle
+        else:
+            return Handle(handle.name, handle_type, range=handle.range, key=handle.key)
 
     def visitAlias(self, alias):
-        type = self.visit(alias.type)
-        return Alias(alias.expr, type)
+        alias_type = self.visit(alias.type)
+        if alias_type is alias.type:
+            return alias
+        else:
+            return Alias(alias.expr, alias_type)
 
     def visitOpaque(self, opaque):
         return opaque
 
+    def visitInterface(self, interface, *args, **kwargs):
+        return interface
+
     def visitPolymorphic(self, polymorphic):
-        defaultType = self.visit(polymorphic.defaultType)
         switchExpr = polymorphic.switchExpr
         switchTypes = [(expr, self.visit(type)) for expr, type in polymorphic.switchTypes]
-        return Polymorphic(defaultType, switchExpr, switchTypes)
+        defaultType = self.visit(polymorphic.defaultType)
+        return Polymorphic(switchExpr, switchTypes, defaultType, polymorphic.contextLess)
 
 
-class Collector(Visitor):
-    '''Visitor which collects all unique types as it traverses them.'''
+class MutableRebuilder(Rebuilder):
+    '''Type visitor which derives a mutable type.'''
 
-    def __init__(self):
-        self.__visited = set()
-        self.types = []
+    def visitConst(self, const):
+        # Strip out const qualifier
+        return const.type
 
-    def visit(self, type):
-        if type in self.__visited:
-            return
-        self.__visited.add(type)
-        Visitor.visit(self, type)
-        self.types.append(type)
+    def visitAlias(self, alias):
+        # Tear the alias on type changes
+        type = self.visit(alias.type)
+        if type is alias.type:
+            return alias
+        return type
+
+    def visitReference(self, reference):
+        # Strip out references
+        return reference.type
+
+
+class Traverser(Visitor):
+    '''Visitor which all types.'''
 
-    def visitVoid(self, literal):
+    def visitVoid(self, void, *args, **kwargs):
         pass
 
-    def visitLiteral(self, literal):
+    def visitLiteral(self, literal, *args, **kwargs):
         pass
 
-    def visitString(self, string):
+    def visitString(self, string, *args, **kwargs):
         pass
 
-    def visitConst(self, const):
-        self.visit(const.type)
+    def visitConst(self, const, *args, **kwargs):
+        self.visit(const.type, *args, **kwargs)
 
-    def visitStruct(self, struct):
+    def visitStruct(self, struct, *args, **kwargs):
         for type, name in struct.members:
-            self.visit(type)
+            self.visit(type, *args, **kwargs)
 
-    def visitArray(self, array):
-        self.visit(array.type)
+    def visitArray(self, array, *args, **kwargs):
+        self.visit(array.type, *args, **kwargs)
 
-    def visitBlob(self, array):
+    def visitBlob(self, array, *args, **kwargs):
         pass
 
-    def visitEnum(self, enum):
+    def visitEnum(self, enum, *args, **kwargs):
         pass
 
-    def visitBitmask(self, bitmask):
-        self.visit(bitmask.type)
+    def visitBitmask(self, bitmask, *args, **kwargs):
+        self.visit(bitmask.type, *args, **kwargs)
 
-    def visitPointer(self, pointer):
-        self.visit(pointer.type)
+    def visitPointer(self, pointer, *args, **kwargs):
+        self.visit(pointer.type, *args, **kwargs)
 
-    def visitHandle(self, handle):
-        self.visit(handle.type)
+    def visitIntPointer(self, pointer, *args, **kwargs):
+        pass
 
-    def visitAlias(self, alias):
-        self.visit(alias.type)
+    def visitObjPointer(self, pointer, *args, **kwargs):
+        self.visit(pointer.type, *args, **kwargs)
 
-    def visitOpaque(self, opaque):
+    def visitLinearPointer(self, pointer, *args, **kwargs):
+        self.visit(pointer.type, *args, **kwargs)
+
+    def visitReference(self, reference, *args, **kwargs):
+        self.visit(reference.type, *args, **kwargs)
+
+    def visitHandle(self, handle, *args, **kwargs):
+        self.visit(handle.type, *args, **kwargs)
+
+    def visitAlias(self, alias, *args, **kwargs):
+        self.visit(alias.type, *args, **kwargs)
+
+    def visitOpaque(self, opaque, *args, **kwargs):
         pass
 
-    def visitInterface(self, interface):
+    def visitInterface(self, interface, *args, **kwargs):
         if interface.base is not None:
-            self.visit(interface.base)
+            self.visit(interface.base, *args, **kwargs)
         for method in interface.iterMethods():
             for arg in method.args:
-                self.visit(arg.type)
-            self.visit(method.type)
+                self.visit(arg.type, *args, **kwargs)
+            self.visit(method.type, *args, **kwargs)
 
-    def visitPolymorphic(self, polymorphic):
-        self.visit(polymorphic.defaultType)
+    def visitPolymorphic(self, polymorphic, *args, **kwargs):
+        self.visit(polymorphic.defaultType, *args, **kwargs)
         for expr, type in polymorphic.switchTypes:
-            self.visit(type)
+            self.visit(type, *args, **kwargs)
+
+
+class Collector(Traverser):
+    '''Visitor which collects all unique types as it traverses them.'''
+
+    def __init__(self):
+        self.__visited = set()
+        self.types = []
+
+    def visit(self, type):
+        if type in self.__visited:
+            return
+        self.__visited.add(type)
+        Visitor.visit(self, type)
+        self.types.append(type)
+
 
 
 class API:
@@ -603,7 +811,7 @@ class API:
         self.functions = []
         self.interfaces = []
 
-    def all_types(self):
+    def getAllTypes(self):
         collector = Collector()
         for function in self.functions:
             for arg in function.args:
@@ -617,6 +825,14 @@ class API:
                 collector.visit(method.type)
         return collector.types
 
+    def getAllInterfaces(self):
+        types = self.getAllTypes()
+        interfaces = [type for type in types if isinstance(type, Interface)]
+        for interface in self.interfaces:
+            if interface not in interfaces:
+                interfaces.append(interface)
+        return interfaces
+
     def addFunction(self, function):
         self.functions.append(function)
 
@@ -635,7 +851,7 @@ class API:
         self.addFunctions(api.functions)
         self.addInterfaces(api.interfaces)
 
-    def get_function_by_name(self, name):
+    def getFunctionByName(self, name):
         for function in self.functions:
             if function.name == name:
                 return function
@@ -656,7 +872,10 @@ ULongLong = Literal("unsigned long long", "UInt")
 Float = Literal("float", "Float")
 Double = Literal("double", "Double")
 SizeT = Literal("size_t", "UInt")
-WString = Literal("wchar_t *", "WString")
+
+# C string (i.e., zero terminated)
+CString = String()
+WString = String("wchar_t *", kind="WString")
 
 Int8 = Literal("int8_t", "SInt")
 UInt8 = Literal("uint8_t", "UInt")