]> git.cworth.org Git - apitrace-tests/commitdiff
Further improvements to checker.
authorJosé Fonseca <jose.r.fonseca@gmail.com>
Thu, 22 Nov 2012 22:03:17 +0000 (22:03 +0000)
committerJosé Fonseca <jose.r.fonseca@gmail.com>
Thu, 22 Nov 2012 22:03:17 +0000 (22:03 +0000)
checker.py

index c6eda92bf8d62437e22d8191c3b8b9f8326ff219..08ef9d324c9d12bc1a2134a491131d3e11376e3a 100755 (executable)
@@ -30,12 +30,18 @@ import optparse
 import re
 
 
 import re
 
 
+class MatchObject:
+
+    def __init__(self):
+        self.params = {}
+
+
 class Matcher:
 
 class Matcher:
 
-    def match(self, value):
+    def match(self, value, mo):
         raise NotImplementedError
 
         raise NotImplementedError
 
-    def _matchSequence(self, refValues, srcValues):
+    def _matchSequence(self, refValues, srcValues, mo):
         if not isinstance(srcValues, (list, tuple)):
             return False
 
         if not isinstance(srcValues, (list, tuple)):
             return False
 
@@ -43,7 +49,7 @@ class Matcher:
             return False
 
         for refValue, srcValue in zip(refValues, srcValues):
             return False
 
         for refValue, srcValue in zip(refValues, srcValues):
-            if not refValue.match(srcValue):
+            if not refValue.match(srcValue, mo):
                 return False
         return True
 
                 return False
         return True
 
@@ -56,11 +62,21 @@ class Matcher:
 
 class WildcardMatcher(Matcher):
 
 
 class WildcardMatcher(Matcher):
 
-    def match(self, value):
+    def __init__(self, name = ''):
+        self.name = name
+
+    def match(self, value, mo):
+        if self.name:
+            try:
+                refValue = mo.params[self.name]
+            except KeyError:
+                mo.params[self.name] = value
+            else:
+                return refValue == value
         return True
 
     def __str__(self):
         return True
 
     def __str__(self):
-        return '*'
+        return '<' + self.name + '>'
 
 
 class LiteralMatcher(Matcher):
 
 
 class LiteralMatcher(Matcher):
@@ -68,7 +84,7 @@ class LiteralMatcher(Matcher):
     def __init__(self, refValue):
         self.refValue = refValue
 
     def __init__(self, refValue):
         self.refValue = refValue
 
-    def match(self, value):
+    def match(self, value, mo):
         return self.refValue == value
 
     def __str__(self):
         return self.refValue == value
 
     def __str__(self):
@@ -81,7 +97,7 @@ class ApproxMatcher(Matcher):
         self.refValue = refValue
         self.tolerance = tolerance
 
         self.refValue = refValue
         self.tolerance = tolerance
 
-    def match(self, value):
+    def match(self, value, mo):
         if not isinstance(value, float):
             return 
 
         if not isinstance(value, float):
             return 
 
@@ -99,8 +115,8 @@ class BitmaskMatcher(Matcher):
     def __init__(self, refElements):
         self.refElements = refElements
 
     def __init__(self, refElements):
         self.refElements = refElements
 
-    def match(self, value):
-        return self._matchSequence(self.refElements, value)
+    def match(self, value, mo):
+        return self._matchSequence(self.refElements, value, mo)
 
     def __str__(self):
         return ' | '.join(map(str, self.refElements))
 
     def __str__(self):
         return ' | '.join(map(str, self.refElements))
@@ -111,8 +127,8 @@ class ArrayMatcher(Matcher):
     def __init__(self, refElements):
         self.refElements = refElements
 
     def __init__(self, refElements):
         self.refElements = refElements
 
-    def match(self, value):
-        return self._matchSequence(self.refElements, value)
+    def match(self, value, mo):
+        return self._matchSequence(self.refElements, value, mo)
 
     def __str__(self):
         return '{' + ', '.join(map(str, self.refElements)) + '}'
 
     def __str__(self):
         return '{' + ', '.join(map(str, self.refElements)) + '}'
@@ -123,7 +139,7 @@ class StructMatcher(Matcher):
     def __init__(self, refMembers):
         self.refMembers = refMembers
 
     def __init__(self, refMembers):
         self.refMembers = refMembers
 
-    def match(self, value):
+    def match(self, value, mo):
         if not isinstance(value, dict):
             return False
 
         if not isinstance(value, dict):
             return False
 
@@ -136,25 +152,25 @@ class StructMatcher(Matcher):
             except KeyError:
                 return False
             else:
             except KeyError:
                 return False
             else:
-                if not refMember.match(member):
+                if not refMember.match(member, mo):
                     return False
 
         return True
 
     def __str__(self):
                     return False
 
         return True
 
     def __str__(self):
-        print self.refMembers
         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
 
 
 class CallMatcher(Matcher):
 
         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
 
 
 class CallMatcher(Matcher):
 
-    def __init__(self, functionName, args, ret = None):
+    def __init__(self, callNo, functionName, args, ret):
+        self.callNo = callNo
         self.functionName = functionName
         self.args = args
         self.ret = ret
 
         self.functionName = functionName
         self.args = args
         self.ret = ret
 
-    def match(self, call):
-        srcFunctionName, srcArgs, srcRet = call
+    def match(self, call, mo):
+        callNo, srcFunctionName, srcArgs, srcRet = call
 
         if self.functionName != srcFunctionName:
             return False
 
         if self.functionName != srcFunctionName:
             return False
@@ -162,14 +178,18 @@ class CallMatcher(Matcher):
         refArgs = [value for name, value in self.args]
         srcArgs = [value for name, value in srcArgs]
 
         refArgs = [value for name, value in self.args]
         srcArgs = [value for name, value in srcArgs]
 
-        if not self._matchSequence(refArgs, srcArgs):
+        if not self._matchSequence(refArgs, srcArgs, mo):
             return False
 
         if self.ret is None:
             if srcRet is not None:
                 return False
         else:
             return False
 
         if self.ret is None:
             if srcRet is not None:
                 return False
         else:
-            if not self.ret.match(srcRet):
+            if not self.ret.match(srcRet, mo):
+                return False
+
+        if self.callNo is not None:
+            if not self.callNo.match(callNo, mo):
                 return False
 
         return True
                 return False
 
         return True
@@ -188,7 +208,7 @@ class TraceMatcher:
         self.calls = calls
 
     def match(self, trace):
         self.calls = calls
 
     def match(self, trace):
-
+        mo = MatchObject()
         srcCalls = iter(trace.calls)
         for refCall in self.calls:
             skippedSrcCalls = []
         srcCalls = iter(trace.calls)
         for refCall in self.calls:
             skippedSrcCalls = []
@@ -200,11 +220,11 @@ class TraceMatcher:
                         raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
                     else:
                         raise Exception('missing call %s' % refCall)
                         raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
                     else:
                         raise Exception('missing call %s' % refCall)
-                if refCall.match(srcCall):
+                if refCall.match(srcCall, mo):
                     break
                 else:
                     skippedSrcCalls.append(srcCall)
                     break
                 else:
                     skippedSrcCalls.append(srcCall)
-        return True
+        return mo
 
     def __str__(self):
         return ''.join(['%s\n' % call for call in self.calls])
 
     def __str__(self):
         return ''.join(['%s\n' % call for call in self.calls])
@@ -386,7 +406,7 @@ class Parser:
 
 #######################################################################
 
 
 #######################################################################
 
-ID, NUMBER, HEXNUM, STRING, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(14)
+ID, NUMBER, HEXNUM, STRING, WILDCARD, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(15)
 
 
 class CallScanner(Scanner):
 
 
 class CallScanner(Scanner):
@@ -408,6 +428,9 @@ class CallScanner(Scanner):
         # String IDs
         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
         
         # String IDs
         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
         
+        # Wildcard
+        (WILDCARD, r'<[^>]*>', False),
+        
         # Pragma
         (PRAGMA, r'#[^\r\n]*', False),
     ]
         # Pragma
         (PRAGMA, r'#[^\r\n]*', False),
     ]
@@ -477,11 +500,13 @@ class TraceParser(Parser):
         while self.lookahead.type == PRAGMA:
             # TODO
             token = self.consume()
         while self.lookahead.type == PRAGMA:
             # TODO
             token = self.consume()
-            print token.text
 
         if self.lookahead.type == NUMBER:
             token = self.consume()
 
         if self.lookahead.type == NUMBER:
             token = self.consume()
-            callNo = int(token.text)
+            callNo = self.handleInt(int(token.text))
+        elif self.lookahead.type == WILDCARD:
+            token = self.consume()
+            callNo = self.handleWildcard((token.text[1:-1]))
         else:
             callNo = None
         
         else:
             callNo = None
         
@@ -495,7 +520,7 @@ class TraceParser(Parser):
         else:
             ret = None
 
         else:
             ret = None
 
-        self.handleCall(functionName, args, ret)
+        self.handleCall(callNo, functionName, args, ret)
 
     def parse_pair(self):
         '''Parse a `name = value` pair.'''
 
     def parse_pair(self):
         '''Parse a `name = value` pair.'''
@@ -565,9 +590,13 @@ class TraceParser(Parser):
         elif self.match(BLOB):
             token = self.consume()
             self.consume(LPAREN)
         elif self.match(BLOB):
             token = self.consume()
             self.consume(LPAREN)
-            length = self.consume()
+            token = self.consume()
+            length = int(token.text)
             self.consume(RPAREN)
             return self.handleBlob(length)
             self.consume(RPAREN)
             return self.handleBlob(length)
+        elif self.match(WILDCARD):
+            token = self.consume()
+            return self.handleWildcard(token.text[1:-1])
         else:
             self.error()
 
         else:
             self.error()
 
@@ -611,11 +640,12 @@ class TraceParser(Parser):
         raise NotImplementedError
 
     def handleBlob(self, length):
         raise NotImplementedError
 
     def handleBlob(self, length):
+        return self.handleID('blob(%u)' % length)
+
+    def handleWildcard(self, name):
         raise NotImplementedError
         raise NotImplementedError
-        # TODO
-        return WildcardMatcher()
 
 
-    def handleCall(self, functionName, args, ret):
+    def handleCall(self, callNo, functionName, args, ret):
         raise NotImplementedError
 
     def handlePragma(self, line):
         raise NotImplementedError
 
     def handlePragma(self, line):
@@ -653,12 +683,11 @@ class RefTraceParser(TraceParser):
     def handleStruct(self, value):
         return StructMatcher(value)
 
     def handleStruct(self, value):
         return StructMatcher(value)
 
-    def handleBlob(self, length):
-        # TODO
-        return WildcardMatcher()
+    def handleWildcard(self, name):
+        return WildcardMatcher(name)
 
 
-    def handleCall(self, functionName, args, ret):
-        call = CallMatcher(functionName, args, ret)
+    def handleCall(self, callNo, functionName, args, ret):
+        call = CallMatcher(callNo, functionName, args, ret)
         self.calls.append(call)
 
 
         self.calls.append(call)
 
 
@@ -693,12 +722,8 @@ class SrcTraceParser(TraceParser):
     def handleStruct(self, members):
         return dict(members)
 
     def handleStruct(self, members):
         return dict(members)
 
-    def handleBlob(self, length):
-        # TODO
-        return None
-
-    def handleCall(self, functionName, args, ret):
-        call = (functionName, args, ret)
+    def handleCall(self, callNo, functionName, args, ret):
+        call = (callNo, functionName, args, ret)
         self.calls.append(call)
 
 
         self.calls.append(call)
 
 
@@ -722,7 +747,12 @@ def main():
     sys.stdout.write(str(refTrace))
     srcParser = SrcTraceParser(open(args[1], 'rt'))
     srcTrace = srcParser.parse()
     sys.stdout.write(str(refTrace))
     srcParser = SrcTraceParser(open(args[1], 'rt'))
     srcTrace = srcParser.parse()
-    refTrace.match(srcTrace)
+    mo = refTrace.match(srcTrace)
+
+    paramNames = mo.params.keys()
+    paramNames.sort()
+    for paramName in paramNames:
+        print '%s = %r' % (paramName, mo.params[paramName])
 
 
 if __name__ == '__main__':
 
 
 if __name__ == '__main__':