]> git.cworth.org Git - apitrace/blobdiff - retrace/retrace.py
Merge branch 'modules'
[apitrace] / retrace / retrace.py
index 6206905dcd787d166f95617a4a851a99466ce09d..5b5812faa101feb04c651608c24561bf22d9f0af 100644 (file)
@@ -34,7 +34,6 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
 
 
 import specs.stdapi as stdapi
-import specs.glapi as glapi
 
 
 class UnsupportedType(Exception):
@@ -151,12 +150,7 @@ class ValueDeserializer(stdapi.Visitor):
         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
 
     def visitObjPointer(self, pointer, lvalue, rvalue):
-        old_lvalue = '(%s).toUIntPtr()' % (rvalue,)
-        new_lvalue = '_obj_map[%s]' % (old_lvalue,)
-        print '    if (retrace::verbosity >= 2) {'
-        print '        std::cout << std::hex << "obj 0x" << size_t(%s) << " <- 0x" << size_t(%s) << std::dec <<"\\n";' % (old_lvalue, new_lvalue)
-        print '    }'
-        print '    %s = static_cast<%s>(%s);' % (lvalue, pointer, new_lvalue)
+        print '    %s = static_cast<%s>(retrace::toObjPointer(call, %s));' % (lvalue, pointer, rvalue)
 
     def visitLinearPointer(self, pointer, lvalue, rvalue):
         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
@@ -248,7 +242,7 @@ class SwizzledValueRegistrator(stdapi.Visitor):
         pass
     
     def visitObjPointer(self, pointer, lvalue, rvalue):
-        print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
+        print r'    retrace::addObj(call, %s, %s);' % (rvalue, lvalue)
     
     def visitLinearPointer(self, pointer, lvalue, rvalue):
         assert pointer.size is not None
@@ -323,8 +317,12 @@ class Retracer:
     def retraceFunctionBody(self, function):
         assert function.sideeffects
 
+        if function.type is not stdapi.Void:
+            self.checkOrigResult(function)
+
         self.deserializeArgs(function)
         
+        self.declareRet(function)
         self.invokeFunction(function)
 
         self.swizzleValues(function)
@@ -332,19 +330,34 @@ class Retracer:
     def retraceInterfaceMethodBody(self, interface, method):
         assert method.sideeffects
 
+        if method.type is not stdapi.Void:
+            self.checkOrigResult(method)
+
         self.deserializeThisPointer(interface)
 
         self.deserializeArgs(method)
         
+        self.declareRet(method)
         self.invokeInterfaceMethod(interface, method)
 
         self.swizzleValues(method)
 
+    def checkOrigResult(self, function):
+        '''Hook for checking the original result, to prevent succeeding now
+        where the original did not, which would cause diversion and potentially
+        unpredictable results.'''
+
+        assert function.type is not stdapi.Void
+
+        if str(function.type) == 'HRESULT':
+            print r'    if (call.ret && FAILED(call.ret->toSInt())) {'
+            print r'        return;'
+            print r'    }'
+
     def deserializeThisPointer(self, interface):
         print r'    %s *_this;' % (interface.name,)
-        print r'    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
+        print r'    _this = static_cast<%s *>(retrace::toObjPointer(call, call.arg(0)));' % (interface.name,)
         print r'    if (!_this) {'
-        print r'        retrace::warning(call) << "NULL this pointer\n";'
         print r'        return;'
         print r'    }'
 
@@ -412,10 +425,13 @@ class Retracer:
         visitor = SwizzledValueRegistrator()
         visitor.visit(type, lvalue, rvalue)
 
+    def declareRet(self, function):
+        if function.type is not stdapi.Void:
+            print '    %s _result;' % (function.type)
+
     def invokeFunction(self, function):
         arg_names = ", ".join(function.argNames())
         if function.type is not stdapi.Void:
-            print '    %s _result;' % (function.type)
             print '    _result = %s(%s);' % (function.name, arg_names)
             print '    (void)_result;'
         else:
@@ -428,11 +444,10 @@ class Retracer:
             print '    if (call.ret->toUInt()) {'
             print '        return;'
             print '    }'
-            print '    _obj_map.erase(call.arg(0).toUIntPtr());'
+            print '    retrace::delObj(call.arg(0));'
 
         arg_names = ", ".join(method.argNames())
         if method.type is not stdapi.Void:
-            print '    %s _result;' % (method.type)
             print '    _result = _this->%s(%s);' % (method.name, arg_names)
             print '    (void)_result;'
         else:
@@ -448,6 +463,7 @@ class Retracer:
         print '#include "os_time.hpp"'
         print '#include "trace_parser.hpp"'
         print '#include "retrace.hpp"'
+        print '#include "retrace_swizzle.hpp"'
         print
 
         types = api.getAllTypes()
@@ -463,25 +479,23 @@ class Retracer:
                 handle_names.add(handle.name)
         print
 
-        print 'static std::map<unsigned long long, void *> _obj_map;'
-        print
-
-        functions = filter(self.filterFunction, api.functions)
+        functions = filter(self.filterFunction, api.getAllFunctions())
         for function in functions:
-            if function.sideeffects:
+            if function.sideeffects and not function.internal:
                 self.retraceFunction(function)
         interfaces = api.getAllInterfaces()
         for interface in interfaces:
             for method in interface.iterMethods():
-                if method.sideeffects:
+                if method.sideeffects and not method.internal:
                     self.retraceInterfaceMethod(interface, method)
 
         print 'const retrace::Entry %s[] = {' % self.table_name
         for function in functions:
-            if function.sideeffects:
-                print '    {"%s", &retrace_%s},' % (function.name, function.name)
-            else:
-                print '    {"%s", &retrace::ignore},' % (function.name,)
+            if not function.internal:
+                if function.sideeffects:
+                    print '    {"%s", &retrace_%s},' % (function.name, function.name)
+                else:
+                    print '    {"%s", &retrace::ignore},' % (function.name,)
         for interface in interfaces:
             for method in interface.iterMethods():                
                 if method.sideeffects: