]> git.cworth.org Git - apitrace/blobdiff - retrace.py
Cleanup unicode support.
[apitrace] / retrace.py
index 9092f351366a2377231761c962e2c1daec382879..25e26093597df8fad5ff9ab5d0217d1ffb1c6069 100644 (file)
 
 """Generic retracing code generator."""
 
 
 """Generic retracing code generator."""
 
+
+import sys
+
 import specs.stdapi as stdapi
 import specs.glapi as glapi
 
 
 class ConstRemover(stdapi.Rebuilder):
 import specs.stdapi as stdapi
 import specs.glapi as glapi
 
 
 class ConstRemover(stdapi.Rebuilder):
+    '''Type visitor which strips out const qualifiers from types.'''
 
 
-    def visit_const(self, const):
+    def visitConst(self, const):
         return const.type
 
         return const.type
 
-    def visit_opaque(self, opaque):
+    def visitOpaque(self, opaque):
         return opaque
 
 
         return opaque
 
 
-def handle_entry(handle, value):
+def lookupHandle(handle, value):
     if handle.key is None:
         return "__%s_map[%s]" % (handle.name, value)
     else:
     if handle.key is None:
         return "__%s_map[%s]" % (handle.name, value)
     else:
@@ -47,124 +51,140 @@ def handle_entry(handle, value):
         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
 
 
         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
 
 
-class ValueExtractor(stdapi.Visitor):
+class ValueDeserializer(stdapi.Visitor):
 
 
-    def visit_literal(self, literal, lvalue, rvalue):
-        #if literal.format in ('Bool', 'UInt'):
-        print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.format)
+    def visitLiteral(self, literal, lvalue, rvalue):
+        print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
 
 
-    def visit_const(self, const, lvalue, rvalue):
+    def visitConst(self, const, lvalue, rvalue):
         self.visit(const.type, lvalue, rvalue)
 
         self.visit(const.type, lvalue, rvalue)
 
-    def visit_alias(self, alias, lvalue, rvalue):
+    def visitAlias(self, alias, lvalue, rvalue):
         self.visit(alias.type, lvalue, rvalue)
     
         self.visit(alias.type, lvalue, rvalue)
     
-    def visit_enum(self, enum, lvalue, rvalue):
+    def visitEnum(self, enum, lvalue, rvalue):
         print '    %s = (%s).toSInt();' % (lvalue, rvalue)
 
         print '    %s = (%s).toSInt();' % (lvalue, rvalue)
 
-    def visit_bitmask(self, bitmask, lvalue, rvalue):
+    def visitBitmask(self, bitmask, lvalue, rvalue):
         self.visit(bitmask.type, lvalue, rvalue)
 
         self.visit(bitmask.type, lvalue, rvalue)
 
-    def visit_array(self, array, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (array.id, rvalue)
-        print '    if (__a%s) {' % (array.id)
-        length = '__a%s->values.size()' % array.id
+    def visitArray(self, array, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
+        print '    if (__a%s) {' % (array.tag)
+        length = '__a%s->values.size()' % array.tag
         print '        %s = new %s[%s];' % (lvalue, array.type, length)
         print '        %s = new %s[%s];' % (lvalue, array.type, length)
-        index = '__j' + array.id
+        index = '__j' + array.tag
         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
         try:
         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
         try:
-            self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.id, index))
+            self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
         finally:
             print '        }'
             print '    } else {'
             print '        %s = NULL;' % lvalue
             print '    }'
     
         finally:
             print '        }'
             print '    } else {'
             print '        %s = NULL;' % lvalue
             print '    }'
     
-    def visit_pointer(self, pointer, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (pointer.id, rvalue)
-        print '    if (__a%s) {' % (pointer.id)
+    def visitPointer(self, pointer, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
+        print '    if (__a%s) {' % (pointer.tag)
         print '        %s = new %s;' % (lvalue, pointer.type)
         try:
         print '        %s = new %s;' % (lvalue, pointer.type)
         try:
-            self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.id,))
+            self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
         finally:
             print '    } else {'
             print '        %s = NULL;' % lvalue
             print '    }'
 
         finally:
             print '    } else {'
             print '        %s = NULL;' % lvalue
             print '    }'
 
-    def visit_handle(self, handle, lvalue, rvalue):
-        OpaqueValueExtractor().visit(handle.type, lvalue, rvalue);
-        new_lvalue = handle_entry(handle, lvalue)
+    def visitIntPointer(self, pointer, lvalue, rvalue):
+        print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
+
+    def visitLinearPointer(self, pointer, lvalue, rvalue):
+        print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
+
+    def visitHandle(self, handle, lvalue, rvalue):
+        #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
+        self.visit(handle.type, lvalue, rvalue);
+        new_lvalue = lookupHandle(handle, lvalue)
         print '    if (retrace::verbosity >= 2) {'
         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
         print '    }'
         print '    %s = %s;' % (lvalue, new_lvalue)
     
         print '    if (retrace::verbosity >= 2) {'
         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
         print '    }'
         print '    %s = %s;' % (lvalue, new_lvalue)
     
-    def visit_blob(self, blob, lvalue, rvalue):
+    def visitBlob(self, blob, lvalue, rvalue):
         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
     
         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
     
-    def visit_string(self, string, lvalue, rvalue):
+    def visitString(self, string, lvalue, rvalue):
         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
 
 
         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
 
 
-class OpaqueValueExtractor(ValueExtractor):
+class OpaqueValueDeserializer(ValueDeserializer):
     '''Value extractor that also understands opaque values.
 
     Normally opaque values can't be retraced, unless they are being extracted
     in the context of handles.'''
 
     '''Value extractor that also understands opaque values.
 
     Normally opaque values can't be retraced, unless they are being extracted
     in the context of handles.'''
 
-    def visit_opaque(self, opaque, lvalue, rvalue):
-        print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, opaque, rvalue)
+    def visitOpaque(self, opaque, lvalue, rvalue):
+        print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
 
 
 
 
-class ValueWrapper(stdapi.Visitor):
+class SwizzledValueRegistrator(stdapi.Visitor):
+    '''Type visitor which will register (un)swizzled value pairs, to later be
+    swizzled.'''
 
 
-    def visit_literal(self, literal, lvalue, rvalue):
+    def visitLiteral(self, literal, lvalue, rvalue):
         pass
 
         pass
 
-    def visit_alias(self, alias, lvalue, rvalue):
+    def visitAlias(self, alias, lvalue, rvalue):
         self.visit(alias.type, lvalue, rvalue)
     
         self.visit(alias.type, lvalue, rvalue)
     
-    def visit_enum(self, enum, lvalue, rvalue):
+    def visitEnum(self, enum, lvalue, rvalue):
         pass
 
         pass
 
-    def visit_bitmask(self, bitmask, lvalue, rvalue):
+    def visitBitmask(self, bitmask, lvalue, rvalue):
         pass
 
         pass
 
-    def visit_array(self, array, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (array.id, rvalue)
-        print '    if (__a%s) {' % (array.id)
-        length = '__a%s->values.size()' % array.id
-        index = '__j' + array.id
+    def visitArray(self, array, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
+        print '    if (__a%s) {' % (array.tag)
+        length = '__a%s->values.size()' % array.tag
+        index = '__j' + array.tag
         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
         try:
         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
         try:
-            self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.id, index))
+            self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
         finally:
             print '        }'
             print '    }'
     
         finally:
             print '        }'
             print '    }'
     
-    def visit_pointer(self, pointer, lvalue, rvalue):
-        print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (pointer.id, rvalue)
-        print '    if (__a%s) {' % (pointer.id)
+    def visitPointer(self, pointer, lvalue, rvalue):
+        print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
+        print '    if (__a%s) {' % (pointer.tag)
         try:
         try:
-            self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.id,))
+            self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
         finally:
             print '    }'
     
         finally:
             print '    }'
     
-    def visit_handle(self, handle, lvalue, rvalue):
+    def visitIntPointer(self, pointer, lvalue, rvalue):
+        pass
+    
+    def visitLinearPointer(self, pointer, lvalue, rvalue):
+        assert pointer.size is not None
+        if pointer.size is not None:
+            print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
+
+    def visitHandle(self, handle, lvalue, rvalue):
         print '    %s __orig_result;' % handle.type
         print '    %s __orig_result;' % handle.type
-        OpaqueValueExtractor().visit(handle.type, '__orig_result', rvalue);
+        OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
         if handle.range is None:
             rvalue = "__orig_result"
         if handle.range is None:
             rvalue = "__orig_result"
-            entry = handle_entry(handle, rvalue) 
+            entry = lookupHandle(handle, rvalue) 
             print "    %s = %s;" % (entry, lvalue)
             print '    if (retrace::verbosity >= 2) {'
             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
             print '    }'
         else:
             print "    %s = %s;" % (entry, lvalue)
             print '    if (retrace::verbosity >= 2) {'
             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
             print '    }'
         else:
-            i = '__h' + handle.id
+            i = '__h' + handle.tag
             lvalue = "%s + %s" % (lvalue, i)
             rvalue = "__orig_result + %s" % (i,)
             lvalue = "%s + %s" % (lvalue, i)
             rvalue = "__orig_result + %s" % (i,)
-            entry = handle_entry(handle, rvalue) 
+            entry = lookupHandle(handle, rvalue) 
             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
             print '        {entry} = {lvalue};'.format(**locals())
             print '        if (retrace::verbosity >= 2) {'
             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
             print '        {entry} = {lvalue};'.format(**locals())
             print '        if (retrace::verbosity >= 2) {'
@@ -172,22 +192,22 @@ class ValueWrapper(stdapi.Visitor):
             print '        }'
             print '    }'
     
             print '        }'
             print '    }'
     
-    def visit_blob(self, blob, lvalue, rvalue):
+    def visitBlob(self, blob, lvalue, rvalue):
         pass
     
         pass
     
-    def visit_string(self, string, lvalue, rvalue):
+    def visitString(self, string, lvalue, rvalue):
         pass
 
 
 class Retracer:
 
         pass
 
 
 class Retracer:
 
-    def retrace_function(self, function):
-        print 'static void retrace_%s(Trace::Call &call) {' % function.name
-        self.retrace_function_body(function)
+    def retraceFunction(self, function):
+        print 'static void retrace_%s(trace::Call &call) {' % function.name
+        self.retraceFunctionBody(function)
         print '}'
         print
 
         print '}'
         print
 
-    def retrace_function_body(self, function):
+    def retraceFunctionBody(self, function):
         if not function.sideeffects:
             print '    (void)call;'
             return
         if not function.sideeffects:
             print '    (void)call;'
             return
@@ -200,45 +220,53 @@ class Retracer:
             rvalue = 'call.arg(%u)' % (arg.index,)
             lvalue = arg.name
             try:
             rvalue = 'call.arg(%u)' % (arg.index,)
             lvalue = arg.name
             try:
-                self.extract_arg(function, arg, arg_type, lvalue, rvalue)
+                self.extractArg(function, arg, arg_type, lvalue, rvalue)
             except NotImplementedError:
                 success = False
                 print '    %s = 0; // FIXME' % arg.name
         if not success:
             print '    if (1) {'
             except NotImplementedError:
                 success = False
                 print '    %s = 0; // FIXME' % arg.name
         if not success:
             print '    if (1) {'
-            self.fail_function(function)
+            self.failFunction(function)
             print '    }'
             print '    }'
-        self.call_function(function)
+        self.invokeFunction(function)
         for arg in function.args:
             if arg.output:
                 arg_type = ConstRemover().visit(arg.type)
                 rvalue = 'call.arg(%u)' % (arg.index,)
                 lvalue = arg.name
                 try:
         for arg in function.args:
             if arg.output:
                 arg_type = ConstRemover().visit(arg.type)
                 rvalue = 'call.arg(%u)' % (arg.index,)
                 lvalue = arg.name
                 try:
-                    ValueWrapper().visit(arg_type, lvalue, rvalue)
+                    self.regiterSwizzledValue(arg_type, lvalue, rvalue)
                 except NotImplementedError:
                 except NotImplementedError:
-                    print '   // FIXME: %s' % arg.name
+                    print '    // XXX: %s' % arg.name
         if function.type is not stdapi.Void:
             rvalue = '*call.ret'
             lvalue = '__result'
             try:
         if function.type is not stdapi.Void:
             rvalue = '*call.ret'
             lvalue = '__result'
             try:
-                ValueWrapper().visit(function.type, lvalue, rvalue)
+                self.regiterSwizzledValue(function.type, lvalue, rvalue)
             except NotImplementedError:
             except NotImplementedError:
-                print '   // FIXME: result'
+                print '    // XXX: result'
+        if not success:
+            if function.name[-1].islower():
+                sys.stderr.write('warning: unsupported %s call\n' % function.name)
 
 
-    def fail_function(self, function):
-        print '    if (retrace::verbosity >= 0)'
-        print '        std::cerr << "warning: unsupported call %s\\n";' % function.name
+    def failFunction(self, function):
+        print '    if (retrace::verbosity >= 0) {'
+        print '        retrace::unsupported(call);'
+        print '    }'
         print '    return;'
 
         print '    return;'
 
-    def extract_arg(self, function, arg, arg_type, lvalue, rvalue):
-        ValueExtractor().visit(arg_type, lvalue, rvalue)
+    def extractArg(self, function, arg, arg_type, lvalue, rvalue):
+        ValueDeserializer().visit(arg_type, lvalue, rvalue)
     
     
-    def extract_opaque_arg(self, function, arg, arg_type, lvalue, rvalue):
-        OpaqueValueExtractor().visit(arg_type, lvalue, rvalue)
+    def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
+        OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
+
+    def regiterSwizzledValue(self, type, lvalue, rvalue):
+        visitor = SwizzledValueRegistrator()
+        visitor.visit(type, lvalue, rvalue)
 
 
-    def call_function(self, function):
-        arg_names = ", ".join([arg.name for arg in function.args])
+    def invokeFunction(self, function):
+        arg_names = ", ".join(function.argNames())
         if function.type is not stdapi.Void:
             print '    %s __result;' % (function.type)
             print '    __result = %s(%s);' % (function.name, arg_names)
         if function.type is not stdapi.Void:
             print '    %s __result;' % (function.type)
             print '    __result = %s(%s);' % (function.name, arg_names)
@@ -246,16 +274,16 @@ class Retracer:
         else:
             print '    %s(%s);' % (function.name, arg_names)
 
         else:
             print '    %s(%s);' % (function.name, arg_names)
 
-    def filter_function(self, function):
+    def filterFunction(self, function):
         return True
 
     table_name = 'retrace::callbacks'
 
         return True
 
     table_name = 'retrace::callbacks'
 
-    def retrace_functions(self, functions):
-        functions = filter(self.filter_function, functions)
+    def retraceFunctions(self, functions):
+        functions = filter(self.filterFunction, functions)
 
         for function in functions:
 
         for function in functions:
-            self.retrace_function(function)
+            self.retraceFunction(function)
 
         print 'const retrace::Entry %s[] = {' % self.table_name
         for function in functions:
 
         print 'const retrace::Entry %s[] = {' % self.table_name
         for function in functions:
@@ -265,13 +293,13 @@ class Retracer:
         print
 
 
         print
 
 
-    def retrace_api(self, api):
+    def retraceApi(self, api):
 
         print '#include "trace_parser.hpp"'
         print '#include "retrace.hpp"'
         print
 
 
         print '#include "trace_parser.hpp"'
         print '#include "retrace.hpp"'
         print
 
-        types = api.all_types()
+        types = api.getAllTypes()
         handles = [type for type in types if isinstance(type, stdapi.Handle)]
         handle_names = set()
         for handle in handles:
         handles = [type for type in types if isinstance(type, stdapi.Handle)]
         handle_names = set()
         for handle in handles:
@@ -284,5 +312,5 @@ class Retracer:
                 handle_names.add(handle.name)
         print
 
                 handle_names.add(handle.name)
         print
 
-        self.retrace_functions(api.functions)
+        self.retraceFunctions(api.functions)