]> git.cworth.org Git - apitrace/blobdiff - retrace/retrace.py
Retrace IUnknown::AddRef/Release correctly.
[apitrace] / retrace / retrace.py
index 369c13b44b90bb98c759b65d5cf706e840372807..2e0350107a64f7e17ced6f9145acc462053816e2 100644 (file)
@@ -41,25 +41,6 @@ class UnsupportedType(Exception):
     pass
 
 
-class MutableRebuilder(stdapi.Rebuilder):
-    '''Type visitor which derives a mutable type.'''
-
-    def visitConst(self, const):
-        # Strip out const qualifier
-        return const.type
-
-    def visitAlias(self, alias):
-        # Tear the alias on type changes
-        type = self.visit(alias.type)
-        if type is alias.type:
-            return alias
-        return type
-
-    def visitReference(self, reference):
-        # Strip out references
-        return reference.type
-
-
 def lookupHandle(handle, value):
     if handle.key is None:
         return "__%s_map[%s]" % (handle.name, value)
@@ -372,7 +353,7 @@ class Retracer:
         print '    (void)_allocator;'
         success = True
         for arg in function.args:
-            arg_type = MutableRebuilder().visit(arg.type)
+            arg_type = arg.type.mutable()
             print '    %s %s;' % (arg_type, arg.name)
             rvalue = 'call.arg(%u)' % (arg.index,)
             lvalue = arg.name
@@ -393,7 +374,7 @@ class Retracer:
     def swizzleValues(self, function):
         for arg in function.args:
             if arg.output:
-                arg_type = MutableRebuilder().visit(arg.type)
+                arg_type = arg.type.mutable()
                 rvalue = 'call.arg(%u)' % (arg.index,)
                 lvalue = arg.name
                 try:
@@ -441,6 +422,14 @@ class Retracer:
             print '    %s(%s);' % (function.name, arg_names)
 
     def invokeInterfaceMethod(self, interface, method):
+        # On release our reference when we reach Release() == 0 call in the
+        # trace.
+        if method.name == 'Release':
+            print '    if (call.ret->toUInt()) {'
+            print '        return;'
+            print '    }'
+            print '    _obj_map.erase(call.arg(0).toUIntPtr());'
+
         arg_names = ", ".join(method.argNames())
         if method.type is not stdapi.Void:
             print '    %s __result;' % (method.type)