]> git.cworth.org Git - apitrace-tests/blobdiff - checker.py
Improve gl map buffer test.
[apitrace-tests] / checker.py
index c6eda92bf8d62437e22d8191c3b8b9f8326ff219..68949b16ce35fad0d6ee37c1ee0ad94125255ede 100755 (executable)
 
 import sys
 import optparse
+import os
 import re
+import subprocess
+
+
+class MatchObject:
+
+    def __init__(self):
+        self.params = {}
 
 
 class Matcher:
 
-    def match(self, value):
+    def match(self, value, mo):
         raise NotImplementedError
 
-    def _matchSequence(self, refValues, srcValues):
+    def _matchSequence(self, refValues, srcValues, mo):
         if not isinstance(srcValues, (list, tuple)):
             return False
 
@@ -43,7 +51,7 @@ class Matcher:
             return False
 
         for refValue, srcValue in zip(refValues, srcValues):
-            if not refValue.match(srcValue):
+            if not refValue.match(srcValue, mo):
                 return False
         return True
 
@@ -56,11 +64,21 @@ class Matcher:
 
 class WildcardMatcher(Matcher):
 
-    def match(self, value):
+    def __init__(self, name = ''):
+        self.name = name
+
+    def match(self, value, mo):
+        if self.name:
+            try:
+                refValue = mo.params[self.name]
+            except KeyError:
+                mo.params[self.name] = value
+            else:
+                return refValue == value
         return True
 
     def __str__(self):
-        return '*'
+        return '<' + self.name + '>'
 
 
 class LiteralMatcher(Matcher):
@@ -68,7 +86,7 @@ class LiteralMatcher(Matcher):
     def __init__(self, refValue):
         self.refValue = refValue
 
-    def match(self, value):
+    def match(self, value, mo):
         return self.refValue == value
 
     def __str__(self):
@@ -81,7 +99,7 @@ class ApproxMatcher(Matcher):
         self.refValue = refValue
         self.tolerance = tolerance
 
-    def match(self, value):
+    def match(self, value, mo):
         if not isinstance(value, float):
             return 
 
@@ -99,20 +117,33 @@ class BitmaskMatcher(Matcher):
     def __init__(self, refElements):
         self.refElements = refElements
 
-    def match(self, value):
-        return self._matchSequence(self.refElements, value)
+    def match(self, value, mo):
+        return self._matchSequence(self.refElements, value, mo)
 
     def __str__(self):
         return ' | '.join(map(str, self.refElements))
 
 
+class OffsetMatcher(Matcher):
+
+    def __init__(self, refValue, offset):
+        self.refValue = refValue
+        self.offset = offset
+
+    def match(self, value, mo):
+        return self.refValue.match(value - self.offset, mo)
+
+    def __str__(self):
+        return '%s + %i' % (self.refValue, self.offset)
+
+
 class ArrayMatcher(Matcher):
 
     def __init__(self, refElements):
         self.refElements = refElements
 
-    def match(self, value):
-        return self._matchSequence(self.refElements, value)
+    def match(self, value, mo):
+        return self._matchSequence(self.refElements, value, mo)
 
     def __str__(self):
         return '{' + ', '.join(map(str, self.refElements)) + '}'
@@ -123,7 +154,7 @@ class StructMatcher(Matcher):
     def __init__(self, refMembers):
         self.refMembers = refMembers
 
-    def match(self, value):
+    def match(self, value, mo):
         if not isinstance(value, dict):
             return False
 
@@ -136,25 +167,25 @@ class StructMatcher(Matcher):
             except KeyError:
                 return False
             else:
-                if not refMember.match(member):
+                if not refMember.match(member, mo):
                     return False
 
         return True
 
     def __str__(self):
-        print self.refMembers
         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
 
 
 class CallMatcher(Matcher):
 
-    def __init__(self, functionName, args, ret = None):
+    def __init__(self, callNo, functionName, args, ret):
+        self.callNo = callNo
         self.functionName = functionName
         self.args = args
         self.ret = ret
 
-    def match(self, call):
-        srcFunctionName, srcArgs, srcRet = call
+    def match(self, call, mo):
+        callNo, srcFunctionName, srcArgs, srcRet = call
 
         if self.functionName != srcFunctionName:
             return False
@@ -162,14 +193,18 @@ class CallMatcher(Matcher):
         refArgs = [value for name, value in self.args]
         srcArgs = [value for name, value in srcArgs]
 
-        if not self._matchSequence(refArgs, srcArgs):
+        if not self._matchSequence(refArgs, srcArgs, mo):
             return False
 
         if self.ret is None:
             if srcRet is not None:
                 return False
         else:
-            if not self.ret.match(srcRet):
+            if not self.ret.match(srcRet, mo):
+                return False
+
+        if self.callNo is not None:
+            if not self.callNo.match(callNo, mo):
                 return False
 
         return True
@@ -182,29 +217,38 @@ class CallMatcher(Matcher):
         return s
 
 
+class TraceMismatch(Exception):
+
+    pass
+
+
 class TraceMatcher:
 
     def __init__(self, calls):
         self.calls = calls
 
-    def match(self, trace):
-
-        srcCalls = iter(trace.calls)
+    def match(self, calls, verbose = False):
+        mo = MatchObject()
+        srcCalls = iter(calls)
         for refCall in self.calls:
+            if verbose:
+                print refCall
             skippedSrcCalls = []
             while True:
                 try:
                     srcCall = srcCalls.next()
                 except StopIteration:
                     if skippedSrcCalls:
-                        raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
+                        raise TraceMismatch('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
                     else:
-                        raise Exception('missing call %s' % refCall)
-                if refCall.match(srcCall):
+                        raise TraceMismatch('missing call %s' % refCall)
+                if verbose:
+                    print '\t%s %s%r = %r' % srcCall
+                if refCall.match(srcCall, mo):
                     break
                 else:
                     skippedSrcCalls.append(srcCall)
-        return True
+        return mo
 
     def __str__(self):
         return ''.join(['%s\n' % call for call in self.calls])
@@ -386,7 +430,7 @@ class Parser:
 
 #######################################################################
 
-ID, NUMBER, HEXNUM, STRING, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(14)
+ID, NUMBER, HEXNUM, STRING, WILDCARD, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, PLUS, VERT, BLOB = xrange(16)
 
 
 class CallScanner(Scanner):
@@ -408,6 +452,9 @@ class CallScanner(Scanner):
         # String IDs
         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
         
+        # Wildcard
+        (WILDCARD, r'<[^>]*>', False),
+        
         # Pragma
         (PRAGMA, r'#[^\r\n]*', False),
     ]
@@ -421,6 +468,7 @@ class CallScanner(Scanner):
         ',': COMMA,
         '&': AMP,
         '=': EQUAL,
+        '+': PLUS,
         '|': VERT,
     }
 
@@ -467,21 +515,18 @@ class TraceParser(Parser):
 
     def parse_element(self):
         if self.lookahead.type == PRAGMA:
-            # TODO
             token = self.consume()
             self.handlePragma(token.text)
         else:
             self.parse_call()
 
     def parse_call(self):
-        while self.lookahead.type == PRAGMA:
-            # TODO
-            token = self.consume()
-            print token.text
-
         if self.lookahead.type == NUMBER:
             token = self.consume()
-            callNo = int(token.text)
+            callNo = self.handleInt(int(token.text))
+        elif self.lookahead.type == WILDCARD:
+            token = self.consume()
+            callNo = self.handleWildcard((token.text[1:-1]))
         else:
             callNo = None
         
@@ -495,7 +540,7 @@ class TraceParser(Parser):
         else:
             ret = None
 
-        self.handleCall(functionName, args, ret)
+        self.handleCall(callNo, functionName, args, ret)
 
     def parse_pair(self):
         '''Parse a `name = value` pair.'''
@@ -531,6 +576,17 @@ class TraceParser(Parser):
                 value = self._parse_value()
                 flags.append(value)
             return self.handleBitmask(flags)
+        elif self.match(PLUS):
+            self.consume()
+            if self.match(NUMBER):
+                token = self.consume()
+                offset = int(token.text)
+            elif self.match(HEXNUM):
+                token = self.consume()
+                offset = int(token.text, 16)
+            else:
+                assert 0
+            return self.handleOffset(value, offset)
         else:
             return value
 
@@ -549,8 +605,12 @@ class TraceParser(Parser):
             return self.handleString(value)
         elif self.match(NUMBER):
             token = self.consume()
-            value = float(token.text)
-            return self.handleFloat(value)
+            if '.' in token.text:
+                value = float(token.text)
+                return self.handleFloat(value)
+            else:
+                value = int(token.text)
+                return self.handleInt(value)
         elif self.match(HEXNUM):
             token = self.consume()
             value = int(token.text, 16)
@@ -565,9 +625,13 @@ class TraceParser(Parser):
         elif self.match(BLOB):
             token = self.consume()
             self.consume(LPAREN)
-            length = self.consume()
+            token = self.consume()
+            length = int(token.text)
             self.consume(RPAREN)
             return self.handleBlob(length)
+        elif self.match(WILDCARD):
+            token = self.consume()
+            return self.handleWildcard(token.text[1:-1])
         else:
             self.error()
 
@@ -604,6 +668,9 @@ class TraceParser(Parser):
     def handleBitmask(self, value):
         raise NotImplementedError
 
+    def handleOffset(self, value, offset):
+        raise NotImplementedError
+
     def handleArray(self, value):
         raise NotImplementedError
 
@@ -611,15 +678,16 @@ class TraceParser(Parser):
         raise NotImplementedError
 
     def handleBlob(self, length):
+        return self.handleID('blob(%u)' % length)
+
+    def handleWildcard(self, name):
         raise NotImplementedError
-        # TODO
-        return WildcardMatcher()
 
-    def handleCall(self, functionName, args, ret):
+    def handleCall(self, callNo, functionName, args, ret):
         raise NotImplementedError
 
     def handlePragma(self, line):
-        pass
+        raise NotImplementedError
 
 
 class RefTraceParser(TraceParser):
@@ -647,19 +715,24 @@ class RefTraceParser(TraceParser):
     def handleBitmask(self, value):
         return BitmaskMatcher(value)
 
+    def handleOffset(self, value, offset):
+        return OffsetMatcher(value, offset)
+
     def handleArray(self, value):
         return ArrayMatcher(value)
 
     def handleStruct(self, value):
         return StructMatcher(value)
 
-    def handleBlob(self, length):
-        # TODO
-        return WildcardMatcher()
+    def handleWildcard(self, name):
+        return WildcardMatcher(name)
 
-    def handleCall(self, functionName, args, ret):
-        call = CallMatcher(functionName, args, ret)
+    def handleCall(self, callNo, functionName, args, ret):
+        call = CallMatcher(callNo, functionName, args, ret)
         self.calls.append(call)
+    
+    def handlePragma(self, line):
+        pass
 
 
 class SrcTraceParser(TraceParser):
@@ -670,7 +743,7 @@ class SrcTraceParser(TraceParser):
 
     def parse(self):
         TraceParser.parse(self)
-        return TraceMatcher(self.calls)
+        return self.calls
 
     def handleID(self, value):
         return value
@@ -693,36 +766,65 @@ class SrcTraceParser(TraceParser):
     def handleStruct(self, members):
         return dict(members)
 
-    def handleBlob(self, length):
-        # TODO
-        return None
-
-    def handleCall(self, functionName, args, ret):
-        call = (functionName, args, ret)
+    def handleCall(self, callNo, functionName, args, ret):
+        call = (callNo, functionName, args, ret)
         self.calls.append(call)
 
 
 def main():
     # Parse command line options
     optparser = optparse.OptionParser(
-        usage='\n\t%prog [OPTIONS] REF_TRACE SRC_TRACE',
+        usage='\n\t%prog [OPTIONS] REF_TXT SRC_TRACE',
         version='%%prog')
+    optparser.add_option(
+        '--apitrace', metavar='PROGRAM',
+        type='string', dest='apitrace', default=os.environ.get('APITRACE', 'apitrace'),
+        help='path to apitrace executable')
     optparser.add_option(
         '-v', '--verbose',
         action="store_true",
-        dest="verbose", default=False,
+        dest="verbose", default=True,
         help="verbose output")
     (options, args) = optparser.parse_args(sys.argv[1:])
 
     if len(args) != 2:
         optparser.error('wrong number of arguments')
 
-    refParser = RefTraceParser(open(args[0], 'rt'))
+    refFileName, srcFileName = args
+
+    refStream = open(refFileName, 'rt')
+    refParser = RefTraceParser(refStream)
     refTrace = refParser.parse()
-    sys.stdout.write(str(refTrace))
-    srcParser = SrcTraceParser(open(args[1], 'rt'))
+    if options.verbose:
+        sys.stdout.write('// Reference\n')
+        sys.stdout.write(str(refTrace))
+        sys.stdout.write('\n')
+
+    if srcFileName.endswith('.trace'):
+        cmd = [options.apitrace, 'dump', '--color=never', srcFileName]
+        p = subprocess.Popen(cmd, stdout=subprocess.PIPE)
+        srcStream = p.stdout
+    else:
+        srcStream = open(srcFileName, 'rt')
+    srcParser = SrcTraceParser(srcStream)
     srcTrace = srcParser.parse()
-    refTrace.match(srcTrace)
+    if options.verbose:
+        sys.stdout.write('// Source\n')
+        sys.stdout.write(''.join(['%s %s%r = %r\n' % call for call in srcTrace]))
+        sys.stdout.write('\n')
+
+    if options.verbose:
+        sys.stdout.write('// Matching\n')
+    mo = refTrace.match(srcTrace, options.verbose)
+    if options.verbose:
+        sys.stdout.write('\n')
+
+    if options.verbose:
+        sys.stdout.write('// Parameters\n')
+        paramNames = mo.params.keys()
+        paramNames.sort()
+        for paramName in paramNames:
+            print '%s = %r' % (paramName, mo.params[paramName])
 
 
 if __name__ == '__main__':