]> git.cworth.org Git - apitrace-tests/blobdiff - checker.py
Improve gl map buffer test.
[apitrace-tests] / checker.py
old mode 100644 (file)
new mode 100755 (executable)
index 32c9663..68949b1
 
 
 import sys
+import optparse
+import os
 import re
+import subprocess
 
 
-class ValueMatcher:
+class MatchObject:
 
-    def match(self, value):
+    def __init__(self):
+        self.params = {}
+
+
+class Matcher:
+
+    def match(self, value, mo):
         raise NotImplementedError
 
+    def _matchSequence(self, refValues, srcValues, mo):
+        if not isinstance(srcValues, (list, tuple)):
+            return False
+
+        if len(refValues) != len(srcValues):
+            return False
+
+        for refValue, srcValue in zip(refValues, srcValues):
+            if not refValue.match(srcValue, mo):
+                return False
+        return True
+
     def __str__(self):
         raise NotImplementerError
 
+    def __repr__(self):
+        return str(self)
+
 
-class WildcardMatcher(ValueMatcher):
+class WildcardMatcher(Matcher):
 
-    def match(self, value):
-        return true
+    def __init__(self, name = ''):
+        self.name = name
+
+    def match(self, value, mo):
+        if self.name:
+            try:
+                refValue = mo.params[self.name]
+            except KeyError:
+                mo.params[self.name] = value
+            else:
+                return refValue == value
+        return True
 
     def __str__(self):
-        return '*'
+        return '<' + self.name + '>'
 
 
-class LiteralValueMatcher(ValueMatcher):
+class LiteralMatcher(Matcher):
 
     def __init__(self, refValue):
         self.refValue = refValue
 
-    def match(self, value):
+    def match(self, value, mo):
         return self.refValue == value
 
     def __str__(self):
         return repr(self.refValue)
 
 
-class ApproxValueMatcher(ValueMatcher):
-
+class ApproxMatcher(Matcher):
 
     def __init__(self, refValue, tolerance = 2**-23):
         self.refValue = refValue
         self.tolerance = tolerance
 
-    def match(self, value):
+    def match(self, value, mo):
+        if not isinstance(value, float):
+            return 
+
         error = abs(self.refValue - value)
         if self.refValue:
             error = error / self.refValue
@@ -76,34 +112,49 @@ class ApproxValueMatcher(ValueMatcher):
         return repr(self.refValue)
 
 
-class ArrayMatcher(ValueMatcher):
+class BitmaskMatcher(Matcher):
 
     def __init__(self, refElements):
         self.refElements = refElements
 
-    def match(self, value):
-        if not isinstance(value, list):
-            return False
+    def match(self, value, mo):
+        return self._matchSequence(self.refElements, value, mo)
 
-        if len(value) != len(self.refElements):
-            return False
+    def __str__(self):
+        return ' | '.join(map(str, self.refElements))
 
-        for refElement, element in zip(self.refElements, value):
-            if not refElement.match(element):
-                return False
 
-        return True
+class OffsetMatcher(Matcher):
+
+    def __init__(self, refValue, offset):
+        self.refValue = refValue
+        self.offset = offset
+
+    def match(self, value, mo):
+        return self.refValue.match(value - self.offset, mo)
+
+    def __str__(self):
+        return '%s + %i' % (self.refValue, self.offset)
+
+
+class ArrayMatcher(Matcher):
+
+    def __init__(self, refElements):
+        self.refElements = refElements
+
+    def match(self, value, mo):
+        return self._matchSequence(self.refElements, value, mo)
 
     def __str__(self):
         return '{' + ', '.join(map(str, self.refElements)) + '}'
 
 
-class StructMatcher(ValueMatcher):
+class StructMatcher(Matcher):
 
     def __init__(self, refMembers):
         self.refMembers = refMembers
 
-    def match(self, value):
+    def match(self, value, mo):
         if not isinstance(value, dict):
             return False
 
@@ -116,59 +167,94 @@ class StructMatcher(ValueMatcher):
             except KeyError:
                 return False
             else:
-                if not refMember.match(member):
+                if not refMember.match(member, mo):
                     return False
 
         return True
 
     def __str__(self):
-        print self.refMembers
         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
 
 
-class CallMatcher:
+class CallMatcher(Matcher):
 
-    def __init__(self, refFunctionName, refArgs, refRet = None):
-        self.refFunctionName = refFunctionName
-        self.refArgs = refArgs
-        self.refRet = refRet
+    def __init__(self, callNo, functionName, args, ret):
+        self.callNo = callNo
+        self.functionName = functionName
+        self.args = args
+        self.ret = ret
 
-    def match(self, functionName, args, ret = None):
-        if refFunctionName != functionName:
-            return False
+    def match(self, call, mo):
+        callNo, srcFunctionName, srcArgs, srcRet = call
 
-        if len(self.refArgs) != len(args):
+        if self.functionName != srcFunctionName:
             return False
 
-        for (refArgName, refArg), (argName, arg) in zip(self.refArgs, args):
-            if not refArg.match(arg):
-                return False
+        refArgs = [value for name, value in self.args]
+        srcArgs = [value for name, value in srcArgs]
 
-        if self.refRet is None:
-            if ret is not None:
+        if not self._matchSequence(refArgs, srcArgs, mo):
+            return False
+
+        if self.ret is None:
+            if srcRet is not None:
                 return False
         else:
-            if not self.refRet.match(ret):
+            if not self.ret.match(srcRet, mo):
+                return False
+
+        if self.callNo is not None:
+            if not self.callNo.match(callNo, mo):
                 return False
 
         return True
 
     def __str__(self):
-        s = self.refFunctionName
-        s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.refArgs]) + ')'
-        if self.refRet is not None:
-            s += ' = ' + str(self.refRet)
+        s = self.functionName
+        s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.args]) + ')'
+        if self.ret is not None:
+            s += ' = ' + str(self.ret)
         return s
 
 
+class TraceMismatch(Exception):
+
+    pass
+
+
 class TraceMatcher:
 
-    def __init__(self, refCalls):
-        self.refCalls = refCalls
+    def __init__(self, calls):
+        self.calls = calls
+
+    def match(self, calls, verbose = False):
+        mo = MatchObject()
+        srcCalls = iter(calls)
+        for refCall in self.calls:
+            if verbose:
+                print refCall
+            skippedSrcCalls = []
+            while True:
+                try:
+                    srcCall = srcCalls.next()
+                except StopIteration:
+                    if skippedSrcCalls:
+                        raise TraceMismatch('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
+                    else:
+                        raise TraceMismatch('missing call %s' % refCall)
+                if verbose:
+                    print '\t%s %s%r = %r' % srcCall
+                if refCall.match(srcCall, mo):
+                    break
+                else:
+                    skippedSrcCalls.append(srcCall)
+        return mo
 
     def __str__(self):
-        return ''.join(['%s\n' % refCall for refCall in self.refCalls])
+        return ''.join(['%s\n' % call for call in self.calls])
+
 
+#######################################################################
 
 EOF = -1
 SKIP = -2
@@ -342,20 +428,9 @@ class Parser:
         return token
 
 
-ID = 0
-NUMBER = 1
-HEXNUM = 2
-STRING = 3
+#######################################################################
 
-LPAREN = 4
-RPAREN = 5
-LCURLY = 6
-RCURLY = 7
-COMMA = 8
-AMP = 9
-EQUAL = 11
-
-BLOB = 12
+ID, NUMBER, HEXNUM, STRING, WILDCARD, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, PLUS, VERT, BLOB = xrange(16)
 
 
 class CallScanner(Scanner):
@@ -366,16 +441,22 @@ class CallScanner(Scanner):
         (SKIP, r'[ \t\f\r\n\v]+', False),
 
         # Alphanumeric IDs
-        (ID, r'[a-zA-Z_\x80-\xff][a-zA-Z0-9_\x80-\xff]*', True),
+        (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True),
 
         # Numeric IDs
         (HEXNUM, r'-?0x[0-9a-fA-F]+', False),
         
         # Numeric IDs
-        (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)', False),
+        (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)(?:[eE][-+][0-9]+)?', False),
 
         # String IDs
         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
+        
+        # Wildcard
+        (WILDCARD, r'<[^>]*>', False),
+        
+        # Pragma
+        (PRAGMA, r'#[^\r\n]*', False),
     ]
 
     # symbol table
@@ -387,6 +468,8 @@ class CallScanner(Scanner):
         ',': COMMA,
         '&': AMP,
         '=': EQUAL,
+        '+': PLUS,
+        '|': VERT,
     }
 
     # literal table
@@ -416,20 +499,34 @@ class CallLexer(Lexer):
         return type, text
 
 
-class CallParser(Parser):
+class TraceParser(Parser):
 
     def __init__(self, stream):
         lexer = CallLexer(fp = stream)
         Parser.__init__(self, lexer)
 
+    def eof(self):
+        return self.match(EOF)
+
     def parse(self):
-        while not self.match(EOF):
+        while not self.eof():
+            self.parse_element()
+        return TraceMatcher(self.calls)
+
+    def parse_element(self):
+        if self.lookahead.type == PRAGMA:
+            token = self.consume()
+            self.handlePragma(token.text)
+        else:
             self.parse_call()
 
     def parse_call(self):
         if self.lookahead.type == NUMBER:
             token = self.consume()
-            callNo = int(token.text)
+            callNo = self.handleInt(int(token.text))
+        elif self.lookahead.type == WILDCARD:
+            token = self.consume()
+            callNo = self.handleWildcard((token.text[1:-1]))
         else:
             callNo = None
         
@@ -443,7 +540,7 @@ class CallParser(Parser):
         else:
             ret = None
 
-        self.handle_call(callNo, functionName, args, ret)
+        self.handleCall(callNo, functionName, args, ret)
 
     def parse_pair(self):
         '''Parse a `name = value` pair.'''
@@ -471,39 +568,70 @@ class CallParser(Parser):
             return name, value
 
     def parse_value(self):
+        value = self._parse_value()
+        if self.match(VERT):
+            flags = [value]
+            while self.match(VERT):
+                self.consume()
+                value = self._parse_value()
+                flags.append(value)
+            return self.handleBitmask(flags)
+        elif self.match(PLUS):
+            self.consume()
+            if self.match(NUMBER):
+                token = self.consume()
+                offset = int(token.text)
+            elif self.match(HEXNUM):
+                token = self.consume()
+                offset = int(token.text, 16)
+            else:
+                assert 0
+            return self.handleOffset(value, offset)
+        else:
+            return value
+
+    def _parse_value(self):
         if self.match(AMP):
             self.consume()
             value = [self.parse_value()]
-            return ArrayMatcher(value)
+            return self.handleArray(value)
         elif self.match(ID):
             token = self.consume()
             value = token.text
-            return LiteralValueMatcher(value)
+            return self.handleID(value)
         elif self.match(STRING):
             token = self.consume()
             value = token.text
-            return LiteralValueMatcher(value)
+            return self.handleString(value)
         elif self.match(NUMBER):
             token = self.consume()
-            value = float(token.text)
-            return ApproxValueMatcher(value)
+            if '.' in token.text:
+                value = float(token.text)
+                return self.handleFloat(value)
+            else:
+                value = int(token.text)
+                return self.handleInt(value)
         elif self.match(HEXNUM):
             token = self.consume()
             value = int(token.text, 16)
-            return LiteralValueMatcher(value)
+            return self.handleInt(value)
         elif self.match(LCURLY):
             value = self.parse_sequence(LCURLY, RCURLY, self.parse_opt_pair)
             if len(value) and isinstance(value[0], tuple):
-                return StructMatcher(dict(value))
+                value = dict(value)
+                return self.handleStruct(value)
             else:
-                return ArrayMatcher(value)
+                return self.handleArray(value)
         elif self.match(BLOB):
             token = self.consume()
             self.consume(LPAREN)
-            length = self.consume()
+            token = self.consume()
+            length = int(token.text)
             self.consume(RPAREN)
-            # TODO
-            return WildcardMatcher()
+            return self.handleBlob(length)
+        elif self.match(WILDCARD):
+            token = self.consume()
+            return self.handleWildcard(token.text[1:-1])
         else:
             self.error()
 
@@ -524,19 +652,179 @@ class CallParser(Parser):
         self.consume(rtype)
 
         return elements
+    
+    def handleID(self, value):
+        raise NotImplementedError
 
-    def handle_call(self, callNo, functionName, args, ret):
-        matcher = CallMatcher(functionName, args, ret)
+    def handleInt(self, value):
+        raise NotImplementedError
 
-        if callNo is not None:
-            sys.stdout.write('%u ' % callNo)
-        sys.stdout.write(str(matcher))
-        sys.stdout.write('\n')
+    def handleFloat(self, value):
+        raise NotImplementedError
+
+    def handleString(self, value):
+        raise NotImplementedError
+
+    def handleBitmask(self, value):
+        raise NotImplementedError
+
+    def handleOffset(self, value, offset):
+        raise NotImplementedError
+
+    def handleArray(self, value):
+        raise NotImplementedError
+
+    def handleStruct(self, value):
+        raise NotImplementedError
+
+    def handleBlob(self, length):
+        return self.handleID('blob(%u)' % length)
+
+    def handleWildcard(self, name):
+        raise NotImplementedError
+
+    def handleCall(self, callNo, functionName, args, ret):
+        raise NotImplementedError
+
+    def handlePragma(self, line):
+        raise NotImplementedError
+
+
+class RefTraceParser(TraceParser):
+
+    def __init__(self, stream):
+        TraceParser.__init__(self, stream)
+        self.calls = []
+
+    def parse(self):
+        TraceParser.parse(self)
+        return TraceMatcher(self.calls)
+
+    def handleID(self, value):
+        return LiteralMatcher(value)
+
+    def handleInt(self, value):
+        return LiteralMatcher(value)
+
+    def handleFloat(self, value):
+        return ApproxMatcher(value)
+
+    def handleString(self, value):
+        return LiteralMatcher(value)
+
+    def handleBitmask(self, value):
+        return BitmaskMatcher(value)
+
+    def handleOffset(self, value, offset):
+        return OffsetMatcher(value, offset)
+
+    def handleArray(self, value):
+        return ArrayMatcher(value)
+
+    def handleStruct(self, value):
+        return StructMatcher(value)
+
+    def handleWildcard(self, name):
+        return WildcardMatcher(name)
+
+    def handleCall(self, callNo, functionName, args, ret):
+        call = CallMatcher(callNo, functionName, args, ret)
+        self.calls.append(call)
+    
+    def handlePragma(self, line):
+        pass
+
+
+class SrcTraceParser(TraceParser):
+
+    def __init__(self, stream):
+        TraceParser.__init__(self, stream)
+        self.calls = []
+
+    def parse(self):
+        TraceParser.parse(self)
+        return self.calls
+
+    def handleID(self, value):
+        return value
+
+    def handleInt(self, value):
+        return int(value)
+
+    def handleFloat(self, value):
+        return float(value)
+
+    def handleString(self, value):
+        return value
+
+    def handleBitmask(self, value):
+        return value
+
+    def handleArray(self, elements):
+        return list(elements)
+
+    def handleStruct(self, members):
+        return dict(members)
+
+    def handleCall(self, callNo, functionName, args, ret):
+        call = (callNo, functionName, args, ret)
+        self.calls.append(call)
 
 
 def main():
-    parser = CallParser(sys.stdin)
-    parser.parse()
+    # Parse command line options
+    optparser = optparse.OptionParser(
+        usage='\n\t%prog [OPTIONS] REF_TXT SRC_TRACE',
+        version='%%prog')
+    optparser.add_option(
+        '--apitrace', metavar='PROGRAM',
+        type='string', dest='apitrace', default=os.environ.get('APITRACE', 'apitrace'),
+        help='path to apitrace executable')
+    optparser.add_option(
+        '-v', '--verbose',
+        action="store_true",
+        dest="verbose", default=True,
+        help="verbose output")
+    (options, args) = optparser.parse_args(sys.argv[1:])
+
+    if len(args) != 2:
+        optparser.error('wrong number of arguments')
+
+    refFileName, srcFileName = args
+
+    refStream = open(refFileName, 'rt')
+    refParser = RefTraceParser(refStream)
+    refTrace = refParser.parse()
+    if options.verbose:
+        sys.stdout.write('// Reference\n')
+        sys.stdout.write(str(refTrace))
+        sys.stdout.write('\n')
+
+    if srcFileName.endswith('.trace'):
+        cmd = [options.apitrace, 'dump', '--color=never', srcFileName]
+        p = subprocess.Popen(cmd, stdout=subprocess.PIPE)
+        srcStream = p.stdout
+    else:
+        srcStream = open(srcFileName, 'rt')
+    srcParser = SrcTraceParser(srcStream)
+    srcTrace = srcParser.parse()
+    if options.verbose:
+        sys.stdout.write('// Source\n')
+        sys.stdout.write(''.join(['%s %s%r = %r\n' % call for call in srcTrace]))
+        sys.stdout.write('\n')
+
+    if options.verbose:
+        sys.stdout.write('// Matching\n')
+    mo = refTrace.match(srcTrace, options.verbose)
+    if options.verbose:
+        sys.stdout.write('\n')
+
+    if options.verbose:
+        sys.stdout.write('// Parameters\n')
+        paramNames = mo.params.keys()
+        paramNames.sort()
+        for paramName in paramNames:
+            print '%s = %r' % (paramName, mo.params[paramName])
 
 
 if __name__ == '__main__':