]> git.cworth.org Git - apitrace/blobdiff - wrappers/trace.py
Merge branch 'modules'
[apitrace] / wrappers / trace.py
index c838dc78a83df0e8e61190b1f30ef774b20e5ba8..e35108b273920561b0c87a950dbd2a41606d2086 100644 (file)
@@ -39,6 +39,45 @@ def getWrapperInterfaceName(interface):
     return "Wrap" + interface.expr
 
 
     return "Wrap" + interface.expr
 
 
+
+class ExpanderMixin:
+    '''Mixin class that provides a bunch of methods to expand C expressions
+    from the specifications.'''
+
+    __structs = None
+    __indices = None
+
+    def expand(self, expr):
+        # Expand a C expression, replacing certain variables
+        if not isinstance(expr, basestring):
+            return expr
+        variables = {}
+
+        if self.__structs is not None:
+            variables['self'] = '(%s)' % self.__structs[0]
+        if self.__indices is not None:
+            variables['i'] = self.__indices[0]
+
+        expandedExpr = expr.format(**variables)
+        if expandedExpr != expr and 0:
+            sys.stderr.write("  %r -> %r\n" % (expr, expandedExpr))
+        return expandedExpr
+
+    def visitMember(self, structInstance, member_type, *args, **kwargs):
+        self.__structs = (structInstance, self.__structs)
+        try:
+            return self.visit(member_type, *args, **kwargs)
+        finally:
+            _, self.__structs = self.__structs
+
+    def visitElement(self, element_index, element_type, *args, **kwargs):
+        self.__indices = (element_index, self.__indices)
+        try:
+            return self.visit(element_type, *args, **kwargs)
+        finally:
+            _, self.__indices = self.__indices
+
+
 class ComplexValueSerializer(stdapi.OnceVisitor):
     '''Type visitors which generates serialization functions for
     complex types.
 class ComplexValueSerializer(stdapi.OnceVisitor):
     '''Type visitors which generates serialization functions for
     complex types.
@@ -63,21 +102,13 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
         self.visit(const.type)
 
     def visitStruct(self, struct):
         self.visit(const.type)
 
     def visitStruct(self, struct):
-        for type, name in struct.members:
-            self.visit(type)
-        print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
-        print '    static const char * members[%u] = {' % (len(struct.members),)
+        print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
         for type, name,  in struct.members:
         for type, name,  in struct.members:
-            print '        "%s",' % (name,)
-        print '    };'
-        print '    static const trace::StructSig sig = {'
-        print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
-        print '    };'
-        print '    trace::localWriter.beginStruct(&sig);'
-        for type, name in struct.members:
-            self.serializer.visit(type, 'value.%s' % (name,))
-        print '    trace::localWriter.endStruct();'
-        print '}'
+            print '    "%s",' % (name,)
+        print '};'
+        print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
+        print '   %u, "%s", %u, _struct%s_members' % (struct.id, struct.name, len(struct.members), struct.tag)
+        print '};'
         print
 
     def visitArray(self, array):
         print
 
     def visitArray(self, array):
@@ -150,7 +181,7 @@ class ComplexValueSerializer(stdapi.OnceVisitor):
         print
 
 
         print
 
 
-class ValueSerializer(stdapi.Visitor):
+class ValueSerializer(stdapi.Visitor, ExpanderMixin):
     '''Visitor which generates code to serialize any type.
     
     Simple types are serialized inline here, whereas the serialization of
     '''Visitor which generates code to serialize any type.
     
     Simple types are serialized inline here, whereas the serialization of
@@ -158,16 +189,21 @@ class ValueSerializer(stdapi.Visitor):
     ComplexValueSerializer visitor above.
     '''
 
     ComplexValueSerializer visitor above.
     '''
 
+    def __init__(self):
+        #stdapi.Visitor.__init__(self)
+        self.indices = []
+        self.instances = []
+
     def visitLiteral(self, literal, instance):
         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
 
     def visitString(self, string, instance):
     def visitLiteral(self, literal, instance):
         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
 
     def visitString(self, string, instance):
-        if string.kind == 'String':
+        if not string.wide:
             cast = 'const char *'
             cast = 'const char *'
-        elif string.kind == 'WString':
-            cast = 'const wchar_t *'
+            suffix = 'String'
         else:
         else:
-            assert False
+            cast = 'const wchar_t *'
+            suffix = 'WString'
         if cast != string.expr:
             # reinterpret_cast is necessary for GLubyte * <=> char *
             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
         if cast != string.expr:
             # reinterpret_cast is necessary for GLubyte * <=> char *
             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
@@ -175,23 +211,27 @@ class ValueSerializer(stdapi.Visitor):
             length = ', %s' % string.length
         else:
             length = ''
             length = ', %s' % string.length
         else:
             length = ''
-        print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
+        print '    trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
 
     def visitConst(self, const, instance):
         self.visit(const.type, instance)
 
     def visitStruct(self, struct, instance):
 
     def visitConst(self, const, instance):
         self.visit(const.type, instance)
 
     def visitStruct(self, struct, instance):
-        print '    _write__%s(%s);' % (struct.tag, instance)
+        print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
+        for type, name in struct.members:
+            self.visitMember(instance, type, '(%s).%s' % (instance, name,))
+        print '    trace::localWriter.endStruct();'
 
     def visitArray(self, array, instance):
         length = '_c' + array.type.tag
         index = '_i' + array.type.tag
 
     def visitArray(self, array, instance):
         length = '_c' + array.type.tag
         index = '_i' + array.type.tag
+        array_length = self.expand(array.length)
         print '    if (%s) {' % instance
         print '    if (%s) {' % instance
-        print '        size_t %s = %s;' % (length, array.length)
+        print '        size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
         print '        trace::localWriter.beginArray(%s);' % length
         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
         print '            trace::localWriter.beginElement();'
         print '        trace::localWriter.beginArray(%s);' % length
         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
         print '            trace::localWriter.beginElement();'
-        self.visit(array.type, '(%s)[%s]' % (instance, index))
+        self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
         print '            trace::localWriter.endElement();'
         print '        }'
         print '        trace::localWriter.endArray();'
         print '            trace::localWriter.endElement();'
         print '        }'
         print '        trace::localWriter.endArray();'
@@ -200,7 +240,7 @@ class ValueSerializer(stdapi.Visitor):
         print '    }'
 
     def visitBlob(self, blob, instance):
         print '    }'
 
     def visitBlob(self, blob, instance):
-        print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
+        print '    trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
 
     def visitEnum(self, enum, instance):
         print '    trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
 
     def visitEnum(self, enum, instance):
         print '    trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
@@ -272,7 +312,7 @@ class WrapDecider(stdapi.Traverser):
         self.needsWrapping = True
 
 
         self.needsWrapping = True
 
 
-class ValueWrapper(stdapi.Traverser):
+class ValueWrapper(stdapi.Traverser, ExpanderMixin):
     '''Type visitor which will generate the code to wrap an instance.
     
     Wrapping is necessary mostly for interfaces, however interface pointers can
     '''Type visitor which will generate the code to wrap an instance.
     
     Wrapping is necessary mostly for interfaces, however interface pointers can
@@ -281,12 +321,13 @@ class ValueWrapper(stdapi.Traverser):
 
     def visitStruct(self, struct, instance):
         for type, name in struct.members:
 
     def visitStruct(self, struct, instance):
         for type, name in struct.members:
-            self.visit(type, "(%s).%s" % (instance, name))
+            self.visitMember(instance, type, "(%s).%s" % (instance, name))
 
     def visitArray(self, array, instance):
 
     def visitArray(self, array, instance):
+        array_length = self.expand(array.length)
         print "    if (%s) {" % instance
         print "    if (%s) {" % instance
-        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
-        self.visit(array.type, instance + "[_i]")
+        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
+        self.visitElement('_i', array.type, instance + "[_i]")
         print "        }"
         print "    }"
 
         print "        }"
         print "    }"
 
@@ -299,6 +340,8 @@ class ValueWrapper(stdapi.Traverser):
         elem_type = pointer.type.mutable()
         if isinstance(elem_type, stdapi.Interface):
             self.visitInterfacePointer(elem_type, instance)
         elem_type = pointer.type.mutable()
         if isinstance(elem_type, stdapi.Interface):
             self.visitInterfacePointer(elem_type, instance)
+        elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
+            self.visitInterfacePointer(elem_type.type, instance)
         else:
             self.visitPointer(pointer, instance)
     
         else:
             self.visitPointer(pointer, instance)
     
@@ -320,13 +363,31 @@ class ValueUnwrapper(ValueWrapper):
 
     allocated = False
 
 
     allocated = False
 
+    def visitStruct(self, struct, instance):
+        if not self.allocated:
+            # Argument is constant. We need to create a non const
+            print '    {'
+            print "        %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
+            print '        *_t = %s;' % (instance,)
+            assert instance.startswith('*')
+            print '        %s = _t;' % (instance[1:],)
+            instance = '*_t'
+            self.allocated = True
+            try:
+                return ValueWrapper.visitStruct(self, struct, instance)
+            finally:
+                print '    }'
+        else:
+            return ValueWrapper.visitStruct(self, struct, instance)
+
     def visitArray(self, array, instance):
         if self.allocated or isinstance(instance, stdapi.Interface):
             return ValueWrapper.visitArray(self, array, instance)
     def visitArray(self, array, instance):
         if self.allocated or isinstance(instance, stdapi.Interface):
             return ValueWrapper.visitArray(self, array, instance)
+        array_length = self.expand(array.length)
         elem_type = array.type.mutable()
         elem_type = array.type.mutable()
-        print "    if (%s && %s) {" % (instance, array.length)
-        print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
-        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
+        print "    if (%s && %s) {" % (instance, array_length)
+        print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
+        print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
         print "            _t[_i] = %s[_i];" % instance 
         self.allocated = True
         self.visit(array.type, "_t[_i]")
         print "            _t[_i] = %s[_i];" % instance 
         self.allocated = True
         self.visit(array.type, "_t[_i]")
@@ -365,8 +426,9 @@ class Tracer:
         self.header(api)
 
         # Includes
         self.header(api)
 
         # Includes
-        for header in api.headers:
-            print header
+        for module in api.modules:
+            for header in module.headers:
+                print header
         print
 
         # Generate the serializer functions
         print
 
         # Generate the serializer functions
@@ -381,8 +443,10 @@ class Tracer:
         # Function wrappers
         self.interface = None
         self.base = None
         # Function wrappers
         self.interface = None
         self.base = None
-        map(self.traceFunctionDecl, api.functions)
-        map(self.traceFunctionImpl, api.functions)
+        for function in api.getAllFunctions():
+            self.traceFunctionDecl(function)
+        for function in api.getAllFunctions():
+            self.traceFunctionImpl(function)
         print
 
         self.footer(api)
         print
 
         self.footer(api)
@@ -396,6 +460,9 @@ class Tracer:
         print '#else'
         print '#  include <alloca.h> // alloca'
         print '#endif'
         print '#else'
         print '#  include <alloca.h> // alloca'
         print '#endif'
+        print
+        print '#include "trace.hpp"'
+        print
 
     def footer(self, api):
         pass
 
     def footer(self, api):
         pass
@@ -422,6 +489,16 @@ class Tracer:
         print function.prototype() + ' {'
         if function.type is not stdapi.Void:
             print '    %s _result;' % function.type
         print function.prototype() + ' {'
         if function.type is not stdapi.Void:
             print '    %s _result;' % function.type
+
+        # No-op if tracing is disabled
+        print '    if (!trace::isTracingEnabled()) {'
+        Tracer.invokeFunction(self, function)
+        if function.type is not stdapi.Void:
+            print '        return _result;'
+        else:
+            print '        return;'
+        print '    }'
+
         self.traceFunctionImplBody(function)
         if function.type is not stdapi.Void:
             print '    return _result;'
         self.traceFunctionImplBody(function)
         if function.type is not stdapi.Void:
             print '    return _result;'
@@ -439,10 +516,12 @@ class Tracer:
         self.invokeFunction(function)
         if not function.internal:
             print '    trace::localWriter.beginLeave(_call);'
         self.invokeFunction(function)
         if not function.internal:
             print '    trace::localWriter.beginLeave(_call);'
+            print '    if (%s) {' % self.wasFunctionSuccessful(function)
             for arg in function.args:
                 if arg.output:
                     self.serializeArg(function, arg)
                     self.wrapArg(function, arg)
             for arg in function.args:
                 if arg.output:
                     self.serializeArg(function, arg)
                     self.wrapArg(function, arg)
+            print '    }'
             if function.type is not stdapi.Void:
                 self.serializeRet(function, "_result")
             print '    trace::localWriter.endLeave();'
             if function.type is not stdapi.Void:
                 self.serializeRet(function, "_result")
             print '    trace::localWriter.endLeave();'
@@ -457,6 +536,13 @@ class Tracer:
         dispatch = prefix + function.name + suffix
         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
 
         dispatch = prefix + function.name + suffix
         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
 
+    def wasFunctionSuccessful(self, function):
+        if function.type is stdapi.Void:
+            return 'true'
+        if str(function.type) == 'HRESULT':
+            return 'SUCCEEDED(_result)'
+        return 'true'
+
     def serializeArg(self, function, arg):
         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
         self.serializeArgValue(function, arg)
     def serializeArg(self, function, arg):
         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
         self.serializeArgValue(function, arg)
@@ -596,10 +682,13 @@ class Tracer:
         self.invokeMethod(interface, base, method)
 
         print '    trace::localWriter.beginLeave(_call);'
         self.invokeMethod(interface, base, method)
 
         print '    trace::localWriter.beginLeave(_call);'
+
+        print '    if (%s) {' % self.wasFunctionSuccessful(method)
         for arg in method.args:
             if arg.output:
                 self.serializeArg(method, arg)
                 self.wrapArg(method, arg)
         for arg in method.args:
             if arg.output:
                 self.serializeArg(method, arg)
                 self.wrapArg(method, arg)
+        print '    }'
 
         if method.type is not stdapi.Void:
             self.serializeRet(method, '_result')
 
         if method.type is not stdapi.Void:
             self.serializeRet(method, '_result')
@@ -682,4 +771,15 @@ class Tracer:
         print '        trace::localWriter.endEnter();'
         print '        trace::localWriter.beginLeave(_call);'
         print '        trace::localWriter.endLeave();'
         print '        trace::localWriter.endEnter();'
         print '        trace::localWriter.beginLeave(_call);'
         print '        trace::localWriter.endLeave();'
+    
+    def fake_call(self, function, args):
+        print '            unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
+        for arg, instance in zip(function.args, args):
+            assert not arg.output
+            print '            trace::localWriter.beginArg(%u);' % (arg.index,)
+            self.serializeValue(arg.type, instance)
+            print '            trace::localWriter.endArg();'
+        print '            trace::localWriter.endEnter();'
+        print '            trace::localWriter.beginLeave(_fake_call);'
+        print '            trace::localWriter.endLeave();'