]> git.cworth.org Git - apitrace/blobdiff - retrace.py
D3D retrace checkpoint.
[apitrace] / retrace.py
index 22ce5eda3066ff321b578bbf09659e2ea0e8bd46..15cdaf5038e328d05e8646c564670d9ffe8c13fa 100644 (file)
@@ -34,15 +34,16 @@ import specs.glapi as glapi
 
 
 class ConstRemover(stdapi.Rebuilder):
 
 
 class ConstRemover(stdapi.Rebuilder):
+    '''Type visitor which strips out const qualifiers from types.'''
 
 
-    def visit_const(self, const):
+    def visitConst(self, const):
         return const.type
 
         return const.type
 
-    def visit_opaque(self, opaque):
+    def visitOpaque(self, opaque):
         return opaque
 
 
         return opaque
 
 
-def handle_entry(handle, value):
+def lookupHandle(handle, value):
     if handle.key is None:
         return "__%s_map[%s]" % (handle.name, value)
     else:
     if handle.key is None:
         return "__%s_map[%s]" % (handle.name, value)
     else:
@@ -50,26 +51,25 @@ def handle_entry(handle, value):
         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
 
 
         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
 
 
-class ValueExtractor(stdapi.Visitor):
+class ValueDeserializer(stdapi.Visitor):
 
 
-    def visit_literal(self, literal, lvalue, rvalue):
-        #if literal.format in ('Bool', 'UInt'):
-        print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.format)
+    def visitLiteral(self, literal, lvalue, rvalue):
+        print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
 
 
-    def visit_const(self, const, lvalue, rvalue):
+    def visitConst(self, const, lvalue, rvalue):
         self.visit(const.type, lvalue, rvalue)
 
         self.visit(const.type, lvalue, rvalue)
 
-    def visit_alias(self, alias, lvalue, rvalue):
+    def visitAlias(self, alias, lvalue, rvalue):
         self.visit(alias.type, lvalue, rvalue)
     
         self.visit(alias.type, lvalue, rvalue)
     
-    def visit_enum(self, enum, lvalue, rvalue):
+    def visitEnum(self, enum, lvalue, rvalue):
         print '    %s = (%s).toSInt();' % (lvalue, rvalue)
 
         print '    %s = (%s).toSInt();' % (lvalue, rvalue)
 
-    def visit_bitmask(self, bitmask, lvalue, rvalue):
+    def visitBitmask(self, bitmask, lvalue, rvalue):
         self.visit(bitmask.type, lvalue, rvalue)
 
         self.visit(bitmask.type, lvalue, rvalue)
 
-    def visit_array(self, array, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (array.tag, rvalue)
+    def visitArray(self, array, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
         print '    if (__a%s) {' % (array.tag)
         length = '__a%s->values.size()' % array.tag
         print '        %s = new %s[%s];' % (lvalue, array.type, length)
         print '    if (__a%s) {' % (array.tag)
         length = '__a%s->values.size()' % array.tag
         print '        %s = new %s[%s];' % (lvalue, array.type, length)
@@ -83,8 +83,8 @@ class ValueExtractor(stdapi.Visitor):
             print '        %s = NULL;' % lvalue
             print '    }'
     
             print '        %s = NULL;' % lvalue
             print '    }'
     
-    def visit_pointer(self, pointer, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (pointer.tag, rvalue)
+    def visitPointer(self, pointer, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
         print '    if (__a%s) {' % (pointer.tag)
         print '        %s = new %s;' % (lvalue, pointer.type)
         try:
         print '    if (__a%s) {' % (pointer.tag)
         print '        %s = new %s;' % (lvalue, pointer.type)
         try:
@@ -94,47 +94,56 @@ class ValueExtractor(stdapi.Visitor):
             print '        %s = NULL;' % lvalue
             print '    }'
 
             print '        %s = NULL;' % lvalue
             print '    }'
 
-    def visit_handle(self, handle, lvalue, rvalue):
-        OpaqueValueExtractor().visit(handle.type, lvalue, rvalue);
-        new_lvalue = handle_entry(handle, lvalue)
+    def visitIntPointer(self, pointer, lvalue, rvalue):
+        print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
+
+    def visitLinearPointer(self, pointer, lvalue, rvalue):
+        print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
+
+    def visitHandle(self, handle, lvalue, rvalue):
+        #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
+        self.visit(handle.type, lvalue, rvalue);
+        new_lvalue = lookupHandle(handle, lvalue)
         print '    if (retrace::verbosity >= 2) {'
         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
         print '    }'
         print '    %s = %s;' % (lvalue, new_lvalue)
     
         print '    if (retrace::verbosity >= 2) {'
         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
         print '    }'
         print '    %s = %s;' % (lvalue, new_lvalue)
     
-    def visit_blob(self, blob, lvalue, rvalue):
+    def visitBlob(self, blob, lvalue, rvalue):
         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
     
         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
     
-    def visit_string(self, string, lvalue, rvalue):
+    def visitString(self, string, lvalue, rvalue):
         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
 
 
         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
 
 
-class OpaqueValueExtractor(ValueExtractor):
+class OpaqueValueDeserializer(ValueDeserializer):
     '''Value extractor that also understands opaque values.
 
     Normally opaque values can't be retraced, unless they are being extracted
     in the context of handles.'''
 
     '''Value extractor that also understands opaque values.
 
     Normally opaque values can't be retraced, unless they are being extracted
     in the context of handles.'''
 
-    def visit_opaque(self, opaque, lvalue, rvalue):
+    def visitOpaque(self, opaque, lvalue, rvalue):
         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
 
 
         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
 
 
-class ValueWrapper(stdapi.Visitor):
+class SwizzledValueRegistrator(stdapi.Visitor):
+    '''Type visitor which will register (un)swizzled value pairs, to later be
+    swizzled.'''
 
 
-    def visit_literal(self, literal, lvalue, rvalue):
+    def visitLiteral(self, literal, lvalue, rvalue):
         pass
 
         pass
 
-    def visit_alias(self, alias, lvalue, rvalue):
+    def visitAlias(self, alias, lvalue, rvalue):
         self.visit(alias.type, lvalue, rvalue)
     
         self.visit(alias.type, lvalue, rvalue)
     
-    def visit_enum(self, enum, lvalue, rvalue):
+    def visitEnum(self, enum, lvalue, rvalue):
         pass
 
         pass
 
-    def visit_bitmask(self, bitmask, lvalue, rvalue):
+    def visitBitmask(self, bitmask, lvalue, rvalue):
         pass
 
         pass
 
-    def visit_array(self, array, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (array.tag, rvalue)
+    def visitArray(self, array, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
         print '    if (__a%s) {' % (array.tag)
         length = '__a%s->values.size()' % array.tag
         index = '__j' + array.tag
         print '    if (__a%s) {' % (array.tag)
         length = '__a%s->values.size()' % array.tag
         index = '__j' + array.tag
@@ -145,20 +154,28 @@ class ValueWrapper(stdapi.Visitor):
             print '        }'
             print '    }'
     
             print '        }'
             print '    }'
     
-    def visit_pointer(self, pointer, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (pointer.tag, rvalue)
+    def visitPointer(self, pointer, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
         print '    if (__a%s) {' % (pointer.tag)
         try:
             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
         finally:
             print '    }'
     
         print '    if (__a%s) {' % (pointer.tag)
         try:
             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
         finally:
             print '    }'
     
-    def visit_handle(self, handle, lvalue, rvalue):
+    def visitIntPointer(self, pointer, lvalue, rvalue):
+        pass
+    
+    def visitLinearPointer(self, pointer, lvalue, rvalue):
+        assert pointer.size is not None
+        if pointer.size is not None:
+            print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
+
+    def visitHandle(self, handle, lvalue, rvalue):
         print '    %s __orig_result;' % handle.type
         print '    %s __orig_result;' % handle.type
-        OpaqueValueExtractor().visit(handle.type, '__orig_result', rvalue);
+        OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
         if handle.range is None:
             rvalue = "__orig_result"
         if handle.range is None:
             rvalue = "__orig_result"
-            entry = handle_entry(handle, rvalue) 
+            entry = lookupHandle(handle, rvalue) 
             print "    %s = %s;" % (entry, lvalue)
             print '    if (retrace::verbosity >= 2) {'
             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
             print "    %s = %s;" % (entry, lvalue)
             print '    if (retrace::verbosity >= 2) {'
             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
@@ -167,7 +184,7 @@ class ValueWrapper(stdapi.Visitor):
             i = '__h' + handle.tag
             lvalue = "%s + %s" % (lvalue, i)
             rvalue = "__orig_result + %s" % (i,)
             i = '__h' + handle.tag
             lvalue = "%s + %s" % (lvalue, i)
             rvalue = "__orig_result + %s" % (i,)
-            entry = handle_entry(handle, rvalue) 
+            entry = lookupHandle(handle, rvalue) 
             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
             print '        {entry} = {lvalue};'.format(**locals())
             print '        if (retrace::verbosity >= 2) {'
             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
             print '        {entry} = {lvalue};'.format(**locals())
             print '        if (retrace::verbosity >= 2) {'
@@ -175,26 +192,56 @@ class ValueWrapper(stdapi.Visitor):
             print '        }'
             print '    }'
     
             print '        }'
             print '    }'
     
-    def visit_blob(self, blob, lvalue, rvalue):
+    def visitBlob(self, blob, lvalue, rvalue):
         pass
     
         pass
     
-    def visit_string(self, string, lvalue, rvalue):
+    def visitString(self, string, lvalue, rvalue):
         pass
 
 
 class Retracer:
 
         pass
 
 
 class Retracer:
 
-    def retrace_function(self, function):
-        print 'static void retrace_%s(Trace::Call &call) {' % function.name
-        self.retrace_function_body(function)
+    def retraceFunction(self, function):
+        print 'static void retrace_%s(trace::Call &call) {' % function.name
+        self.retraceFunctionBody(function)
+        print '}'
+        print
+
+    def retraceInterfaceMethod(self, interface, method):
+        print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
+        self.retraceInterfaceMethodBody(interface, method)
         print '}'
         print
 
         print '}'
         print
 
-    def retrace_function_body(self, function):
+    def retraceFunctionBody(self, function):
         if not function.sideeffects:
             print '    (void)call;'
             return
 
         if not function.sideeffects:
             print '    (void)call;'
             return
 
+        self.deserializeArgs(function)
+        
+        self.invokeFunction(function)
+
+        self.swizzleValues(function)
+
+    def retraceInterfaceMethodBody(self, interface, method):
+        if not method.sideeffects:
+            print '    (void)call;'
+            return
+
+        self.deserializeThisPointer(interface)
+
+        self.deserializeArgs(method)
+        
+        self.invokeInterfaceMethod(interface, method)
+
+        self.swizzleValues(method)
+
+    def deserializeThisPointer(self, interface):
+        print '    %s *_this;' % (interface.name,)
+        # FIXME
+
+    def deserializeArgs(self, function):
         success = True
         for arg in function.args:
             arg_type = ConstRemover().visit(arg.type)
         success = True
         for arg in function.args:
             arg_type = ConstRemover().visit(arg.type)
@@ -203,49 +250,54 @@ class Retracer:
             rvalue = 'call.arg(%u)' % (arg.index,)
             lvalue = arg.name
             try:
             rvalue = 'call.arg(%u)' % (arg.index,)
             lvalue = arg.name
             try:
-                self.extract_arg(function, arg, arg_type, lvalue, rvalue)
+                self.extractArg(function, arg, arg_type, lvalue, rvalue)
             except NotImplementedError:
             except NotImplementedError:
-                success = False
+                success =  False
                 print '    %s = 0; // FIXME' % arg.name
                 print '    %s = 0; // FIXME' % arg.name
+
         if not success:
             print '    if (1) {'
         if not success:
             print '    if (1) {'
-            self.fail_function(function)
+            self.failFunction(function)
+            if function.name[-1].islower():
+                sys.stderr.write('warning: unsupported %s call\n' % function.name)
             print '    }'
             print '    }'
-        self.call_function(function)
+
+    def swizzleValues(self, function):
         for arg in function.args:
             if arg.output:
                 arg_type = ConstRemover().visit(arg.type)
                 rvalue = 'call.arg(%u)' % (arg.index,)
                 lvalue = arg.name
                 try:
         for arg in function.args:
             if arg.output:
                 arg_type = ConstRemover().visit(arg.type)
                 rvalue = 'call.arg(%u)' % (arg.index,)
                 lvalue = arg.name
                 try:
-                    ValueWrapper().visit(arg_type, lvalue, rvalue)
+                    self.regiterSwizzledValue(arg_type, lvalue, rvalue)
                 except NotImplementedError:
                     print '    // XXX: %s' % arg.name
         if function.type is not stdapi.Void:
             rvalue = '*call.ret'
             lvalue = '__result'
             try:
                 except NotImplementedError:
                     print '    // XXX: %s' % arg.name
         if function.type is not stdapi.Void:
             rvalue = '*call.ret'
             lvalue = '__result'
             try:
-                ValueWrapper().visit(function.type, lvalue, rvalue)
+                self.regiterSwizzledValue(function.type, lvalue, rvalue)
             except NotImplementedError:
                 print '    // XXX: result'
             except NotImplementedError:
                 print '    // XXX: result'
-        if not success:
-            if function.name[-1].islower():
-                sys.stderr.write('warning: unsupported %s call\n' % function.name)
 
 
-    def fail_function(self, function):
+    def failFunction(self, function):
         print '    if (retrace::verbosity >= 0) {'
         print '        retrace::unsupported(call);'
         print '    }'
         print '    return;'
 
         print '    if (retrace::verbosity >= 0) {'
         print '        retrace::unsupported(call);'
         print '    }'
         print '    return;'
 
-    def extract_arg(self, function, arg, arg_type, lvalue, rvalue):
-        ValueExtractor().visit(arg_type, lvalue, rvalue)
+    def extractArg(self, function, arg, arg_type, lvalue, rvalue):
+        ValueDeserializer().visit(arg_type, lvalue, rvalue)
     
     
-    def extract_opaque_arg(self, function, arg, arg_type, lvalue, rvalue):
-        OpaqueValueExtractor().visit(arg_type, lvalue, rvalue)
+    def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
+        OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
 
 
-    def call_function(self, function):
-        arg_names = ", ".join([arg.name for arg in function.args])
+    def regiterSwizzledValue(self, type, lvalue, rvalue):
+        visitor = SwizzledValueRegistrator()
+        visitor.visit(type, lvalue, rvalue)
+
+    def invokeFunction(self, function):
+        arg_names = ", ".join(function.argNames())
         if function.type is not stdapi.Void:
             print '    %s __result;' % (function.type)
             print '    __result = %s(%s);' % (function.name, arg_names)
         if function.type is not stdapi.Void:
             print '    %s __result;' % (function.type)
             print '    __result = %s(%s);' % (function.name, arg_names)
@@ -253,32 +305,27 @@ class Retracer:
         else:
             print '    %s(%s);' % (function.name, arg_names)
 
         else:
             print '    %s(%s);' % (function.name, arg_names)
 
-    def filter_function(self, function):
+    def invokeInterfaceMethod(self, interface, method):
+        arg_names = ", ".join(method.argNames())
+        if method.type is not stdapi.Void:
+            print '    %s __result;' % (method.type)
+            print '    __result = _this->%s(%s);' % (method.name, arg_names)
+            print '    (void)__result;'
+        else:
+            print '    _this->%s(%s);' % (method.name, arg_names)
+
+    def filterFunction(self, function):
         return True
 
     table_name = 'retrace::callbacks'
 
         return True
 
     table_name = 'retrace::callbacks'
 
-    def retrace_functions(self, functions):
-        functions = filter(self.filter_function, functions)
-
-        for function in functions:
-            self.retrace_function(function)
-
-        print 'const retrace::Entry %s[] = {' % self.table_name
-        for function in functions:
-            print '    {"%s", &retrace_%s},' % (function.name, function.name)
-        print '    {NULL, NULL}'
-        print '};'
-        print
-
-
-    def retrace_api(self, api):
+    def retraceApi(self, api):
 
         print '#include "trace_parser.hpp"'
         print '#include "retrace.hpp"'
         print
 
 
         print '#include "trace_parser.hpp"'
         print '#include "retrace.hpp"'
         print
 
-        types = api.all_types()
+        types = api.getAllTypes()
         handles = [type for type in types if isinstance(type, stdapi.Handle)]
         handle_names = set()
         for handle in handles:
         handles = [type for type in types if isinstance(type, stdapi.Handle)]
         handle_names = set()
         for handle in handles:
@@ -291,5 +338,21 @@ class Retracer:
                 handle_names.add(handle.name)
         print
 
                 handle_names.add(handle.name)
         print
 
-        self.retrace_functions(api.functions)
+        functions = filter(self.filterFunction, api.functions)
+        for function in functions:
+            self.retraceFunction(function)
+        interfaces = api.getAllInterfaces()
+        for interface in interfaces:
+            for method in interface.iterMethods():
+                self.retraceInterfaceMethod(interface, method)
+
+        print 'const retrace::Entry %s[] = {' % self.table_name
+        for function in functions:
+            print '    {"%s", &retrace_%s},' % (function.name, function.name)
+        for interface in interfaces:
+            for method in interface.iterMethods():
+                print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
+        print '    {NULL, NULL}'
+        print '};'
+        print