From: José Fonseca Date: Thu, 22 Nov 2012 09:55:15 +0000 (+0000) Subject: Improve checker. X-Git-Url: https://git.cworth.org/git?p=apitrace-tests;a=commitdiff_plain;h=4918b9940bfff2cd37b9ff3aadd1e066fdb405bc Improve checker. --- diff --git a/checker.py b/checker.py old mode 100644 new mode 100755 index be1331d..c6eda92 --- a/checker.py +++ b/checker.py @@ -26,14 +26,27 @@ import sys +import optparse import re -class ValueMatcher: +class Matcher: def match(self, value): raise NotImplementedError + def _matchSequence(self, refValues, srcValues): + if not isinstance(srcValues, (list, tuple)): + return False + + if len(refValues) != len(srcValues): + return False + + for refValue, srcValue in zip(refValues, srcValues): + if not refValue.match(srcValue): + return False + return True + def __str__(self): raise NotImplementerError @@ -41,16 +54,16 @@ class ValueMatcher: return str(self) -class WildcardMatcher(ValueMatcher): +class WildcardMatcher(Matcher): def match(self, value): - return true + return True def __str__(self): return '*' -class LiteralValueMatcher(ValueMatcher): +class LiteralMatcher(Matcher): def __init__(self, refValue): self.refValue = refValue @@ -62,13 +75,16 @@ class LiteralValueMatcher(ValueMatcher): return repr(self.refValue) -class ApproxValueMatcher(ValueMatcher): +class ApproxMatcher(Matcher): def __init__(self, refValue, tolerance = 2**-23): self.refValue = refValue self.tolerance = tolerance def match(self, value): + if not isinstance(value, float): + return + error = abs(self.refValue - value) if self.refValue: error = error / self.refValue @@ -78,29 +94,31 @@ class ApproxValueMatcher(ValueMatcher): return repr(self.refValue) -class ArrayMatcher(ValueMatcher): +class BitmaskMatcher(Matcher): def __init__(self, refElements): self.refElements = refElements def match(self, value): - if not isinstance(value, list): - return False + return self._matchSequence(self.refElements, value) - if len(value) != len(self.refElements): - return False + def __str__(self): + return ' | '.join(map(str, self.refElements)) - for refElement, element in zip(self.refElements, value): - if not refElement.match(element): - return False - return True +class ArrayMatcher(Matcher): + + def __init__(self, refElements): + self.refElements = refElements + + def match(self, value): + return self._matchSequence(self.refElements, value) def __str__(self): return '{' + ', '.join(map(str, self.refElements)) + '}' -class StructMatcher(ValueMatcher): +class StructMatcher(Matcher): def __init__(self, refMembers): self.refMembers = refMembers @@ -128,48 +146,68 @@ class StructMatcher(ValueMatcher): return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}' -class CallMatcher: +class CallMatcher(Matcher): - def __init__(self, refFunctionName, refArgs, refRet = None): - self.refFunctionName = refFunctionName - self.refArgs = refArgs - self.refRet = refRet + def __init__(self, functionName, args, ret = None): + self.functionName = functionName + self.args = args + self.ret = ret - def match(self, functionName, args, ret = None): - if refFunctionName != functionName: - return False + def match(self, call): + srcFunctionName, srcArgs, srcRet = call - if len(self.refArgs) != len(args): + if self.functionName != srcFunctionName: return False - for (refArgName, refArg), (argName, arg) in zip(self.refArgs, args): - if not refArg.match(arg): - return False + refArgs = [value for name, value in self.args] + srcArgs = [value for name, value in srcArgs] - if self.refRet is None: - if ret is not None: + if not self._matchSequence(refArgs, srcArgs): + return False + + if self.ret is None: + if srcRet is not None: return False else: - if not self.refRet.match(ret): + if not self.ret.match(srcRet): return False return True def __str__(self): - s = self.refFunctionName - s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.refArgs]) + ')' - if self.refRet is not None: - s += ' = ' + str(self.refRet) + s = self.functionName + s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.args]) + ')' + if self.ret is not None: + s += ' = ' + str(self.ret) return s class TraceMatcher: - def __init__(self, refCalls): - self.refCalls = refCalls + def __init__(self, calls): + self.calls = calls + + def match(self, trace): + + srcCalls = iter(trace.calls) + for refCall in self.calls: + skippedSrcCalls = [] + while True: + try: + srcCall = srcCalls.next() + except StopIteration: + if skippedSrcCalls: + raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0])) + else: + raise Exception('missing call %s' % refCall) + if refCall.match(srcCall): + break + else: + skippedSrcCalls.append(srcCall) + return True def __str__(self): - return ''.join(['%s\n' % refCall for refCall in self.refCalls]) + return ''.join(['%s\n' % call for call in self.calls]) ####################################################################### @@ -348,7 +386,7 @@ class Parser: ####################################################################### -ID, NUMBER, HEXNUM, STRING, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, BLOB = xrange(13) +ID, NUMBER, HEXNUM, STRING, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(14) class CallScanner(Scanner): @@ -359,13 +397,13 @@ class CallScanner(Scanner): (SKIP, r'[ \t\f\r\n\v]+', False), # Alphanumeric IDs - (ID, r'[a-zA-Z_\x80-\xff][a-zA-Z0-9_\x80-\xff]*', True), + (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True), # Numeric IDs (HEXNUM, r'-?0x[0-9a-fA-F]+', False), # Numeric IDs - (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)', False), + (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)(?:[eE][-+][0-9]+)?', False), # String IDs (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False), @@ -383,6 +421,7 @@ class CallScanner(Scanner): ',': COMMA, '&': AMP, '=': EQUAL, + '|': VERT, } # literal table @@ -412,7 +451,7 @@ class CallLexer(Lexer): return type, text -class CallParser(Parser): +class TraceParser(Parser): def __init__(self, stream): lexer = CallLexer(fp = stream) @@ -424,6 +463,7 @@ class CallParser(Parser): def parse(self): while not self.eof(): self.parse_element() + return TraceMatcher(self.calls) def parse_element(self): if self.lookahead.type == PRAGMA: @@ -455,7 +495,7 @@ class CallParser(Parser): else: ret = None - return self.handleCall(callNo, functionName, args, ret) + self.handleCall(functionName, args, ret) def parse_pair(self): '''Parse a `name = value` pair.''' @@ -483,6 +523,18 @@ class CallParser(Parser): return name, value def parse_value(self): + value = self._parse_value() + if self.match(VERT): + flags = [value] + while self.match(VERT): + self.consume() + value = self._parse_value() + flags.append(value) + return self.handleBitmask(flags) + else: + return value + + def _parse_value(self): if self.match(AMP): self.consume() value = [self.parse_value()] @@ -515,7 +567,7 @@ class CallParser(Parser): self.consume(LPAREN) length = self.consume() self.consume(RPAREN) - return self.handleBlob() + return self.handleBlob(length) else: self.error() @@ -538,16 +590,62 @@ class CallParser(Parser): return elements def handleID(self, value): - return LiteralValueMatcher(value) + raise NotImplementedError def handleInt(self, value): - return LiteralValueMatcher(value) + raise NotImplementedError def handleFloat(self, value): - return ApproxValueMatcher(value) + raise NotImplementedError def handleString(self, value): - return LiteralValueMatcher(value) + raise NotImplementedError + + def handleBitmask(self, value): + raise NotImplementedError + + def handleArray(self, value): + raise NotImplementedError + + def handleStruct(self, value): + raise NotImplementedError + + def handleBlob(self, length): + raise NotImplementedError + # TODO + return WildcardMatcher() + + def handleCall(self, functionName, args, ret): + raise NotImplementedError + + def handlePragma(self, line): + pass + + +class RefTraceParser(TraceParser): + + def __init__(self, stream): + TraceParser.__init__(self, stream) + self.calls = [] + + def parse(self): + TraceParser.parse(self) + return TraceMatcher(self.calls) + + def handleID(self, value): + return LiteralMatcher(value) + + def handleInt(self, value): + return LiteralMatcher(value) + + def handleFloat(self, value): + return ApproxMatcher(value) + + def handleString(self, value): + return LiteralMatcher(value) + + def handleBitmask(self, value): + return BitmaskMatcher(value) def handleArray(self, value): return ArrayMatcher(value) @@ -555,26 +653,76 @@ class CallParser(Parser): def handleStruct(self, value): return StructMatcher(value) - def handleBlob(self, value): + def handleBlob(self, length): # TODO return WildcardMatcher() - def handleCall(self, callNo, functionName, args, ret): - matcher = CallMatcher(functionName, args, ret) + def handleCall(self, functionName, args, ret): + call = CallMatcher(functionName, args, ret) + self.calls.append(call) - if callNo is not None: - sys.stdout.write('%u ' % callNo) - sys.stdout.write(str(matcher)) - sys.stdout.write('\n') - def handlePragma(self, line): - sys.stdout.write(line) - sys.stdout.write('\n') +class SrcTraceParser(TraceParser): + + def __init__(self, stream): + TraceParser.__init__(self, stream) + self.calls = [] + + def parse(self): + TraceParser.parse(self) + return TraceMatcher(self.calls) + + def handleID(self, value): + return value + + def handleInt(self, value): + return int(value) + + def handleFloat(self, value): + return float(value) + + def handleString(self, value): + return value + + def handleBitmask(self, value): + return value + + def handleArray(self, elements): + return list(elements) + + def handleStruct(self, members): + return dict(members) + + def handleBlob(self, length): + # TODO + return None + + def handleCall(self, functionName, args, ret): + call = (functionName, args, ret) + self.calls.append(call) def main(): - parser = CallParser(sys.stdin) - parser.parse() + # Parse command line options + optparser = optparse.OptionParser( + usage='\n\t%prog [OPTIONS] REF_TRACE SRC_TRACE', + version='%%prog') + optparser.add_option( + '-v', '--verbose', + action="store_true", + dest="verbose", default=False, + help="verbose output") + (options, args) = optparser.parse_args(sys.argv[1:]) + + if len(args) != 2: + optparser.error('wrong number of arguments') + + refParser = RefTraceParser(open(args[0], 'rt')) + refTrace = refParser.parse() + sys.stdout.write(str(refTrace)) + srcParser = SrcTraceParser(open(args[1], 'rt')) + srcTrace = srcParser.parse() + refTrace.match(srcTrace) if __name__ == '__main__':