From: José Fonseca Date: Thu, 22 Nov 2012 22:03:17 +0000 (+0000) Subject: Further improvements to checker. X-Git-Url: https://git.cworth.org/git?p=apitrace-tests;a=commitdiff_plain;h=09690a098cc072323bc1364119be0b9e1ce60d11 Further improvements to checker. --- diff --git a/checker.py b/checker.py index c6eda92..08ef9d3 100755 --- a/checker.py +++ b/checker.py @@ -30,12 +30,18 @@ import optparse import re +class MatchObject: + + def __init__(self): + self.params = {} + + class Matcher: - def match(self, value): + def match(self, value, mo): raise NotImplementedError - def _matchSequence(self, refValues, srcValues): + def _matchSequence(self, refValues, srcValues, mo): if not isinstance(srcValues, (list, tuple)): return False @@ -43,7 +49,7 @@ class Matcher: return False for refValue, srcValue in zip(refValues, srcValues): - if not refValue.match(srcValue): + if not refValue.match(srcValue, mo): return False return True @@ -56,11 +62,21 @@ class Matcher: class WildcardMatcher(Matcher): - def match(self, value): + def __init__(self, name = ''): + self.name = name + + def match(self, value, mo): + if self.name: + try: + refValue = mo.params[self.name] + except KeyError: + mo.params[self.name] = value + else: + return refValue == value return True def __str__(self): - return '*' + return '<' + self.name + '>' class LiteralMatcher(Matcher): @@ -68,7 +84,7 @@ class LiteralMatcher(Matcher): def __init__(self, refValue): self.refValue = refValue - def match(self, value): + def match(self, value, mo): return self.refValue == value def __str__(self): @@ -81,7 +97,7 @@ class ApproxMatcher(Matcher): self.refValue = refValue self.tolerance = tolerance - def match(self, value): + def match(self, value, mo): if not isinstance(value, float): return @@ -99,8 +115,8 @@ class BitmaskMatcher(Matcher): def __init__(self, refElements): self.refElements = refElements - def match(self, value): - return self._matchSequence(self.refElements, value) + def match(self, value, mo): + return self._matchSequence(self.refElements, value, mo) def __str__(self): return ' | '.join(map(str, self.refElements)) @@ -111,8 +127,8 @@ class ArrayMatcher(Matcher): def __init__(self, refElements): self.refElements = refElements - def match(self, value): - return self._matchSequence(self.refElements, value) + def match(self, value, mo): + return self._matchSequence(self.refElements, value, mo) def __str__(self): return '{' + ', '.join(map(str, self.refElements)) + '}' @@ -123,7 +139,7 @@ class StructMatcher(Matcher): def __init__(self, refMembers): self.refMembers = refMembers - def match(self, value): + def match(self, value, mo): if not isinstance(value, dict): return False @@ -136,25 +152,25 @@ class StructMatcher(Matcher): except KeyError: return False else: - if not refMember.match(member): + if not refMember.match(member, mo): return False return True def __str__(self): - print self.refMembers return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}' class CallMatcher(Matcher): - def __init__(self, functionName, args, ret = None): + def __init__(self, callNo, functionName, args, ret): + self.callNo = callNo self.functionName = functionName self.args = args self.ret = ret - def match(self, call): - srcFunctionName, srcArgs, srcRet = call + def match(self, call, mo): + callNo, srcFunctionName, srcArgs, srcRet = call if self.functionName != srcFunctionName: return False @@ -162,14 +178,18 @@ class CallMatcher(Matcher): refArgs = [value for name, value in self.args] srcArgs = [value for name, value in srcArgs] - if not self._matchSequence(refArgs, srcArgs): + if not self._matchSequence(refArgs, srcArgs, mo): return False if self.ret is None: if srcRet is not None: return False else: - if not self.ret.match(srcRet): + if not self.ret.match(srcRet, mo): + return False + + if self.callNo is not None: + if not self.callNo.match(callNo, mo): return False return True @@ -188,7 +208,7 @@ class TraceMatcher: self.calls = calls def match(self, trace): - + mo = MatchObject() srcCalls = iter(trace.calls) for refCall in self.calls: skippedSrcCalls = [] @@ -200,11 +220,11 @@ class TraceMatcher: raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0])) else: raise Exception('missing call %s' % refCall) - if refCall.match(srcCall): + if refCall.match(srcCall, mo): break else: skippedSrcCalls.append(srcCall) - return True + return mo def __str__(self): return ''.join(['%s\n' % call for call in self.calls]) @@ -386,7 +406,7 @@ class Parser: ####################################################################### -ID, NUMBER, HEXNUM, STRING, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(14) +ID, NUMBER, HEXNUM, STRING, WILDCARD, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(15) class CallScanner(Scanner): @@ -408,6 +428,9 @@ class CallScanner(Scanner): # String IDs (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False), + # Wildcard + (WILDCARD, r'<[^>]*>', False), + # Pragma (PRAGMA, r'#[^\r\n]*', False), ] @@ -477,11 +500,13 @@ class TraceParser(Parser): while self.lookahead.type == PRAGMA: # TODO token = self.consume() - print token.text if self.lookahead.type == NUMBER: token = self.consume() - callNo = int(token.text) + callNo = self.handleInt(int(token.text)) + elif self.lookahead.type == WILDCARD: + token = self.consume() + callNo = self.handleWildcard((token.text[1:-1])) else: callNo = None @@ -495,7 +520,7 @@ class TraceParser(Parser): else: ret = None - self.handleCall(functionName, args, ret) + self.handleCall(callNo, functionName, args, ret) def parse_pair(self): '''Parse a `name = value` pair.''' @@ -565,9 +590,13 @@ class TraceParser(Parser): elif self.match(BLOB): token = self.consume() self.consume(LPAREN) - length = self.consume() + token = self.consume() + length = int(token.text) self.consume(RPAREN) return self.handleBlob(length) + elif self.match(WILDCARD): + token = self.consume() + return self.handleWildcard(token.text[1:-1]) else: self.error() @@ -611,11 +640,12 @@ class TraceParser(Parser): raise NotImplementedError def handleBlob(self, length): + return self.handleID('blob(%u)' % length) + + def handleWildcard(self, name): raise NotImplementedError - # TODO - return WildcardMatcher() - def handleCall(self, functionName, args, ret): + def handleCall(self, callNo, functionName, args, ret): raise NotImplementedError def handlePragma(self, line): @@ -653,12 +683,11 @@ class RefTraceParser(TraceParser): def handleStruct(self, value): return StructMatcher(value) - def handleBlob(self, length): - # TODO - return WildcardMatcher() + def handleWildcard(self, name): + return WildcardMatcher(name) - def handleCall(self, functionName, args, ret): - call = CallMatcher(functionName, args, ret) + def handleCall(self, callNo, functionName, args, ret): + call = CallMatcher(callNo, functionName, args, ret) self.calls.append(call) @@ -693,12 +722,8 @@ class SrcTraceParser(TraceParser): def handleStruct(self, members): return dict(members) - def handleBlob(self, length): - # TODO - return None - - def handleCall(self, functionName, args, ret): - call = (functionName, args, ret) + def handleCall(self, callNo, functionName, args, ret): + call = (callNo, functionName, args, ret) self.calls.append(call) @@ -722,7 +747,12 @@ def main(): sys.stdout.write(str(refTrace)) srcParser = SrcTraceParser(open(args[1], 'rt')) srcTrace = srcParser.parse() - refTrace.match(srcTrace) + mo = refTrace.match(srcTrace) + + paramNames = mo.params.keys() + paramNames.sort() + for paramName in paramNames: + print '%s = %r' % (paramName, mo.params[paramName]) if __name__ == '__main__':