]> git.cworth.org Git - apitrace-tests/blobdiff - checker.py
Improve checker.
[apitrace-tests] / checker.py
old mode 100644 (file)
new mode 100755 (executable)
index 32c9663..c6eda92
 
 
 import sys
+import optparse
 import re
 
 
-class ValueMatcher:
+class Matcher:
 
     def match(self, value):
         raise NotImplementedError
 
+    def _matchSequence(self, refValues, srcValues):
+        if not isinstance(srcValues, (list, tuple)):
+            return False
+
+        if len(refValues) != len(srcValues):
+            return False
+
+        for refValue, srcValue in zip(refValues, srcValues):
+            if not refValue.match(srcValue):
+                return False
+        return True
+
     def __str__(self):
         raise NotImplementerError
 
+    def __repr__(self):
+        return str(self)
+
 
-class WildcardMatcher(ValueMatcher):
+class WildcardMatcher(Matcher):
 
     def match(self, value):
-        return true
+        return True
 
     def __str__(self):
         return '*'
 
 
-class LiteralValueMatcher(ValueMatcher):
+class LiteralMatcher(Matcher):
 
     def __init__(self, refValue):
         self.refValue = refValue
@@ -59,14 +75,16 @@ class LiteralValueMatcher(ValueMatcher):
         return repr(self.refValue)
 
 
-class ApproxValueMatcher(ValueMatcher):
-
+class ApproxMatcher(Matcher):
 
     def __init__(self, refValue, tolerance = 2**-23):
         self.refValue = refValue
         self.tolerance = tolerance
 
     def match(self, value):
+        if not isinstance(value, float):
+            return 
+
         error = abs(self.refValue - value)
         if self.refValue:
             error = error / self.refValue
@@ -76,29 +94,31 @@ class ApproxValueMatcher(ValueMatcher):
         return repr(self.refValue)
 
 
-class ArrayMatcher(ValueMatcher):
+class BitmaskMatcher(Matcher):
 
     def __init__(self, refElements):
         self.refElements = refElements
 
     def match(self, value):
-        if not isinstance(value, list):
-            return False
+        return self._matchSequence(self.refElements, value)
 
-        if len(value) != len(self.refElements):
-            return False
+    def __str__(self):
+        return ' | '.join(map(str, self.refElements))
 
-        for refElement, element in zip(self.refElements, value):
-            if not refElement.match(element):
-                return False
 
-        return True
+class ArrayMatcher(Matcher):
+
+    def __init__(self, refElements):
+        self.refElements = refElements
+
+    def match(self, value):
+        return self._matchSequence(self.refElements, value)
 
     def __str__(self):
         return '{' + ', '.join(map(str, self.refElements)) + '}'
 
 
-class StructMatcher(ValueMatcher):
+class StructMatcher(Matcher):
 
     def __init__(self, refMembers):
         self.refMembers = refMembers
@@ -126,49 +146,71 @@ class StructMatcher(ValueMatcher):
         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
 
 
-class CallMatcher:
+class CallMatcher(Matcher):
 
-    def __init__(self, refFunctionName, refArgs, refRet = None):
-        self.refFunctionName = refFunctionName
-        self.refArgs = refArgs
-        self.refRet = refRet
+    def __init__(self, functionName, args, ret = None):
+        self.functionName = functionName
+        self.args = args
+        self.ret = ret
 
-    def match(self, functionName, args, ret = None):
-        if refFunctionName != functionName:
-            return False
+    def match(self, call):
+        srcFunctionName, srcArgs, srcRet = call
 
-        if len(self.refArgs) != len(args):
+        if self.functionName != srcFunctionName:
             return False
 
-        for (refArgName, refArg), (argName, arg) in zip(self.refArgs, args):
-            if not refArg.match(arg):
-                return False
+        refArgs = [value for name, value in self.args]
+        srcArgs = [value for name, value in srcArgs]
 
-        if self.refRet is None:
-            if ret is not None:
+        if not self._matchSequence(refArgs, srcArgs):
+            return False
+
+        if self.ret is None:
+            if srcRet is not None:
                 return False
         else:
-            if not self.refRet.match(ret):
+            if not self.ret.match(srcRet):
                 return False
 
         return True
 
     def __str__(self):
-        s = self.refFunctionName
-        s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.refArgs]) + ')'
-        if self.refRet is not None:
-            s += ' = ' + str(self.refRet)
+        s = self.functionName
+        s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.args]) + ')'
+        if self.ret is not None:
+            s += ' = ' + str(self.ret)
         return s
 
 
 class TraceMatcher:
 
-    def __init__(self, refCalls):
-        self.refCalls = refCalls
+    def __init__(self, calls):
+        self.calls = calls
+
+    def match(self, trace):
+
+        srcCalls = iter(trace.calls)
+        for refCall in self.calls:
+            skippedSrcCalls = []
+            while True:
+                try:
+                    srcCall = srcCalls.next()
+                except StopIteration:
+                    if skippedSrcCalls:
+                        raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
+                    else:
+                        raise Exception('missing call %s' % refCall)
+                if refCall.match(srcCall):
+                    break
+                else:
+                    skippedSrcCalls.append(srcCall)
+        return True
 
     def __str__(self):
-        return ''.join(['%s\n' % refCall for refCall in self.refCalls])
+        return ''.join(['%s\n' % call for call in self.calls])
+
 
+#######################################################################
 
 EOF = -1
 SKIP = -2
@@ -342,20 +384,9 @@ class Parser:
         return token
 
 
-ID = 0
-NUMBER = 1
-HEXNUM = 2
-STRING = 3
-
-LPAREN = 4
-RPAREN = 5
-LCURLY = 6
-RCURLY = 7
-COMMA = 8
-AMP = 9
-EQUAL = 11
+#######################################################################
 
-BLOB = 12
+ID, NUMBER, HEXNUM, STRING, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(14)
 
 
 class CallScanner(Scanner):
@@ -366,16 +397,19 @@ class CallScanner(Scanner):
         (SKIP, r'[ \t\f\r\n\v]+', False),
 
         # Alphanumeric IDs
-        (ID, r'[a-zA-Z_\x80-\xff][a-zA-Z0-9_\x80-\xff]*', True),
+        (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True),
 
         # Numeric IDs
         (HEXNUM, r'-?0x[0-9a-fA-F]+', False),
         
         # Numeric IDs
-        (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)', False),
+        (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)(?:[eE][-+][0-9]+)?', False),
 
         # String IDs
         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
+        
+        # Pragma
+        (PRAGMA, r'#[^\r\n]*', False),
     ]
 
     # symbol table
@@ -387,6 +421,7 @@ class CallScanner(Scanner):
         ',': COMMA,
         '&': AMP,
         '=': EQUAL,
+        '|': VERT,
     }
 
     # literal table
@@ -416,17 +451,34 @@ class CallLexer(Lexer):
         return type, text
 
 
-class CallParser(Parser):
+class TraceParser(Parser):
 
     def __init__(self, stream):
         lexer = CallLexer(fp = stream)
         Parser.__init__(self, lexer)
 
+    def eof(self):
+        return self.match(EOF)
+
     def parse(self):
-        while not self.match(EOF):
+        while not self.eof():
+            self.parse_element()
+        return TraceMatcher(self.calls)
+
+    def parse_element(self):
+        if self.lookahead.type == PRAGMA:
+            # TODO
+            token = self.consume()
+            self.handlePragma(token.text)
+        else:
             self.parse_call()
 
     def parse_call(self):
+        while self.lookahead.type == PRAGMA:
+            # TODO
+            token = self.consume()
+            print token.text
+
         if self.lookahead.type == NUMBER:
             token = self.consume()
             callNo = int(token.text)
@@ -443,7 +495,7 @@ class CallParser(Parser):
         else:
             ret = None
 
-        self.handle_call(callNo, functionName, args, ret)
+        self.handleCall(functionName, args, ret)
 
     def parse_pair(self):
         '''Parse a `name = value` pair.'''
@@ -471,39 +523,51 @@ class CallParser(Parser):
             return name, value
 
     def parse_value(self):
+        value = self._parse_value()
+        if self.match(VERT):
+            flags = [value]
+            while self.match(VERT):
+                self.consume()
+                value = self._parse_value()
+                flags.append(value)
+            return self.handleBitmask(flags)
+        else:
+            return value
+
+    def _parse_value(self):
         if self.match(AMP):
             self.consume()
             value = [self.parse_value()]
-            return ArrayMatcher(value)
+            return self.handleArray(value)
         elif self.match(ID):
             token = self.consume()
             value = token.text
-            return LiteralValueMatcher(value)
+            return self.handleID(value)
         elif self.match(STRING):
             token = self.consume()
             value = token.text
-            return LiteralValueMatcher(value)
+            return self.handleString(value)
         elif self.match(NUMBER):
             token = self.consume()
             value = float(token.text)
-            return ApproxValueMatcher(value)
+            return self.handleFloat(value)
         elif self.match(HEXNUM):
             token = self.consume()
             value = int(token.text, 16)
-            return LiteralValueMatcher(value)
+            return self.handleInt(value)
         elif self.match(LCURLY):
             value = self.parse_sequence(LCURLY, RCURLY, self.parse_opt_pair)
             if len(value) and isinstance(value[0], tuple):
-                return StructMatcher(dict(value))
+                value = dict(value)
+                return self.handleStruct(value)
             else:
-                return ArrayMatcher(value)
+                return self.handleArray(value)
         elif self.match(BLOB):
             token = self.consume()
             self.consume(LPAREN)
             length = self.consume()
             self.consume(RPAREN)
-            # TODO
-            return WildcardMatcher()
+            return self.handleBlob(length)
         else:
             self.error()
 
@@ -524,19 +588,141 @@ class CallParser(Parser):
         self.consume(rtype)
 
         return elements
+    
+    def handleID(self, value):
+        raise NotImplementedError
+
+    def handleInt(self, value):
+        raise NotImplementedError
+
+    def handleFloat(self, value):
+        raise NotImplementedError
+
+    def handleString(self, value):
+        raise NotImplementedError
+
+    def handleBitmask(self, value):
+        raise NotImplementedError
+
+    def handleArray(self, value):
+        raise NotImplementedError
+
+    def handleStruct(self, value):
+        raise NotImplementedError
+
+    def handleBlob(self, length):
+        raise NotImplementedError
+        # TODO
+        return WildcardMatcher()
+
+    def handleCall(self, functionName, args, ret):
+        raise NotImplementedError
+
+    def handlePragma(self, line):
+        pass
+
+
+class RefTraceParser(TraceParser):
+
+    def __init__(self, stream):
+        TraceParser.__init__(self, stream)
+        self.calls = []
+
+    def parse(self):
+        TraceParser.parse(self)
+        return TraceMatcher(self.calls)
+
+    def handleID(self, value):
+        return LiteralMatcher(value)
+
+    def handleInt(self, value):
+        return LiteralMatcher(value)
+
+    def handleFloat(self, value):
+        return ApproxMatcher(value)
+
+    def handleString(self, value):
+        return LiteralMatcher(value)
+
+    def handleBitmask(self, value):
+        return BitmaskMatcher(value)
+
+    def handleArray(self, value):
+        return ArrayMatcher(value)
+
+    def handleStruct(self, value):
+        return StructMatcher(value)
+
+    def handleBlob(self, length):
+        # TODO
+        return WildcardMatcher()
+
+    def handleCall(self, functionName, args, ret):
+        call = CallMatcher(functionName, args, ret)
+        self.calls.append(call)
+
+
+class SrcTraceParser(TraceParser):
+
+    def __init__(self, stream):
+        TraceParser.__init__(self, stream)
+        self.calls = []
+
+    def parse(self):
+        TraceParser.parse(self)
+        return TraceMatcher(self.calls)
+
+    def handleID(self, value):
+        return value
+
+    def handleInt(self, value):
+        return int(value)
+
+    def handleFloat(self, value):
+        return float(value)
+
+    def handleString(self, value):
+        return value
+
+    def handleBitmask(self, value):
+        return value
+
+    def handleArray(self, elements):
+        return list(elements)
+
+    def handleStruct(self, members):
+        return dict(members)
 
-    def handle_call(self, callNo, functionName, args, ret):
-        matcher = CallMatcher(functionName, args, ret)
+    def handleBlob(self, length):
+        # TODO
+        return None
 
-        if callNo is not None:
-            sys.stdout.write('%u ' % callNo)
-        sys.stdout.write(str(matcher))
-        sys.stdout.write('\n')
+    def handleCall(self, functionName, args, ret):
+        call = (functionName, args, ret)
+        self.calls.append(call)
 
 
 def main():
-    parser = CallParser(sys.stdin)
-    parser.parse()
+    # Parse command line options
+    optparser = optparse.OptionParser(
+        usage='\n\t%prog [OPTIONS] REF_TRACE SRC_TRACE',
+        version='%%prog')
+    optparser.add_option(
+        '-v', '--verbose',
+        action="store_true",
+        dest="verbose", default=False,
+        help="verbose output")
+    (options, args) = optparser.parse_args(sys.argv[1:])
+
+    if len(args) != 2:
+        optparser.error('wrong number of arguments')
+
+    refParser = RefTraceParser(open(args[0], 'rt'))
+    refTrace = refParser.parse()
+    sys.stdout.write(str(refTrace))
+    srcParser = SrcTraceParser(open(args[1], 'rt'))
+    srcTrace = srcParser.parse()
+    refTrace.match(srcTrace)
 
 
 if __name__ == '__main__':