]> git.cworth.org Git - apitrace-tests/blob - tracematch.py
Drop pragmas.
[apitrace-tests] / tracematch.py
1 #!/usr/bin/env python
2 ##########################################################################
3 #
4 # Copyright 2008-2012 Jose Fonseca
5 # All Rights Reserved.
6 #
7 # Permission is hereby granted, free of charge, to any person obtaining a copy
8 # of this software and associated documentation files (the "Software"), to deal
9 # in the Software without restriction, including without limitation the rights
10 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 # copies of the Software, and to permit persons to whom the Software is
12 # furnished to do so, subject to the following conditions:
13 #
14 # The above copyright notice and this permission notice shall be included in
15 # all copies or substantial portions of the Software.
16 #
17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 # THE SOFTWARE.
24 #
25 ##########################################################################/
26
27
28 import sys
29 import optparse
30 import os
31 import re
32 import subprocess
33
34
35 class MatchObject:
36
37     def __init__(self):
38         self.params = {}
39
40
41 class Matcher:
42
43     def match(self, value, mo):
44         raise NotImplementedError
45
46     def _matchSequence(self, refValues, srcValues, mo):
47         if not isinstance(srcValues, (list, tuple)):
48             return False
49
50         if len(refValues) != len(srcValues):
51             return False
52
53         for refValue, srcValue in zip(refValues, srcValues):
54             if not refValue.match(srcValue, mo):
55                 return False
56         return True
57
58     def __str__(self):
59         raise NotImplementerError
60
61     def __repr__(self):
62         return str(self)
63
64
65 class WildcardMatcher(Matcher):
66
67     def __init__(self, name = ''):
68         self.name = name
69
70     def match(self, value, mo):
71         if self.name:
72             try:
73                 refValue = mo.params[self.name]
74             except KeyError:
75                 mo.params[self.name] = value
76             else:
77                 return refValue == value
78         return True
79
80     def __str__(self):
81         return '<' + self.name + '>'
82
83
84 class LiteralMatcher(Matcher):
85
86     def __init__(self, refValue):
87         self.refValue = refValue
88
89     def match(self, value, mo):
90         return self.refValue == value
91
92     def __str__(self):
93         return repr(self.refValue)
94
95
96 class ApproxMatcher(Matcher):
97
98     def __init__(self, refValue, tolerance = 2**-23):
99         self.refValue = refValue
100         self.tolerance = tolerance
101
102     def match(self, value, mo):
103         if not isinstance(value, float):
104             return 
105
106         error = abs(self.refValue - value)
107         if self.refValue:
108             error = error / self.refValue
109         return error <= self.tolerance
110
111     def __str__(self):
112         return repr(self.refValue)
113
114
115 class BitmaskMatcher(Matcher):
116
117     def __init__(self, refElements):
118         self.refElements = refElements
119
120     def match(self, value, mo):
121         return self._matchSequence(self.refElements, value, mo)
122
123     def __str__(self):
124         return ' | '.join(map(str, self.refElements))
125
126
127 class OffsetMatcher(Matcher):
128
129     def __init__(self, refValue, offset):
130         self.refValue = refValue
131         self.offset = offset
132
133     def match(self, value, mo):
134         return self.refValue.match(value - self.offset, mo)
135
136     def __str__(self):
137         return '%s + %i' % (self.refValue, self.offset)
138
139
140 class ArrayMatcher(Matcher):
141
142     def __init__(self, refElements):
143         self.refElements = refElements
144
145     def match(self, value, mo):
146         return self._matchSequence(self.refElements, value, mo)
147
148     def __str__(self):
149         return '{' + ', '.join(map(str, self.refElements)) + '}'
150
151
152 class StructMatcher(Matcher):
153
154     def __init__(self, refMembers):
155         self.refMembers = refMembers
156
157     def match(self, value, mo):
158         if not isinstance(value, dict):
159             return False
160
161         if len(value) != len(self.refMembers):
162             return False
163
164         for name, refMember in self.refMembers.iteritems():
165             try:
166                 member = value[name]
167             except KeyError:
168                 return False
169             else:
170                 if not refMember.match(member, mo):
171                     return False
172
173         return True
174
175     def __str__(self):
176         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
177
178
179 class CallMatcher(Matcher):
180
181     def __init__(self, callNo, functionName, args, ret):
182         self.callNo = callNo
183         self.functionName = functionName
184         self.args = args
185         self.ret = ret
186
187     def match(self, call, mo):
188         callNo, srcFunctionName, srcArgs, srcRet = call
189
190         if self.functionName != srcFunctionName:
191             return False
192
193         refArgs = [value for name, value in self.args]
194         srcArgs = [value for name, value in srcArgs]
195
196         if not self._matchSequence(refArgs, srcArgs, mo):
197             return False
198
199         if self.ret is None:
200             if srcRet is not None:
201                 return False
202         else:
203             if not self.ret.match(srcRet, mo):
204                 return False
205
206         if self.callNo is not None:
207             if not self.callNo.match(callNo, mo):
208                 return False
209
210         return True
211
212     def __str__(self):
213         s = self.functionName
214         s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.args]) + ')'
215         if self.ret is not None:
216             s += ' = ' + str(self.ret)
217         return s
218
219
220 class TraceMismatch(Exception):
221
222     pass
223
224
225 class TraceMatcher:
226
227     def __init__(self, calls):
228         self.calls = calls
229
230     def match(self, calls, verbose = False):
231         mo = MatchObject()
232         srcCalls = iter(calls)
233         for refCall in self.calls:
234             if verbose:
235                 print refCall
236             skippedSrcCalls = []
237             while True:
238                 try:
239                     srcCall = srcCalls.next()
240                 except StopIteration:
241                     if skippedSrcCalls:
242                         raise TraceMismatch('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
243                     else:
244                         raise TraceMismatch('missing call %s' % refCall)
245                 if verbose:
246                     print '\t%s %s%r = %r' % srcCall
247                 if refCall.match(srcCall, mo):
248                     break
249                 else:
250                     skippedSrcCalls.append(srcCall)
251         return mo
252
253     def __str__(self):
254         return ''.join(['%s\n' % call for call in self.calls])
255
256
257 #######################################################################
258
259 EOF = -1
260 SKIP = -2
261
262
263 class ParseError(Exception):
264
265     def __init__(self, msg=None, filename=None, line=None, col=None):
266         self.msg = msg
267         self.filename = filename
268         self.line = line
269         self.col = col
270
271     def __str__(self):
272         return ':'.join([str(part) for part in (self.filename, self.line, self.col, self.msg) if part != None])
273         
274
275 class Scanner:
276     """Stateless scanner."""
277
278     # should be overriden by derived classes
279     tokens = []
280     symbols = {}
281     literals = {}
282     ignorecase = False
283
284     def __init__(self):
285         flags = re.DOTALL
286         if self.ignorecase:
287             flags |= re.IGNORECASE
288         self.tokens_re = re.compile(
289             '|'.join(['(' + regexp + ')' for type, regexp, test_lit in self.tokens]),
290              flags
291         )
292
293     def next(self, buf, pos):
294         if pos >= len(buf):
295             return EOF, '', pos
296         mo = self.tokens_re.match(buf, pos)
297         if mo:
298             text = mo.group()
299             type, regexp, test_lit = self.tokens[mo.lastindex - 1]
300             pos = mo.end()
301             if test_lit:
302                 type = self.literals.get(text, type)
303             return type, text, pos
304         else:
305             c = buf[pos]
306             return self.symbols.get(c, None), c, pos + 1
307
308
309 class Token:
310
311     def __init__(self, type, text, line, col):
312         self.type = type
313         self.text = text
314         self.line = line
315         self.col = col
316
317
318 class Lexer:
319
320     # should be overriden by derived classes
321     scanner = None
322     tabsize = 8
323
324     newline_re = re.compile(r'\r\n?|\n')
325
326     def __init__(self, buf = None, pos = 0, filename = None, fp = None):
327         if fp is not None:
328             try:
329                 fileno = fp.fileno()
330                 length = os.path.getsize(fp.name)
331                 import mmap
332             except:
333                 # read whole file into memory
334                 buf = fp.read()
335                 pos = 0
336             else:
337                 # map the whole file into memory
338                 if length:
339                     # length must not be zero
340                     buf = mmap.mmap(fileno, length, access = mmap.ACCESS_READ)
341                     pos = os.lseek(fileno, 0, 1)
342                 else:
343                     buf = ''
344                     pos = 0
345
346             if filename is None:
347                 try:
348                     filename = fp.name
349                 except AttributeError:
350                     filename = None
351
352         self.buf = buf
353         self.pos = pos
354         self.line = 1
355         self.col = 1
356         self.filename = filename
357
358     def next(self):
359         while True:
360             # save state
361             pos = self.pos
362             line = self.line
363             col = self.col
364
365             type, text, endpos = self.scanner.next(self.buf, pos)
366             assert pos + len(text) == endpos
367             self.consume(text)
368             type, text = self.filter(type, text)
369             self.pos = endpos
370
371             if type == SKIP:
372                 continue
373             elif type is None:
374                 msg = 'unexpected char '
375                 if text >= ' ' and text <= '~':
376                     msg += "'%s'" % text
377                 else:
378                     msg += "0x%X" % ord(text)
379                 raise ParseError(msg, self.filename, line, col)
380             else:
381                 break
382         return Token(type = type, text = text, line = line, col = col)
383
384     def consume(self, text):
385         # update line number
386         pos = 0
387         for mo in self.newline_re.finditer(text, pos):
388             self.line += 1
389             self.col = 1
390             pos = mo.end()
391
392         # update column number
393         while True:
394             tabpos = text.find('\t', pos)
395             if tabpos == -1:
396                 break
397             self.col += tabpos - pos
398             self.col = ((self.col - 1)//self.tabsize + 1)*self.tabsize + 1
399             pos = tabpos + 1
400         self.col += len(text) - pos
401
402
403 class Parser:
404
405     def __init__(self, lexer):
406         self.lexer = lexer
407         self.lookahead = self.lexer.next()
408
409     def match(self, type):
410         return self.lookahead.type == type
411
412     def skip(self, type):
413         while not self.match(type):
414             self.consume()
415
416     def error(self):
417         raise ParseError(
418             msg = 'unexpected token %r' % self.lookahead.text, 
419             filename = self.lexer.filename, 
420             line = self.lookahead.line, 
421             col = self.lookahead.col)
422
423     def consume(self, type = None):
424         if type is not None and not self.match(type):
425             self.error()
426         token = self.lookahead
427         self.lookahead = self.lexer.next()
428         return token
429
430
431 #######################################################################
432
433 ID, NUMBER, HEXNUM, STRING, WILDCARD, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, PLUS, VERT, BLOB = xrange(15)
434
435
436 class CallScanner(Scanner):
437
438     # token regular expression table
439     tokens = [
440         # whitespace
441         (SKIP, r'[ \t\f\r\n\v]+', False),
442
443         # Alphanumeric IDs
444         (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True),
445
446         # Numeric IDs
447         (HEXNUM, r'-?0x[0-9a-fA-F]+', False),
448         
449         # Numeric IDs
450         (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)(?:[eE][-+][0-9]+)?', False),
451
452         # String IDs
453         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
454         
455         # Wildcard
456         (WILDCARD, r'<[^>]*>', False),
457     ]
458
459     # symbol table
460     symbols = {
461         '(': LPAREN,
462         ')': RPAREN,
463         '{': LCURLY,
464         '}': RCURLY,
465         ',': COMMA,
466         '&': AMP,
467         '=': EQUAL,
468         '+': PLUS,
469         '|': VERT,
470     }
471
472     # literal table
473     literals = {
474         'blob': BLOB
475     }
476
477
478 class CallLexer(Lexer):
479
480     scanner = CallScanner()
481
482     def filter(self, type, text):
483         if type == STRING:
484             text = text[1:-1]
485
486             # line continuations
487             text = text.replace('\\\r\n', '')
488             text = text.replace('\\\r', '')
489             text = text.replace('\\\n', '')
490             
491             # quotes
492             text = text.replace('\\"', '"')
493
494             type = ID
495
496         return type, text
497
498
499 class TraceParser(Parser):
500
501     def __init__(self, stream):
502         lexer = CallLexer(fp = stream)
503         Parser.__init__(self, lexer)
504
505     def eof(self):
506         return self.match(EOF)
507
508     def parse(self):
509         while not self.eof():
510             self.parse_call()
511         return TraceMatcher(self.calls)
512
513     def parse_call(self):
514         if self.lookahead.type == NUMBER:
515             token = self.consume()
516             callNo = self.handleInt(int(token.text))
517         elif self.lookahead.type == WILDCARD:
518             token = self.consume()
519             callNo = self.handleWildcard((token.text[1:-1]))
520         else:
521             callNo = None
522         
523         functionName = self.consume(ID).text
524
525         args = self.parse_sequence(LPAREN, RPAREN, self.parse_pair)
526
527         if self.match(EQUAL):
528             self.consume(EQUAL)
529             ret = self.parse_value()
530         else:
531             ret = None
532
533         self.handleCall(callNo, functionName, args, ret)
534
535     def parse_pair(self):
536         '''Parse a `name = value` pair.'''
537         name = self.consume(ID).text
538         self.consume(EQUAL)
539         value = self.parse_value()
540         return name, value
541
542     def parse_opt_pair(self):
543         '''Parse an optional `name = value` pair.'''
544         if self.match(ID):
545             name = self.consume(ID).text
546             if self.match(EQUAL):
547                 self.consume(EQUAL)
548                 value = self.parse_value()
549             else:
550                 value = name
551                 name = None
552         else:
553             name = None
554             value = self.parse_value()
555         if name is None:
556             return value
557         else:
558             return name, value
559
560     def parse_value(self):
561         value = self._parse_value()
562         if self.match(VERT):
563             flags = [value]
564             while self.match(VERT):
565                 self.consume()
566                 value = self._parse_value()
567                 flags.append(value)
568             return self.handleBitmask(flags)
569         elif self.match(PLUS):
570             self.consume()
571             if self.match(NUMBER):
572                 token = self.consume()
573                 offset = int(token.text)
574             elif self.match(HEXNUM):
575                 token = self.consume()
576                 offset = int(token.text, 16)
577             else:
578                 assert 0
579             return self.handleOffset(value, offset)
580         else:
581             return value
582
583     def _parse_value(self):
584         if self.match(AMP):
585             self.consume()
586             value = [self.parse_value()]
587             return self.handleArray(value)
588         elif self.match(ID):
589             token = self.consume()
590             value = token.text
591             return self.handleID(value)
592         elif self.match(STRING):
593             token = self.consume()
594             value = token.text
595             return self.handleString(value)
596         elif self.match(NUMBER):
597             token = self.consume()
598             if '.' in token.text:
599                 value = float(token.text)
600                 return self.handleFloat(value)
601             else:
602                 value = int(token.text)
603                 return self.handleInt(value)
604         elif self.match(HEXNUM):
605             token = self.consume()
606             value = int(token.text, 16)
607             return self.handleInt(value)
608         elif self.match(LCURLY):
609             value = self.parse_sequence(LCURLY, RCURLY, self.parse_opt_pair)
610             if len(value) and isinstance(value[0], tuple):
611                 value = dict(value)
612                 return self.handleStruct(value)
613             else:
614                 return self.handleArray(value)
615         elif self.match(BLOB):
616             token = self.consume()
617             self.consume(LPAREN)
618             token = self.consume()
619             length = int(token.text)
620             self.consume(RPAREN)
621             return self.handleBlob(length)
622         elif self.match(WILDCARD):
623             token = self.consume()
624             return self.handleWildcard(token.text[1:-1])
625         else:
626             self.error()
627
628     def parse_sequence(self, ltype, rtype, elementParser):
629         '''Parse a comma separated list'''
630
631         elements = []
632
633         self.consume(ltype)
634         sep = None
635         while not self.match(rtype):
636             if sep is None:
637                 sep = COMMA
638             else:
639                 self.consume(sep)
640             element = elementParser()
641             elements.append(element)
642         self.consume(rtype)
643
644         return elements
645     
646     def handleID(self, value):
647         raise NotImplementedError
648
649     def handleInt(self, value):
650         raise NotImplementedError
651
652     def handleFloat(self, value):
653         raise NotImplementedError
654
655     def handleString(self, value):
656         raise NotImplementedError
657
658     def handleBitmask(self, value):
659         raise NotImplementedError
660
661     def handleOffset(self, value, offset):
662         raise NotImplementedError
663
664     def handleArray(self, value):
665         raise NotImplementedError
666
667     def handleStruct(self, value):
668         raise NotImplementedError
669
670     def handleBlob(self, length):
671         return self.handleID('blob(%u)' % length)
672
673     def handleWildcard(self, name):
674         raise NotImplementedError
675
676     def handleCall(self, callNo, functionName, args, ret):
677         raise NotImplementedError
678
679
680 class RefTraceParser(TraceParser):
681
682     def __init__(self, fileName):
683         TraceParser.__init__(self, open(fileName, 'rt'))
684         self.calls = []
685
686     def parse(self):
687         TraceParser.parse(self)
688         return TraceMatcher(self.calls)
689
690     def handleID(self, value):
691         return LiteralMatcher(value)
692
693     def handleInt(self, value):
694         return LiteralMatcher(value)
695
696     def handleFloat(self, value):
697         return ApproxMatcher(value)
698
699     def handleString(self, value):
700         return LiteralMatcher(value)
701
702     def handleBitmask(self, value):
703         return BitmaskMatcher(value)
704
705     def handleOffset(self, value, offset):
706         return OffsetMatcher(value, offset)
707
708     def handleArray(self, value):
709         return ArrayMatcher(value)
710
711     def handleStruct(self, value):
712         return StructMatcher(value)
713
714     def handleWildcard(self, name):
715         return WildcardMatcher(name)
716
717     def handleCall(self, callNo, functionName, args, ret):
718         call = CallMatcher(callNo, functionName, args, ret)
719         self.calls.append(call)
720
721
722 class SrcTraceParser(TraceParser):
723
724     def __init__(self, stream):
725         TraceParser.__init__(self, stream)
726         self.calls = []
727
728     def parse(self):
729         TraceParser.parse(self)
730         return self.calls
731
732     def handleID(self, value):
733         return value
734
735     def handleInt(self, value):
736         return int(value)
737
738     def handleFloat(self, value):
739         return float(value)
740
741     def handleString(self, value):
742         return value
743
744     def handleBitmask(self, value):
745         return value
746
747     def handleArray(self, elements):
748         return list(elements)
749
750     def handleStruct(self, members):
751         return dict(members)
752
753     def handleCall(self, callNo, functionName, args, ret):
754         call = (callNo, functionName, args, ret)
755         self.calls.append(call)
756
757
758 def main():
759     # Parse command line options
760     optparser = optparse.OptionParser(
761         usage='\n\t%prog [OPTIONS] REF_TXT SRC_TRACE',
762         version='%%prog')
763     optparser.add_option(
764         '--apitrace', metavar='PROGRAM',
765         type='string', dest='apitrace', default=os.environ.get('APITRACE', 'apitrace'),
766         help='path to apitrace executable')
767     optparser.add_option(
768         '-v', '--verbose',
769         action="store_true",
770         dest="verbose", default=True,
771         help="verbose output")
772     (options, args) = optparser.parse_args(sys.argv[1:])
773
774     if len(args) != 2:
775         optparser.error('wrong number of arguments')
776
777     refFileName, srcFileName = args
778
779     refParser = RefTraceParser(refFileName)
780     refTrace = refParser.parse()
781     if options.verbose:
782         sys.stdout.write('// Reference\n')
783         sys.stdout.write(str(refTrace))
784         sys.stdout.write('\n')
785
786     if srcFileName.endswith('.trace'):
787         cmd = [options.apitrace, 'dump', '--color=never', srcFileName]
788         p = subprocess.Popen(cmd, stdout=subprocess.PIPE)
789         srcStream = p.stdout
790     else:
791         srcStream = open(srcFileName, 'rt')
792     srcParser = SrcTraceParser(srcStream)
793     srcTrace = srcParser.parse()
794     if options.verbose:
795         sys.stdout.write('// Source\n')
796         sys.stdout.write(''.join(['%s %s%r = %r\n' % call for call in srcTrace]))
797         sys.stdout.write('\n')
798
799     if options.verbose:
800         sys.stdout.write('// Matching\n')
801     mo = refTrace.match(srcTrace, options.verbose)
802     if options.verbose:
803         sys.stdout.write('\n')
804
805     if options.verbose:
806         sys.stdout.write('// Parameters\n')
807         paramNames = mo.params.keys()
808         paramNames.sort()
809         for paramName in paramNames:
810             print '%s = %r' % (paramName, mo.params[paramName])
811
812
813 if __name__ == '__main__':
814     main()