]> git.cworth.org Git - apitrace/blobdiff - wrappers/trace.py
dxva: Eliminate the globals hack.
[apitrace] / wrappers / trace.py
index 30825617ceca523567d69cbc876f8fa0d3a294ef..5d0a566f413b77119cf26ccf8b4831d038e88090 100644 (file)
@@ -39,6 +39,45 @@ def getWrapperInterfaceName(interface):
     return "Wrap" + interface.expr
 
 
+
+class ExpanderMixin:
+    '''Mixin class that provides a bunch of methods to expand C expressions
+    from the specifications.'''
+
+    __structs = None
+    __indices = None
+
+    def expand(self, expr):
+        # Expand a C expression, replacing certain variables
+        if not isinstance(expr, basestring):
+            return expr
+        variables = {}
+
+        if self.__structs is not None:
+            variables['self'] = '(%s)' % self.__structs[0]
+        if self.__indices is not None:
+            variables['i'] = self.__indices[0]
+
+        expandedExpr = expr.format(**variables)
+        if expandedExpr != expr and 0:
+            sys.stderr.write("  %r -> %r\n" % (expr, expandedExpr))
+        return expandedExpr
+
+    def visitMember(self, structInstance, member_type, *args, **kwargs):
+        self.__structs = (structInstance, self.__structs)
+        try:
+            return self.visit(member_type, *args, **kwargs)
+        finally:
+            _, self.__structs = self.__structs
+
+    def visitElement(self, element_index, element_type, *args, **kwargs):
+        self.__indices = (element_index, self.__indices)
+        try:
+            return self.visit(element_type, *args, **kwargs)
+        finally:
+            _, self.__indices = self.__indices
+
+
 class ComplexValueSerializer(stdapi.OnceVisitor):
     '''Type visitors which generates serialization functions for
     complex types.
@@ -142,7 +181,7 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
         print
 
 
-class ValueSerializer(stdapi.Visitor):
+class ValueSerializer(stdapi.Visitor, ExpanderMixin):
     '''Visitor which generates code to serialize any type.
     
     Simple types are serialized inline here, whereas the serialization of
@@ -155,22 +194,6 @@ class ValueSerializer(stdapi.Visitor):
         self.indices = []
         self.instances = []
 
-    def expand(self, expr):
-        # Expand a C expression, replacing certain variables
-        variables = {}
-        try:
-            variables['self'] = self.instances[-1]
-        except IndexError:
-            pass
-        try:
-            variables['i'] = self.indices[-1]
-        except IndexError:
-            pass
-        expandedExpr = expr.format(**variables)
-        if expandedExpr != expr:
-            sys.stderr.write("  %r -> %r\n" % (expr, expandedExpr))
-        return expandedExpr
-
     def visitLiteral(self, literal, instance):
         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
 
@@ -195,27 +218,20 @@ class ValueSerializer(stdapi.Visitor):
 
     def visitStruct(self, struct, instance):
         print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
-        self.instances.append(instance)
-        try:
-            for type, name in struct.members:
-                self.visit(type, '(%s).%s' % (instance, name,))
-        finally:
-            self.instances.pop()
+        for type, name in struct.members:
+            self.visitMember(instance, type, '(%s).%s' % (instance, name,))
         print '    trace::localWriter.endStruct();'
 
     def visitArray(self, array, instance):
         length = '_c' + array.type.tag
         index = '_i' + array.type.tag
+        array_length = self.expand(array.length)
         print '    if (%s) {' % instance
-        print '        size_t %s = %s > 0 ? %s : 0;' % (length, array.length, array.length)
+        print '        size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
         print '        trace::localWriter.beginArray(%s);' % length
         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
         print '            trace::localWriter.beginElement();'
-        self.indices.append(index)
-        try:
-            self.visit(array.type, '(%s)[%s]' % (instance, index))
-        finally:
-            self.indices.pop()
+        self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
         print '            trace::localWriter.endElement();'
         print '        }'
         print '        trace::localWriter.endArray();'
@@ -296,7 +312,7 @@ class WrapDecider(stdapi.Traverser):
         self.needsWrapping = True
 
 
-class ValueWrapper(stdapi.Traverser):
+class ValueWrapper(stdapi.Traverser, ExpanderMixin):
     '''Type visitor which will generate the code to wrap an instance.
     
     Wrapping is necessary mostly for interfaces, however interface pointers can
@@ -305,12 +321,13 @@ class ValueWrapper(stdapi.Traverser):
 
     def visitStruct(self, struct, instance):
         for type, name in struct.members:
-            self.visit(type, "(%s).%s" % (instance, name))
+            self.visitMember(instance, type, "(%s).%s" % (instance, name))
 
     def visitArray(self, array, instance):
+        array_length = self.expand(array.length)
         print "    if (%s) {" % instance
-        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
-        self.visit(array.type, instance + "[_i]")
+        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
+        self.visitElement('_i', array.type, instance + "[_i]")
         print "        }"
         print "    }"
 
@@ -347,10 +364,11 @@ class ValueUnwrapper(ValueWrapper):
     def visitArray(self, array, instance):
         if self.allocated or isinstance(instance, stdapi.Interface):
             return ValueWrapper.visitArray(self, array, instance)
+        array_length = self.expand(array.length)
         elem_type = array.type.mutable()
-        print "    if (%s && %s) {" % (instance, array.length)
-        print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
-        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
+        print "    if (%s && %s) {" % (instance, array_length)
+        print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
+        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
         print "            _t[_i] = %s[_i];" % instance 
         self.allocated = True
         self.visit(array.type, "_t[_i]")