Improve checker.
authorJosé Fonseca <jose.r.fonseca@gmail.com>
Thu, 22 Nov 2012 09:55:15 +0000 (09:55 +0000)
committerJosé Fonseca <jose.r.fonseca@gmail.com>
Thu, 22 Nov 2012 10:28:57 +0000 (10:28 +0000)
checker.py [changed mode: 0644->0755]

old mode 100644 (file)
new mode 100755 (executable)
index be1331d..c6eda92
 
 
 import sys
+import optparse
 import re
 
 
-class ValueMatcher:
+class Matcher:
 
     def match(self, value):
         raise NotImplementedError
 
+    def _matchSequence(self, refValues, srcValues):
+        if not isinstance(srcValues, (list, tuple)):
+            return False
+
+        if len(refValues) != len(srcValues):
+            return False
+
+        for refValue, srcValue in zip(refValues, srcValues):
+            if not refValue.match(srcValue):
+                return False
+        return True
+
     def __str__(self):
         raise NotImplementerError
 
@@ -41,16 +54,16 @@ class ValueMatcher:
         return str(self)
 
 
-class WildcardMatcher(ValueMatcher):
+class WildcardMatcher(Matcher):
 
     def match(self, value):
-        return true
+        return True
 
     def __str__(self):
         return '*'
 
 
-class LiteralValueMatcher(ValueMatcher):
+class LiteralMatcher(Matcher):
 
     def __init__(self, refValue):
         self.refValue = refValue
@@ -62,13 +75,16 @@ class LiteralValueMatcher(ValueMatcher):
         return repr(self.refValue)
 
 
-class ApproxValueMatcher(ValueMatcher):
+class ApproxMatcher(Matcher):
 
     def __init__(self, refValue, tolerance = 2**-23):
         self.refValue = refValue
         self.tolerance = tolerance
 
     def match(self, value):
+        if not isinstance(value, float):
+            return 
+
         error = abs(self.refValue - value)
         if self.refValue:
             error = error / self.refValue
@@ -78,29 +94,31 @@ class ApproxValueMatcher(ValueMatcher):
         return repr(self.refValue)
 
 
-class ArrayMatcher(ValueMatcher):
+class BitmaskMatcher(Matcher):
 
     def __init__(self, refElements):
         self.refElements = refElements
 
     def match(self, value):
-        if not isinstance(value, list):
-            return False
+        return self._matchSequence(self.refElements, value)
 
-        if len(value) != len(self.refElements):
-            return False
+    def __str__(self):
+        return ' | '.join(map(str, self.refElements))
 
-        for refElement, element in zip(self.refElements, value):
-            if not refElement.match(element):
-                return False
 
-        return True
+class ArrayMatcher(Matcher):
+
+    def __init__(self, refElements):
+        self.refElements = refElements
+
+    def match(self, value):
+        return self._matchSequence(self.refElements, value)
 
     def __str__(self):
         return '{' + ', '.join(map(str, self.refElements)) + '}'
 
 
-class StructMatcher(ValueMatcher):
+class StructMatcher(Matcher):
 
     def __init__(self, refMembers):
         self.refMembers = refMembers
@@ -128,48 +146,68 @@ class StructMatcher(ValueMatcher):
         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
 
 
-class CallMatcher:
+class CallMatcher(Matcher):
 
-    def __init__(self, refFunctionName, refArgs, refRet = None):
-        self.refFunctionName = refFunctionName
-        self.refArgs = refArgs
-        self.refRet = refRet
+    def __init__(self, functionName, args, ret = None):
+        self.functionName = functionName
+        self.args = args
+        self.ret = ret
 
-    def match(self, functionName, args, ret = None):
-        if refFunctionName != functionName:
-            return False
+    def match(self, call):
+        srcFunctionName, srcArgs, srcRet = call
 
-        if len(self.refArgs) != len(args):
+        if self.functionName != srcFunctionName:
             return False
 
-        for (refArgName, refArg), (argName, arg) in zip(self.refArgs, args):
-            if not refArg.match(arg):
-                return False
+        refArgs = [value for name, value in self.args]
+        srcArgs = [value for name, value in srcArgs]
 
-        if self.refRet is None:
-            if ret is not None:
+        if not self._matchSequence(refArgs, srcArgs):
+            return False
+
+        if self.ret is None:
+            if srcRet is not None:
                 return False
         else:
-            if not self.refRet.match(ret):
+            if not self.ret.match(srcRet):
                 return False
 
         return True
 
     def __str__(self):
-        s = self.refFunctionName
-        s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.refArgs]) + ')'
-        if self.refRet is not None:
-            s += ' = ' + str(self.refRet)
+        s = self.functionName
+        s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.args]) + ')'
+        if self.ret is not None:
+            s += ' = ' + str(self.ret)
         return s
 
 
 class TraceMatcher:
 
-    def __init__(self, refCalls):
-        self.refCalls = refCalls
+    def __init__(self, calls):
+        self.calls = calls
+
+    def match(self, trace):
+
+        srcCalls = iter(trace.calls)
+        for refCall in self.calls:
+            skippedSrcCalls = []
+            while True:
+                try:
+                    srcCall = srcCalls.next()
+                except StopIteration:
+                    if skippedSrcCalls:
+                        raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
+                    else:
+                        raise Exception('missing call %s' % refCall)
+                if refCall.match(srcCall):
+                    break
+                else:
+                    skippedSrcCalls.append(srcCall)
+        return True
 
     def __str__(self):
-        return ''.join(['%s\n' % refCall for refCall in self.refCalls])
+        return ''.join(['%s\n' % call for call in self.calls])
 
 
 #######################################################################
@@ -348,7 +386,7 @@ class Parser:
 
 #######################################################################
 
-ID, NUMBER, HEXNUM, STRING, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, BLOB = xrange(13)
+ID, NUMBER, HEXNUM, STRING, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(14)
 
 
 class CallScanner(Scanner):
@@ -359,13 +397,13 @@ class CallScanner(Scanner):
         (SKIP, r'[ \t\f\r\n\v]+', False),
 
         # Alphanumeric IDs
-        (ID, r'[a-zA-Z_\x80-\xff][a-zA-Z0-9_\x80-\xff]*', True),
+        (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True),
 
         # Numeric IDs
         (HEXNUM, r'-?0x[0-9a-fA-F]+', False),
         
         # Numeric IDs
-        (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)', False),
+        (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)(?:[eE][-+][0-9]+)?', False),
 
         # String IDs
         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
@@ -383,6 +421,7 @@ class CallScanner(Scanner):
         ',': COMMA,
         '&': AMP,
         '=': EQUAL,
+        '|': VERT,
     }
 
     # literal table
@@ -412,7 +451,7 @@ class CallLexer(Lexer):
         return type, text
 
 
-class CallParser(Parser):
+class TraceParser(Parser):
 
     def __init__(self, stream):
         lexer = CallLexer(fp = stream)
@@ -424,6 +463,7 @@ class CallParser(Parser):
     def parse(self):
         while not self.eof():
             self.parse_element()
+        return TraceMatcher(self.calls)
 
     def parse_element(self):
         if self.lookahead.type == PRAGMA:
@@ -455,7 +495,7 @@ class CallParser(Parser):
         else:
             ret = None
 
-        return self.handleCall(callNo, functionName, args, ret)
+        self.handleCall(functionName, args, ret)
 
     def parse_pair(self):
         '''Parse a `name = value` pair.'''
@@ -483,6 +523,18 @@ class CallParser(Parser):
             return name, value
 
     def parse_value(self):
+        value = self._parse_value()
+        if self.match(VERT):
+            flags = [value]
+            while self.match(VERT):
+                self.consume()
+                value = self._parse_value()
+                flags.append(value)
+            return self.handleBitmask(flags)
+        else:
+            return value
+
+    def _parse_value(self):
         if self.match(AMP):
             self.consume()
             value = [self.parse_value()]
@@ -515,7 +567,7 @@ class CallParser(Parser):
             self.consume(LPAREN)
             length = self.consume()
             self.consume(RPAREN)
-            return self.handleBlob()
+            return self.handleBlob(length)
         else:
             self.error()
 
@@ -538,16 +590,62 @@ class CallParser(Parser):
         return elements
     
     def handleID(self, value):
-        return LiteralValueMatcher(value)
+        raise NotImplementedError
 
     def handleInt(self, value):
-        return LiteralValueMatcher(value)
+        raise NotImplementedError
 
     def handleFloat(self, value):
-        return ApproxValueMatcher(value)
+        raise NotImplementedError
 
     def handleString(self, value):
-        return LiteralValueMatcher(value)
+        raise NotImplementedError
+
+    def handleBitmask(self, value):
+        raise NotImplementedError
+
+    def handleArray(self, value):
+        raise NotImplementedError
+
+    def handleStruct(self, value):
+        raise NotImplementedError
+
+    def handleBlob(self, length):
+        raise NotImplementedError
+        # TODO
+        return WildcardMatcher()
+
+    def handleCall(self, functionName, args, ret):
+        raise NotImplementedError
+
+    def handlePragma(self, line):
+        pass
+
+
+class RefTraceParser(TraceParser):
+
+    def __init__(self, stream):
+        TraceParser.__init__(self, stream)
+        self.calls = []
+
+    def parse(self):
+        TraceParser.parse(self)
+        return TraceMatcher(self.calls)
+
+    def handleID(self, value):
+        return LiteralMatcher(value)
+
+    def handleInt(self, value):
+        return LiteralMatcher(value)
+
+    def handleFloat(self, value):
+        return ApproxMatcher(value)
+
+    def handleString(self, value):
+        return LiteralMatcher(value)
+
+    def handleBitmask(self, value):
+        return BitmaskMatcher(value)
 
     def handleArray(self, value):
         return ArrayMatcher(value)
@@ -555,26 +653,76 @@ class CallParser(Parser):
     def handleStruct(self, value):
         return StructMatcher(value)
 
-    def handleBlob(self, value):
+    def handleBlob(self, length):
         # TODO
         return WildcardMatcher()
 
-    def handleCall(self, callNo, functionName, args, ret):
-        matcher = CallMatcher(functionName, args, ret)
+    def handleCall(self, functionName, args, ret):
+        call = CallMatcher(functionName, args, ret)
+        self.calls.append(call)
 
-        if callNo is not None:
-            sys.stdout.write('%u ' % callNo)
-        sys.stdout.write(str(matcher))
-        sys.stdout.write('\n')
 
-    def handlePragma(self, line):
-        sys.stdout.write(line)
-        sys.stdout.write('\n')
+class SrcTraceParser(TraceParser):
+
+    def __init__(self, stream):
+        TraceParser.__init__(self, stream)
+        self.calls = []
+
+    def parse(self):
+        TraceParser.parse(self)
+        return TraceMatcher(self.calls)
+
+    def handleID(self, value):
+        return value
+
+    def handleInt(self, value):
+        return int(value)
+
+    def handleFloat(self, value):
+        return float(value)
+
+    def handleString(self, value):
+        return value
+
+    def handleBitmask(self, value):
+        return value
+
+    def handleArray(self, elements):
+        return list(elements)
+
+    def handleStruct(self, members):
+        return dict(members)
+
+    def handleBlob(self, length):
+        # TODO
+        return None
+
+    def handleCall(self, functionName, args, ret):
+        call = (functionName, args, ret)
+        self.calls.append(call)
 
 
 def main():
-    parser = CallParser(sys.stdin)
-    parser.parse()
+    # Parse command line options
+    optparser = optparse.OptionParser(
+        usage='\n\t%prog [OPTIONS] REF_TRACE SRC_TRACE',
+        version='%%prog')
+    optparser.add_option(
+        '-v', '--verbose',
+        action="store_true",
+        dest="verbose", default=False,
+        help="verbose output")
+    (options, args) = optparser.parse_args(sys.argv[1:])
+
+    if len(args) != 2:
+        optparser.error('wrong number of arguments')
+
+    refParser = RefTraceParser(open(args[0], 'rt'))
+    refTrace = refParser.parse()
+    sys.stdout.write(str(refTrace))
+    srcParser = SrcTraceParser(open(args[1], 'rt'))
+    srcTrace = srcParser.parse()
+    refTrace.match(srcTrace)
 
 
 if __name__ == '__main__':