]> git.cworth.org Git - apitrace-tests/blob - tracematch.py
apps/d3d11: Comprehensive test.
[apitrace-tests] / tracematch.py
1 #!/usr/bin/env python
2 ##########################################################################
3 #
4 # Copyright 2008-2012 Jose Fonseca
5 # All Rights Reserved.
6 #
7 # Permission is hereby granted, free of charge, to any person obtaining a copy
8 # of this software and associated documentation files (the "Software"), to deal
9 # in the Software without restriction, including without limitation the rights
10 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 # copies of the Software, and to permit persons to whom the Software is
12 # furnished to do so, subject to the following conditions:
13 #
14 # The above copyright notice and this permission notice shall be included in
15 # all copies or substantial portions of the Software.
16 #
17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 # THE SOFTWARE.
24 #
25 ##########################################################################/
26
27
28 import sys
29 import optparse
30 import os
31 import re
32 import subprocess
33
34
35 class MatchObject:
36
37     def __init__(self):
38         self.params = {}
39
40
41 class Matcher:
42
43     def match(self, value, mo):
44         raise NotImplementedError
45
46     def _matchSequence(self, refValues, srcValues, mo):
47         if not isinstance(srcValues, (list, tuple)):
48             return False
49
50         if len(refValues) != len(srcValues):
51             return False
52
53         for refValue, srcValue in zip(refValues, srcValues):
54             if not refValue.match(srcValue, mo):
55                 return False
56         return True
57
58     def __str__(self):
59         raise NotImplementerError
60
61     def __repr__(self):
62         return str(self)
63
64
65 class WildcardMatcher(Matcher):
66
67     def __init__(self, name = ''):
68         self.name = name
69
70     def match(self, value, mo):
71         if self.name:
72             try:
73                 refValue = mo.params[self.name]
74             except KeyError:
75                 mo.params[self.name] = value
76             else:
77                 return refValue == value
78         return True
79
80     def __str__(self):
81         return '<' + self.name + '>'
82
83
84 class LiteralMatcher(Matcher):
85
86     def __init__(self, refValue):
87         self.refValue = refValue
88
89     def match(self, value, mo):
90         return self.refValue == value
91
92     def __str__(self):
93         return repr(self.refValue)
94
95
96 class ApproxMatcher(Matcher):
97
98     def __init__(self, refValue, tolerance = 2**-23):
99         self.refValue = refValue
100         self.tolerance = tolerance
101
102     def match(self, value, mo):
103         if not isinstance(value, float):
104             return 
105
106         error = abs(self.refValue - value)
107         if self.refValue:
108             error = error / self.refValue
109         return error <= self.tolerance
110
111     def __str__(self):
112         return repr(self.refValue)
113
114
115 class BitmaskMatcher(Matcher):
116
117     def __init__(self, refElements):
118         self.refElements = refElements
119
120     def match(self, value, mo):
121         return self._matchSequence(self.refElements, value, mo)
122
123     def __str__(self):
124         return ' | '.join(map(str, self.refElements))
125
126
127 class OffsetMatcher(Matcher):
128
129     def __init__(self, refValue, offset):
130         self.refValue = refValue
131         self.offset = offset
132
133     def match(self, value, mo):
134         return self.refValue.match(value - self.offset, mo)
135
136     def __str__(self):
137         return '%s + %i' % (self.refValue, self.offset)
138
139
140 class ArrayMatcher(Matcher):
141
142     def __init__(self, refElements):
143         self.refElements = refElements
144
145     def match(self, value, mo):
146         return self._matchSequence(self.refElements, value, mo)
147
148     def __str__(self):
149         return '{' + ', '.join(map(str, self.refElements)) + '}'
150
151
152 class StructMatcher(Matcher):
153
154     def __init__(self, refMembers):
155         self.refMembers = refMembers
156
157     def match(self, value, mo):
158         if not isinstance(value, dict):
159             return False
160
161         if len(value) != len(self.refMembers):
162             return False
163
164         for name, refMember in self.refMembers.iteritems():
165             try:
166                 member = value[name]
167             except KeyError:
168                 return False
169             else:
170                 if not refMember.match(member, mo):
171                     return False
172
173         return True
174
175     def __str__(self):
176         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
177
178
179 class CallMatcher(Matcher):
180
181     def __init__(self, callNo, functionName, args, ret):
182         self.callNo = callNo
183         self.functionName = functionName
184         self.args = args
185         self.ret = ret
186
187     def match(self, call, mo):
188         callNo, srcFunctionName, srcArgs, srcRet = call
189
190         if self.functionName != srcFunctionName:
191             return False
192
193         refArgs = [value for name, value in self.args]
194         srcArgs = [value for name, value in srcArgs]
195
196         if not self._matchSequence(refArgs, srcArgs, mo):
197             return False
198
199         if self.ret is None:
200             if srcRet is not None:
201                 return False
202         else:
203             if not self.ret.match(srcRet, mo):
204                 return False
205
206         if self.callNo is not None:
207             if not self.callNo.match(callNo, mo):
208                 return False
209
210         return True
211
212     def __str__(self):
213         s = self.functionName
214         s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.args]) + ')'
215         if self.ret is not None:
216             s += ' = ' + str(self.ret)
217         return s
218
219
220 class TraceMismatch(Exception):
221
222     pass
223
224
225 class TraceMatcher:
226
227     def __init__(self, calls):
228         self.calls = calls
229
230     def match(self, calls, verbose = False):
231         mo = MatchObject()
232         srcCalls = iter(calls)
233         for refCall in self.calls:
234             if verbose:
235                 print refCall
236             skippedSrcCalls = []
237             while True:
238                 try:
239                     srcCall = srcCalls.next()
240                 except StopIteration:
241                     if skippedSrcCalls:
242                         raise TraceMismatch('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
243                     else:
244                         raise TraceMismatch('missing call %s' % refCall)
245                 if verbose:
246                     print '\t%s %s%r = %r' % srcCall
247                 if refCall.match(srcCall, mo):
248                     break
249                 else:
250                     skippedSrcCalls.append(srcCall)
251         return mo
252
253     def __str__(self):
254         return ''.join(['%s\n' % call for call in self.calls])
255
256
257 #######################################################################
258
259 EOF = -1
260 SKIP = -2
261
262
263 class ParseError(Exception):
264
265     def __init__(self, msg=None, filename=None, line=None, col=None):
266         self.msg = msg
267         self.filename = filename
268         self.line = line
269         self.col = col
270
271     def __str__(self):
272         return ':'.join([str(part) for part in (self.filename, self.line, self.col, self.msg) if part != None])
273         
274
275 class Scanner:
276     """Stateless scanner."""
277
278     # should be overriden by derived classes
279     tokens = []
280     symbols = {}
281     literals = {}
282     ignorecase = False
283
284     def __init__(self):
285         flags = re.DOTALL
286         if self.ignorecase:
287             flags |= re.IGNORECASE
288         self.tokens_re = re.compile(
289             '|'.join(['(' + regexp + ')' for type, regexp, test_lit in self.tokens]),
290              flags
291         )
292
293     def next(self, buf, pos):
294         if pos >= len(buf):
295             return EOF, '', pos
296         mo = self.tokens_re.match(buf, pos)
297         if mo:
298             text = mo.group()
299             type, regexp, test_lit = self.tokens[mo.lastindex - 1]
300             pos = mo.end()
301             if test_lit:
302                 type = self.literals.get(text, type)
303             return type, text, pos
304         else:
305             c = buf[pos]
306             return self.symbols.get(c, None), c, pos + 1
307
308
309 class Token:
310
311     def __init__(self, type, text, line, col):
312         self.type = type
313         self.text = text
314         self.line = line
315         self.col = col
316
317
318 class Lexer:
319
320     # should be overriden by derived classes
321     scanner = None
322     tabsize = 8
323
324     newline_re = re.compile(r'\r\n?|\n')
325
326     def __init__(self, buf = None, pos = 0, filename = None, fp = None):
327         if fp is not None:
328             try:
329                 fileno = fp.fileno()
330                 length = os.path.getsize(fp.name)
331                 import mmap
332             except:
333                 # read whole file into memory
334                 buf = fp.read()
335                 pos = 0
336             else:
337                 # map the whole file into memory
338                 if length:
339                     # length must not be zero
340                     buf = mmap.mmap(fileno, length, access = mmap.ACCESS_READ)
341                     pos = os.lseek(fileno, 0, 1)
342                 else:
343                     buf = ''
344                     pos = 0
345
346             if filename is None:
347                 try:
348                     filename = fp.name
349                 except AttributeError:
350                     filename = None
351
352         self.buf = buf
353         self.pos = pos
354         self.line = 1
355         self.col = 1
356         self.filename = filename
357
358     def next(self):
359         while True:
360             # save state
361             pos = self.pos
362             line = self.line
363             col = self.col
364
365             type, text, endpos = self.scanner.next(self.buf, pos)
366             assert pos + len(text) == endpos
367             self.consume(text)
368             type, text = self.filter(type, text)
369             self.pos = endpos
370
371             if type == SKIP:
372                 continue
373             elif type is None:
374                 msg = 'unexpected char '
375                 if text >= ' ' and text <= '~':
376                     msg += "'%s'" % text
377                 else:
378                     msg += "0x%X" % ord(text)
379                 raise ParseError(msg, self.filename, line, col)
380             else:
381                 break
382         return Token(type = type, text = text, line = line, col = col)
383
384     def consume(self, text):
385         # update line number
386         pos = 0
387         for mo in self.newline_re.finditer(text, pos):
388             self.line += 1
389             self.col = 1
390             pos = mo.end()
391
392         # update column number
393         while True:
394             tabpos = text.find('\t', pos)
395             if tabpos == -1:
396                 break
397             self.col += tabpos - pos
398             self.col = ((self.col - 1)//self.tabsize + 1)*self.tabsize + 1
399             pos = tabpos + 1
400         self.col += len(text) - pos
401
402
403 class Parser:
404
405     def __init__(self, lexer):
406         self.lexer = lexer
407         self.lookahead = self.lexer.next()
408
409     def match(self, type):
410         return self.lookahead.type == type
411
412     def skip(self, type):
413         while not self.match(type):
414             self.consume()
415
416     def error(self):
417         raise ParseError(
418             msg = 'unexpected token %r' % self.lookahead.text, 
419             filename = self.lexer.filename, 
420             line = self.lookahead.line, 
421             col = self.lookahead.col)
422
423     def consume(self, type = None):
424         if type is not None and not self.match(type):
425             self.error()
426         token = self.lookahead
427         self.lookahead = self.lexer.next()
428         return token
429
430
431 #######################################################################
432
433 ID, NUMBER, HEXNUM, STRING, WILDCARD, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, PLUS, VERT, BLOB = xrange(15)
434
435
436 class CallScanner(Scanner):
437
438     # token regular expression table
439     tokens = [
440         # whitespace
441         (SKIP, r'[ \t\f\r\n\v]+', False),
442
443         # comments
444         (SKIP, r'//[^\r\n]*', False),
445
446         # Alphanumeric IDs
447         (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True),
448
449         # Numeric IDs
450         (HEXNUM, r'-?0x[0-9a-fA-F]+', False),
451         
452         # Numeric IDs
453         (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)(?:[eE][-+][0-9]+)?', False),
454
455         # String IDs
456         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
457         
458         # Wildcard
459         (WILDCARD, r'<[^>]*>', False),
460     ]
461
462     # symbol table
463     symbols = {
464         '(': LPAREN,
465         ')': RPAREN,
466         '{': LCURLY,
467         '}': RCURLY,
468         ',': COMMA,
469         '&': AMP,
470         '=': EQUAL,
471         '+': PLUS,
472         '|': VERT,
473     }
474
475     # literal table
476     literals = {
477         'blob': BLOB
478     }
479
480
481 class CallLexer(Lexer):
482
483     scanner = CallScanner()
484
485     def filter(self, type, text):
486         if type == STRING:
487             text = text[1:-1]
488
489             # line continuations
490             text = text.replace('\\\r\n', '')
491             text = text.replace('\\\r', '')
492             text = text.replace('\\\n', '')
493             
494             # quotes
495             text = text.replace('\\"', '"')
496
497             type = ID
498
499         return type, text
500
501
502 class TraceParser(Parser):
503
504     def __init__(self, stream):
505         lexer = CallLexer(fp = stream)
506         Parser.__init__(self, lexer)
507
508     def eof(self):
509         return self.match(EOF)
510
511     def parse(self):
512         while not self.eof():
513             self.parse_call()
514         return TraceMatcher(self.calls)
515
516     def parse_call(self):
517         if self.lookahead.type == NUMBER:
518             token = self.consume()
519             callNo = self.handleInt(int(token.text))
520         elif self.lookahead.type == WILDCARD:
521             token = self.consume()
522             callNo = self.handleWildcard((token.text[1:-1]))
523         else:
524             callNo = None
525         
526         functionName = self.consume(ID).text
527
528         args = self.parse_sequence(LPAREN, RPAREN, self.parse_pair)
529
530         if self.match(EQUAL):
531             self.consume(EQUAL)
532             ret = self.parse_value()
533         else:
534             ret = None
535
536         self.handleCall(callNo, functionName, args, ret)
537
538     def parse_pair(self):
539         '''Parse a `name = value` pair.'''
540         name = self.consume(ID).text
541         self.consume(EQUAL)
542         value = self.parse_value()
543         return name, value
544
545     def parse_opt_pair(self):
546         '''Parse an optional `name = value` pair.'''
547         if self.match(ID):
548             token = self.consume(ID)
549             if self.match(EQUAL):
550                 self.consume(EQUAL)
551                 name = token.text
552                 value = self.parse_value()
553             else:
554                 name = None
555                 value = self.handleID(token.text)
556         else:
557             name = None
558             value = self.parse_value()
559         if name is None:
560             return value
561         else:
562             return name, value
563
564     def parse_value(self):
565         value = self._parse_value()
566         if self.match(VERT):
567             flags = [value]
568             while self.match(VERT):
569                 self.consume()
570                 value = self._parse_value()
571                 flags.append(value)
572             return self.handleBitmask(flags)
573         elif self.match(PLUS):
574             self.consume()
575             if self.match(NUMBER):
576                 token = self.consume()
577                 offset = int(token.text)
578             elif self.match(HEXNUM):
579                 token = self.consume()
580                 offset = int(token.text, 16)
581             else:
582                 assert 0
583             return self.handleOffset(value, offset)
584         else:
585             return value
586
587     def _parse_value(self):
588         if self.match(AMP):
589             self.consume()
590             value = [self.parse_value()]
591             return self.handleArray(value)
592         elif self.match(ID):
593             token = self.consume()
594             value = token.text
595             return self.handleID(value)
596         elif self.match(STRING):
597             token = self.consume()
598             value = token.text
599             return self.handleString(value)
600         elif self.match(NUMBER):
601             token = self.consume()
602             try:
603                 value = int(token.text)
604             except ValueError:
605                 value = float(token.text)
606                 return self.handleFloat(value)
607             else:
608                 return self.handleInt(value)
609         elif self.match(HEXNUM):
610             token = self.consume()
611             value = int(token.text, 16)
612             return self.handleInt(value)
613         elif self.match(LCURLY):
614             value = self.parse_sequence(LCURLY, RCURLY, self.parse_opt_pair)
615             if len(value) and isinstance(value[0], tuple):
616                 value = dict(value)
617                 return self.handleStruct(value)
618             else:
619                 return self.handleArray(value)
620         elif self.match(BLOB):
621             token = self.consume()
622             self.consume(LPAREN)
623             token = self.consume()
624             length = int(token.text)
625             self.consume(RPAREN)
626             return self.handleBlob(length)
627         elif self.match(WILDCARD):
628             token = self.consume()
629             return self.handleWildcard(token.text[1:-1])
630         else:
631             self.error()
632
633     def parse_sequence(self, ltype, rtype, elementParser):
634         '''Parse a comma separated list'''
635
636         elements = []
637
638         self.consume(ltype)
639         sep = None
640         while not self.match(rtype):
641             if sep is None:
642                 sep = COMMA
643             else:
644                 self.consume(sep)
645             element = elementParser()
646             elements.append(element)
647         self.consume(rtype)
648
649         return elements
650     
651     def handleID(self, value):
652         raise NotImplementedError
653
654     def handleInt(self, value):
655         raise NotImplementedError
656
657     def handleFloat(self, value):
658         raise NotImplementedError
659
660     def handleString(self, value):
661         raise NotImplementedError
662
663     def handleBitmask(self, value):
664         raise NotImplementedError
665
666     def handleOffset(self, value, offset):
667         raise NotImplementedError
668
669     def handleArray(self, value):
670         raise NotImplementedError
671
672     def handleStruct(self, value):
673         raise NotImplementedError
674
675     def handleBlob(self, length):
676         return self.handleID('blob(%u)' % length)
677
678     def handleWildcard(self, name):
679         raise NotImplementedError
680
681     def handleCall(self, callNo, functionName, args, ret):
682         raise NotImplementedError
683
684
685 class RefTraceParser(TraceParser):
686
687     def __init__(self, fileName):
688         TraceParser.__init__(self, open(fileName, 'rt'))
689         self.calls = []
690
691     def parse(self):
692         TraceParser.parse(self)
693         return TraceMatcher(self.calls)
694
695     def handleID(self, value):
696         return LiteralMatcher(value)
697
698     def handleInt(self, value):
699         return LiteralMatcher(value)
700
701     def handleFloat(self, value):
702         return ApproxMatcher(value)
703
704     def handleString(self, value):
705         return LiteralMatcher(value)
706
707     def handleBitmask(self, value):
708         return BitmaskMatcher(value)
709
710     def handleOffset(self, value, offset):
711         return OffsetMatcher(value, offset)
712
713     def handleArray(self, value):
714         return ArrayMatcher(value)
715
716     def handleStruct(self, value):
717         return StructMatcher(value)
718
719     def handleWildcard(self, name):
720         return WildcardMatcher(name)
721
722     def handleCall(self, callNo, functionName, args, ret):
723         call = CallMatcher(callNo, functionName, args, ret)
724         self.calls.append(call)
725
726
727 class SrcTraceParser(TraceParser):
728
729     def __init__(self, stream):
730         TraceParser.__init__(self, stream)
731         self.calls = []
732
733     def parse(self):
734         TraceParser.parse(self)
735         return self.calls
736
737     def handleID(self, value):
738         return value
739
740     def handleInt(self, value):
741         return int(value)
742
743     def handleFloat(self, value):
744         return float(value)
745
746     def handleString(self, value):
747         return value
748
749     def handleBitmask(self, value):
750         return value
751
752     def handleArray(self, elements):
753         return list(elements)
754
755     def handleStruct(self, members):
756         return dict(members)
757
758     def handleCall(self, callNo, functionName, args, ret):
759         call = (callNo, functionName, args, ret)
760         self.calls.append(call)
761
762
763 def main():
764     # Parse command line options
765     optparser = optparse.OptionParser(
766         usage='\n\t%prog [OPTIONS] REF_TXT SRC_TRACE',
767         version='%%prog')
768     optparser.add_option(
769         '--apitrace', metavar='PROGRAM',
770         type='string', dest='apitrace', default=os.environ.get('APITRACE', 'apitrace'),
771         help='path to apitrace executable')
772     optparser.add_option(
773         '-v', '--verbose',
774         action="store_true",
775         dest="verbose", default=True,
776         help="verbose output")
777     (options, args) = optparser.parse_args(sys.argv[1:])
778
779     if len(args) != 2:
780         optparser.error('wrong number of arguments')
781
782     refFileName, srcFileName = args
783
784     refParser = RefTraceParser(refFileName)
785     refTrace = refParser.parse()
786     if options.verbose:
787         sys.stdout.write('// Reference\n')
788         sys.stdout.write(str(refTrace))
789         sys.stdout.write('\n')
790
791     if srcFileName.endswith('.trace'):
792         cmd = [options.apitrace, 'dump', '--verbose', '--color=never', srcFileName]
793         p = subprocess.Popen(cmd, stdout=subprocess.PIPE, universal_newlines=True)
794         srcStream = p.stdout
795     else:
796         srcStream = open(srcFileName, 'rt')
797     srcParser = SrcTraceParser(srcStream)
798     srcTrace = srcParser.parse()
799     if options.verbose:
800         sys.stdout.write('// Source\n')
801         sys.stdout.write(''.join(['%s %s%r = %r\n' % call for call in srcTrace]))
802         sys.stdout.write('\n')
803
804     if options.verbose:
805         sys.stdout.write('// Matching\n')
806     mo = refTrace.match(srcTrace, options.verbose)
807     if options.verbose:
808         sys.stdout.write('\n')
809
810     if options.verbose:
811         sys.stdout.write('// Parameters\n')
812         paramNames = mo.params.keys()
813         paramNames.sort()
814         for paramName in paramNames:
815             print '%s = %r' % (paramName, mo.params[paramName])
816
817
818 if __name__ == '__main__':
819     main()