]> git.cworth.org Git - apitrace/blobdiff - wrappers/trace.py
d2d: Eliminate custom tracers.
[apitrace] / wrappers / trace.py
index 5d0a566f413b77119cf26ccf8b4831d038e88090..79156acafbd328086f6ea021c41bf5c64a8ce2d7 100644 (file)
@@ -340,6 +340,8 @@ class ValueWrapper(stdapi.Traverser, ExpanderMixin):
         elem_type = pointer.type.mutable()
         if isinstance(elem_type, stdapi.Interface):
             self.visitInterfacePointer(elem_type, instance)
+        elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
+            self.visitInterfacePointer(elem_type.type, instance)
         else:
             self.visitPointer(pointer, instance)
     
@@ -361,6 +363,23 @@ class ValueUnwrapper(ValueWrapper):
 
     allocated = False
 
+    def visitStruct(self, struct, instance):
+        if not self.allocated:
+            # Argument is constant. We need to create a non const
+            print '    {'
+            print "        %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
+            print '        *_t = %s;' % (instance,)
+            assert instance.startswith('*')
+            print '        %s = _t;' % (instance[1:],)
+            instance = '*_t'
+            self.allocated = True
+            try:
+                return ValueWrapper.visitStruct(self, struct, instance)
+            finally:
+                print '    }'
+        else:
+            return ValueWrapper.visitStruct(self, struct, instance)
+
     def visitArray(self, array, instance):
         if self.allocated or isinstance(instance, stdapi.Interface):
             return ValueWrapper.visitArray(self, array, instance)
@@ -407,8 +426,9 @@ class Tracer:
         self.header(api)
 
         # Includes
-        for header in api.headers:
-            print header
+        for module in api.modules:
+            for header in module.headers:
+                print header
         print
 
         # Generate the serializer functions
@@ -423,8 +443,10 @@ class Tracer:
         # Function wrappers
         self.interface = None
         self.base = None
-        map(self.traceFunctionDecl, api.functions)
-        map(self.traceFunctionImpl, api.functions)
+        for function in api.getAllFunctions():
+            self.traceFunctionDecl(function)
+        for function in api.getAllFunctions():
+            self.traceFunctionImpl(function)
         print
 
         self.footer(api)
@@ -494,10 +516,12 @@ class Tracer:
         self.invokeFunction(function)
         if not function.internal:
             print '    trace::localWriter.beginLeave(_call);'
+            print '    if (%s) {' % self.wasFunctionSuccessful(function)
             for arg in function.args:
                 if arg.output:
                     self.serializeArg(function, arg)
                     self.wrapArg(function, arg)
+            print '    }'
             if function.type is not stdapi.Void:
                 self.serializeRet(function, "_result")
             print '    trace::localWriter.endLeave();'
@@ -512,6 +536,13 @@ class Tracer:
         dispatch = prefix + function.name + suffix
         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
 
+    def wasFunctionSuccessful(self, function):
+        if function.type is stdapi.Void:
+            return 'true'
+        if str(function.type) == 'HRESULT':
+            return 'SUCCEEDED(_result)'
+        return 'false'
+
     def serializeArg(self, function, arg):
         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
         self.serializeArgValue(function, arg)
@@ -651,10 +682,13 @@ class Tracer:
         self.invokeMethod(interface, base, method)
 
         print '    trace::localWriter.beginLeave(_call);'
+
+        print '    if (%s) {' % self.wasFunctionSuccessful(method)
         for arg in method.args:
             if arg.output:
                 self.serializeArg(method, arg)
                 self.wrapArg(method, arg)
+        print '    }'
 
         if method.type is not stdapi.Void:
             self.serializeRet(method, '_result')