]> git.cworth.org Git - apitrace/blobdiff - scripts/unpickle.py
Optimize tracediff2.
[apitrace] / scripts / unpickle.py
index 98e2c30b4505d27d6b779200d38382a74004addf..ab2dc4db0311292dce0a0a8ab23474cdc3d9b300 100755 (executable)
@@ -33,10 +33,131 @@ Run as:
 '''
 
 
+import itertools
 import optparse
-import cPickle as pickle
 import sys
 import time
+import re
+import cPickle as pickle
+
+
+class Visitor:
+
+    def __init__(self):
+        self.dispatch = {}
+        self.dispatch[type(None)] = self.visitNone
+        self.dispatch[bool] = self.visitBool
+        self.dispatch[int] = self.visitInt
+        self.dispatch[long] = self.visitInt
+        self.dispatch[float] = self.visitFloat
+        self.dispatch[str] = self.visitStr
+        self.dispatch[tuple] = self.visitTuple
+        self.dispatch[list] = self.visitList
+        self.dispatch[dict] = self.visitDict
+        self.dispatch[bytearray] = self.visitByteArray
+
+    def visit(self, obj):
+        method = self.dispatch.get(type(obj), self.visitObj)
+        return method(obj)
+
+    def visitObj(self, obj):
+        raise NotImplementedError
+
+    def visitAtom(self, obj):
+        return self.visitObj(obj)
+
+    def visitNone(self, obj):
+        return self.visitAtom(obj)
+
+    def visitBool(self, obj):
+        return self.visitAtom(obj)
+
+    def visitInt(self, obj):
+        return self.visitAtom(obj)
+
+    def visitFloat(self, obj):
+        return self.visitAtom(obj)
+
+    def visitStr(self, obj):
+        return self.visitAtom(obj)
+
+    def visitIterable(self, obj):
+        return self.visitObj(obj)
+
+    def visitTuple(self, obj):
+        return self.visitIterable(obj)
+
+    def visitList(self, obj):
+        return self.visitIterable(obj)
+
+    def visitDict(self, obj):
+        raise NotImplementedError
+
+    def visitByteArray(self, obj):
+        raise NotImplementedError
+
+
+class Dumper(Visitor):
+
+    id_re = re.compile('^[_A-Za-z][_A-Za-z0-9]*$')
+
+    def visitObj(self, obj):
+        return repr(obj)
+
+    def visitStr(self, obj):
+        if self.id_re.match(obj):
+            return obj
+        else:
+            return repr(obj)
+
+    def visitTuple(self, obj):
+        return '[' + ', '.join(itertools.imap(self.visit, obj)) + ']'
+
+    def visitList(self, obj):
+        return '(' + ', '.join(itertools.imap(self.visit, obj)) + ')'
+
+    def visitByteArray(self, obj):
+        return 'blob(%u)' % len(obj)
+
+
+class Hasher(Visitor):
+    '''Returns a hashable version of the objtree.'''
+
+    def visitObj(self, obj):
+        return obj
+
+    def visitAtom(self, obj):
+        return obj
+
+    def visitIterable(self, obj):
+        return tuple(itertools.imap(self.visit, obj))
+
+    def visitByteArray(self, obj):
+        return str(obj)
+
+
+class Rebuilder(Visitor):
+    '''Returns a hashable version of the objtree.'''
+
+    def visitAtom(self, obj):
+        return obj
+
+    def visitIterable(self, obj):
+        changed = False
+        newItems = []
+        for oldItem in obj:
+            newItem = self.visit(oldItem)
+            if newItem is not oldItem:
+                changed = True
+            newItems.append(newItem)
+        if changed:
+            klass = type(obj)
+            return klass(changed)
+        else:
+            return obj
+
+    def visitByteArray(self, obj):
+        return obj
 
 
 class Call:
@@ -49,10 +170,11 @@ class Call:
         s = self.functionName
         if self.no is not None:
             s = str(self.no) + ' ' + s
-        s += '(' + ', '.join(map(repr, self.args)) + ')'
+        dumper = Dumper()
+        s += '(' + ', '.join(itertools.imap(dumper.visit, self.args)) + ')'
         if self.ret is not None:
             s += ' = '
-            s += repr(self.ret)
+            s += dumper.visit(self.ret)
         return s
 
     def __eq__(self, other):
@@ -63,9 +185,9 @@ class Call:
 
     def __hash__(self):
         if self._hash is None:
-            # XXX: hack due to unhashable types
-            #self._hash = hash(self.functionName) ^ hash(tuple(self.args)) ^ hash(self.ret)
-            self._hash = hash(self.functionName) ^ hash(repr(self.args)) ^ hash(repr(self.ret))
+            hasher = Hasher()
+            hashable = hasher.visit(self.functionName), hasher.visit(self.args), hasher.visit(self.ret)
+            self._hash = hash(hashable)
         return self._hash
 
 
@@ -112,7 +234,7 @@ def main():
     optparser = optparse.OptionParser(
         usage="\n\tapitrace pickle trace. %prog [options]")
     optparser.add_option(
-        '--quiet',
+        '-q', '--quiet',
         action="store_true", dest="quiet", default=False,
         help="don't dump calls to stdout")