]> git.cworth.org Git - apitrace/blobdiff - retrace.py
Add missing dependencies to glproc.
[apitrace] / retrace.py
index 4425ad1e99fbd697c854cebfeda1f777a47ef386..96a53b37adad90cd6c317a63d91a5bfd15184a8a 100644 (file)
@@ -33,16 +33,30 @@ import specs.stdapi as stdapi
 import specs.glapi as glapi
 
 
 import specs.glapi as glapi
 
 
-class ConstRemover(stdapi.Rebuilder):
+class MutableRebuilder(stdapi.Rebuilder):
+    '''Type visitor which derives a mutable type.'''
 
 
-    def visit_const(self, const):
+    def visitConst(self, const):
+        # Strip out const qualifier
         return const.type
 
         return const.type
 
-    def visit_opaque(self, opaque):
+    def visitAlias(self, alias):
+        # Tear the alias on type changes
+        type = self.visit(alias.type)
+        if type is alias.type:
+            return alias
+        return type
+
+    def visitReference(self, reference):
+        # Strip out references
+        return reference.type
+
+    def visitOpaque(self, opaque):
+        # Don't recursule
         return opaque
 
 
         return opaque
 
 
-def handle_entry(handle, value):
+def lookupHandle(handle, value):
     if handle.key is None:
         return "__%s_map[%s]" % (handle.name, value)
     else:
     if handle.key is None:
         return "__%s_map[%s]" % (handle.name, value)
     else:
@@ -50,124 +64,175 @@ def handle_entry(handle, value):
         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
 
 
         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
 
 
-class ValueExtractor(stdapi.Visitor):
+class ValueDeserializer(stdapi.Visitor):
 
 
-    def visit_literal(self, literal, lvalue, rvalue):
-        #if literal.format in ('Bool', 'UInt'):
-        print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.format)
+    def visitLiteral(self, literal, lvalue, rvalue):
+        print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
 
 
-    def visit_const(self, const, lvalue, rvalue):
+    def visitConst(self, const, lvalue, rvalue):
         self.visit(const.type, lvalue, rvalue)
 
         self.visit(const.type, lvalue, rvalue)
 
-    def visit_alias(self, alias, lvalue, rvalue):
+    def visitAlias(self, alias, lvalue, rvalue):
         self.visit(alias.type, lvalue, rvalue)
     
         self.visit(alias.type, lvalue, rvalue)
     
-    def visit_enum(self, enum, lvalue, rvalue):
-        print '    %s = (%s).toSInt();' % (lvalue, rvalue)
+    def visitEnum(self, enum, lvalue, rvalue):
+        print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
 
 
-    def visit_bitmask(self, bitmask, lvalue, rvalue):
+    def visitBitmask(self, bitmask, lvalue, rvalue):
         self.visit(bitmask.type, lvalue, rvalue)
 
         self.visit(bitmask.type, lvalue, rvalue)
 
-    def visit_array(self, array, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (array.id, rvalue)
-        print '    if (__a%s) {' % (array.id)
-        length = '__a%s->values.size()' % array.id
-        print '        %s = new %s[%s];' % (lvalue, array.type, length)
-        index = '__j' + array.id
+    allocated = False
+
+    def visitArray(self, array, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
+        length = '__a%s->values.size()' % array.tag
+        allocated = self.allocated
+        if not allocated:
+            print '    if (__a%s) {' % (array.tag)
+            print '        %s = _allocator.alloc<%s>(%s);' % (lvalue, array.type, length)
+            self.allocated = True
+        index = '__j' + array.tag
         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
         try:
         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
         try:
-            self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.id, index))
+            self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
         finally:
             print '        }'
         finally:
             print '        }'
-            print '    } else {'
-            print '        %s = NULL;' % lvalue
-            print '    }'
+            if not allocated:
+                print '    } else {'
+                print '        %s = NULL;' % lvalue
+                print '    }'
     
     
-    def visit_pointer(self, pointer, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (pointer.id, rvalue)
-        print '    if (__a%s) {' % (pointer.id)
-        print '        %s = new %s;' % (lvalue, pointer.type)
+    def visitPointer(self, pointer, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
+        allocated = self.allocated
+        if not allocated:
+            print '    if (__a%s) {' % (pointer.tag)
+            print '        %s = _allocator.alloc<%s>();' % (lvalue, pointer.type)
+            self.allocated = True
         try:
         try:
-            self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.id,))
+            self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
         finally:
         finally:
-            print '    } else {'
-            print '        %s = NULL;' % lvalue
-            print '    }'
+            if not allocated:
+                print '    } else {'
+                print '        %s = NULL;' % lvalue
+                print '    }'
+
+    def visitIntPointer(self, pointer, lvalue, rvalue):
+        print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
+
+    def visitObjPointer(self, pointer, lvalue, rvalue):
+        print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
+
+    def visitLinearPointer(self, pointer, lvalue, rvalue):
+        print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
+
+    def visitReference(self, reference, lvalue, rvalue):
+        self.visit(reference.type, lvalue, rvalue);
 
 
-    def visit_handle(self, handle, lvalue, rvalue):
-        OpaqueValueExtractor().visit(handle.type, lvalue, rvalue);
-        new_lvalue = handle_entry(handle, lvalue)
+    def visitHandle(self, handle, lvalue, rvalue):
+        #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
+        self.visit(handle.type, lvalue, rvalue);
+        new_lvalue = lookupHandle(handle, lvalue)
         print '    if (retrace::verbosity >= 2) {'
         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
         print '    }'
         print '    %s = %s;' % (lvalue, new_lvalue)
     
         print '    if (retrace::verbosity >= 2) {'
         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
         print '    }'
         print '    %s = %s;' % (lvalue, new_lvalue)
     
-    def visit_blob(self, blob, lvalue, rvalue):
+    def visitBlob(self, blob, lvalue, rvalue):
         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
     
         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
     
-    def visit_string(self, string, lvalue, rvalue):
+    def visitString(self, string, lvalue, rvalue):
         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
 
         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
 
+    seq = 0
 
 
-class OpaqueValueExtractor(ValueExtractor):
+    def visitStruct(self, struct, lvalue, rvalue):
+        tmp = '__s_' + struct.tag + '_' + str(self.seq)
+        self.seq += 1
+
+        print '    const trace::Struct *%s = dynamic_cast<const trace::Struct *>(&%s);' % (tmp, rvalue)
+        print '    assert(%s);' % (tmp)
+        self.allocated = True
+        for i in range(len(struct.members)):
+            member_type, member_name = struct.members[i]
+            self.visit(member_type, '%s.%s' % (lvalue, member_name), '*%s->members[%s]' % (tmp, i))
+
+
+class OpaqueValueDeserializer(ValueDeserializer):
     '''Value extractor that also understands opaque values.
 
     Normally opaque values can't be retraced, unless they are being extracted
     in the context of handles.'''
 
     '''Value extractor that also understands opaque values.
 
     Normally opaque values can't be retraced, unless they are being extracted
     in the context of handles.'''
 
-    def visit_opaque(self, opaque, lvalue, rvalue):
-        print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, opaque, rvalue)
+    def visitOpaque(self, opaque, lvalue, rvalue):
+        print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
 
 
 
 
-class ValueWrapper(stdapi.Visitor):
+class SwizzledValueRegistrator(stdapi.Visitor):
+    '''Type visitor which will register (un)swizzled value pairs, to later be
+    swizzled.'''
 
 
-    def visit_literal(self, literal, lvalue, rvalue):
+    def visitLiteral(self, literal, lvalue, rvalue):
         pass
 
         pass
 
-    def visit_alias(self, alias, lvalue, rvalue):
+    def visitAlias(self, alias, lvalue, rvalue):
         self.visit(alias.type, lvalue, rvalue)
     
         self.visit(alias.type, lvalue, rvalue)
     
-    def visit_enum(self, enum, lvalue, rvalue):
+    def visitEnum(self, enum, lvalue, rvalue):
         pass
 
         pass
 
-    def visit_bitmask(self, bitmask, lvalue, rvalue):
+    def visitBitmask(self, bitmask, lvalue, rvalue):
         pass
 
         pass
 
-    def visit_array(self, array, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (array.id, rvalue)
-        print '    if (__a%s) {' % (array.id)
-        length = '__a%s->values.size()' % array.id
-        index = '__j' + array.id
+    def visitArray(self, array, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
+        print '    if (__a%s) {' % (array.tag)
+        length = '__a%s->values.size()' % array.tag
+        index = '__j' + array.tag
         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
         try:
         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
         try:
-            self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.id, index))
+            self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
         finally:
             print '        }'
             print '    }'
     
         finally:
             print '        }'
             print '    }'
     
-    def visit_pointer(self, pointer, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (pointer.id, rvalue)
-        print '    if (__a%s) {' % (pointer.id)
+    def visitPointer(self, pointer, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
+        print '    if (__a%s) {' % (pointer.tag)
         try:
         try:
-            self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.id,))
+            self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
         finally:
             print '    }'
     
         finally:
             print '    }'
     
-    def visit_handle(self, handle, lvalue, rvalue):
+    def visitIntPointer(self, pointer, lvalue, rvalue):
+        pass
+    
+    def visitObjPointer(self, pointer, lvalue, rvalue):
+        print r'    _obj_map[(%s).toUIntPtr()] = %s;' % (rvalue, lvalue)
+    
+    def visitLinearPointer(self, pointer, lvalue, rvalue):
+        assert pointer.size is not None
+        if pointer.size is not None:
+            print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
+
+    def visitReference(self, reference, lvalue, rvalue):
+        pass
+    
+    def visitHandle(self, handle, lvalue, rvalue):
         print '    %s __orig_result;' % handle.type
         print '    %s __orig_result;' % handle.type
-        OpaqueValueExtractor().visit(handle.type, '__orig_result', rvalue);
+        OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
         if handle.range is None:
             rvalue = "__orig_result"
         if handle.range is None:
             rvalue = "__orig_result"
-            entry = handle_entry(handle, rvalue) 
+            entry = lookupHandle(handle, rvalue) 
             print "    %s = %s;" % (entry, lvalue)
             print '    if (retrace::verbosity >= 2) {'
             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
             print '    }'
         else:
             print "    %s = %s;" % (entry, lvalue)
             print '    if (retrace::verbosity >= 2) {'
             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
             print '    }'
         else:
-            i = '__h' + handle.id
+            i = '__h' + handle.tag
             lvalue = "%s + %s" % (lvalue, i)
             rvalue = "__orig_result + %s" % (i,)
             lvalue = "%s + %s" % (lvalue, i)
             rvalue = "__orig_result + %s" % (i,)
-            entry = handle_entry(handle, rvalue) 
+            entry = lookupHandle(handle, rvalue) 
             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
             print '        {entry} = {lvalue};'.format(**locals())
             print '        if (retrace::verbosity >= 2) {'
             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
             print '        {entry} = {lvalue};'.format(**locals())
             print '        if (retrace::verbosity >= 2) {'
@@ -175,78 +240,115 @@ class ValueWrapper(stdapi.Visitor):
             print '        }'
             print '    }'
     
             print '        }'
             print '    }'
     
-    def visit_blob(self, blob, lvalue, rvalue):
+    def visitBlob(self, blob, lvalue, rvalue):
         pass
     
         pass
     
-    def visit_string(self, string, lvalue, rvalue):
+    def visitString(self, string, lvalue, rvalue):
         pass
 
 
 class Retracer:
 
         pass
 
 
 class Retracer:
 
-    def retrace_function(self, function):
-        print 'static void retrace_%s(Trace::Call &call) {' % function.name
-        self.retrace_function_body(function)
+    def retraceFunction(self, function):
+        print 'static void retrace_%s(trace::Call &call) {' % function.name
+        self.retraceFunctionBody(function)
+        print '}'
+        print
+
+    def retraceInterfaceMethod(self, interface, method):
+        print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
+        self.retraceInterfaceMethodBody(interface, method)
         print '}'
         print
 
         print '}'
         print
 
-    def retrace_function_body(self, function):
+    def retraceFunctionBody(self, function):
         if not function.sideeffects:
             print '    (void)call;'
             return
 
         if not function.sideeffects:
             print '    (void)call;'
             return
 
+        self.deserializeArgs(function)
+        
+        self.invokeFunction(function)
+
+        self.swizzleValues(function)
+
+    def retraceInterfaceMethodBody(self, interface, method):
+        if not method.sideeffects:
+            print '    (void)call;'
+            return
+
+        self.deserializeThisPointer(interface)
+
+        self.deserializeArgs(method)
+        
+        self.invokeInterfaceMethod(interface, method)
+
+        self.swizzleValues(method)
+
+    def deserializeThisPointer(self, interface):
+        print '    %s *_this;' % (interface.name,)
+        print '    _this = static_cast<%s *>(_obj_map[call.arg(0).toUIntPtr()]);' % (interface.name,)
+
+    def deserializeArgs(self, function):
+        print '    retrace::ScopedAllocator _allocator;'
+        print '    (void)_allocator;'
         success = True
         for arg in function.args:
         success = True
         for arg in function.args:
-            arg_type = ConstRemover().visit(arg.type)
+            arg_type = MutableRebuilder().visit(arg.type)
             #print '    // %s ->  %s' % (arg.type, arg_type)
             print '    %s %s;' % (arg_type, arg.name)
             rvalue = 'call.arg(%u)' % (arg.index,)
             lvalue = arg.name
             try:
             #print '    // %s ->  %s' % (arg.type, arg_type)
             print '    %s %s;' % (arg_type, arg.name)
             rvalue = 'call.arg(%u)' % (arg.index,)
             lvalue = arg.name
             try:
-                self.extract_arg(function, arg, arg_type, lvalue, rvalue)
+                self.extractArg(function, arg, arg_type, lvalue, rvalue)
             except NotImplementedError:
             except NotImplementedError:
-                success = False
-                print '    %s = 0; // FIXME' % arg.name
+                success =  False
+                print '    memset(&%s, 0, sizeof %s); // FIXME' % (arg.name, arg.name)
+
         if not success:
             print '    if (1) {'
         if not success:
             print '    if (1) {'
-            self.fail_function(function)
+            self.failFunction(function)
+            if function.name[-1].islower():
+                sys.stderr.write('warning: unsupported %s call\n' % function.name)
             print '    }'
             print '    }'
-        self.call_function(function)
+
+    def swizzleValues(self, function):
         for arg in function.args:
             if arg.output:
         for arg in function.args:
             if arg.output:
-                arg_type = ConstRemover().visit(arg.type)
+                arg_type = MutableRebuilder().visit(arg.type)
                 rvalue = 'call.arg(%u)' % (arg.index,)
                 lvalue = arg.name
                 try:
                 rvalue = 'call.arg(%u)' % (arg.index,)
                 lvalue = arg.name
                 try:
-                    ValueWrapper().visit(arg_type, lvalue, rvalue)
+                    self.regiterSwizzledValue(arg_type, lvalue, rvalue)
                 except NotImplementedError:
                     print '    // XXX: %s' % arg.name
         if function.type is not stdapi.Void:
             rvalue = '*call.ret'
             lvalue = '__result'
             try:
                 except NotImplementedError:
                     print '    // XXX: %s' % arg.name
         if function.type is not stdapi.Void:
             rvalue = '*call.ret'
             lvalue = '__result'
             try:
-                ValueWrapper().visit(function.type, lvalue, rvalue)
+                self.regiterSwizzledValue(function.type, lvalue, rvalue)
             except NotImplementedError:
             except NotImplementedError:
-                success = False
-                print '    // FIXME: result'
-        if not success:
-            if function.name[-1].islower():
-                sys.stderr.write('warning: %s unsupported\n' % function.name)
+                raise
+                print '    // XXX: result'
 
 
-    def fail_function(self, function):
+    def failFunction(self, function):
         print '    if (retrace::verbosity >= 0) {'
         print '    if (retrace::verbosity >= 0) {'
-        print '        retrace::unknown(call);'
+        print '        retrace::unsupported(call);'
         print '    }'
         print '    return;'
 
         print '    }'
         print '    return;'
 
-    def extract_arg(self, function, arg, arg_type, lvalue, rvalue):
-        ValueExtractor().visit(arg_type, lvalue, rvalue)
+    def extractArg(self, function, arg, arg_type, lvalue, rvalue):
+        ValueDeserializer().visit(arg_type, lvalue, rvalue)
     
     
-    def extract_opaque_arg(self, function, arg, arg_type, lvalue, rvalue):
-        OpaqueValueExtractor().visit(arg_type, lvalue, rvalue)
+    def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
+        OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
 
 
-    def call_function(self, function):
-        arg_names = ", ".join([arg.name for arg in function.args])
+    def regiterSwizzledValue(self, type, lvalue, rvalue):
+        visitor = SwizzledValueRegistrator()
+        visitor.visit(type, lvalue, rvalue)
+
+    def invokeFunction(self, function):
+        arg_names = ", ".join(function.argNames())
         if function.type is not stdapi.Void:
             print '    %s __result;' % (function.type)
             print '    __result = %s(%s);' % (function.name, arg_names)
         if function.type is not stdapi.Void:
             print '    %s __result;' % (function.type)
             print '    __result = %s(%s);' % (function.name, arg_names)
@@ -254,32 +356,28 @@ class Retracer:
         else:
             print '    %s(%s);' % (function.name, arg_names)
 
         else:
             print '    %s(%s);' % (function.name, arg_names)
 
-    def filter_function(self, function):
+    def invokeInterfaceMethod(self, interface, method):
+        arg_names = ", ".join(method.argNames())
+        if method.type is not stdapi.Void:
+            print '    %s __result;' % (method.type)
+            print '    __result = _this->%s(%s);' % (method.name, arg_names)
+            print '    (void)__result;'
+        else:
+            print '    _this->%s(%s);' % (method.name, arg_names)
+
+    def filterFunction(self, function):
         return True
 
     table_name = 'retrace::callbacks'
 
         return True
 
     table_name = 'retrace::callbacks'
 
-    def retrace_functions(self, functions):
-        functions = filter(self.filter_function, functions)
-
-        for function in functions:
-            self.retrace_function(function)
-
-        print 'const retrace::Entry %s[] = {' % self.table_name
-        for function in functions:
-            print '    {"%s", &retrace_%s},' % (function.name, function.name)
-        print '    {NULL, NULL}'
-        print '};'
-        print
-
-
-    def retrace_api(self, api):
+    def retraceApi(self, api):
 
 
+        print '#include "os_time.hpp"'
         print '#include "trace_parser.hpp"'
         print '#include "retrace.hpp"'
         print
 
         print '#include "trace_parser.hpp"'
         print '#include "retrace.hpp"'
         print
 
-        types = api.all_types()
+        types = api.getAllTypes()
         handles = [type for type in types if isinstance(type, stdapi.Handle)]
         handle_names = set()
         for handle in handles:
         handles = [type for type in types if isinstance(type, stdapi.Handle)]
         handle_names = set()
         for handle in handles:
@@ -292,5 +390,24 @@ class Retracer:
                 handle_names.add(handle.name)
         print
 
                 handle_names.add(handle.name)
         print
 
-        self.retrace_functions(api.functions)
+        print 'static std::map<unsigned long long, void *> _obj_map;'
+        print
+
+        functions = filter(self.filterFunction, api.functions)
+        for function in functions:
+            self.retraceFunction(function)
+        interfaces = api.getAllInterfaces()
+        for interface in interfaces:
+            for method in interface.iterMethods():
+                self.retraceInterfaceMethod(interface, method)
+
+        print 'const retrace::Entry %s[] = {' % self.table_name
+        for function in functions:
+            print '    {"%s", &retrace_%s},' % (function.name, function.name)
+        for interface in interfaces:
+            for method in interface.iterMethods():
+                print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
+        print '    {NULL, NULL}'
+        print '};'
+        print