Be more lenient with shader matching.
[apitrace-tests] / tracematch.py
1 #!/usr/bin/env python
2 ##########################################################################
3 #
4 # Copyright 2008-2012 Jose Fonseca
5 # All Rights Reserved.
6 #
7 # Permission is hereby granted, free of charge, to any person obtaining a copy
8 # of this software and associated documentation files (the "Software"), to deal
9 # in the Software without restriction, including without limitation the rights
10 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 # copies of the Software, and to permit persons to whom the Software is
12 # furnished to do so, subject to the following conditions:
13 #
14 # The above copyright notice and this permission notice shall be included in
15 # all copies or substantial portions of the Software.
16 #
17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 # THE SOFTWARE.
24 #
25 ##########################################################################/
26
27
28 import sys
29 import optparse
30 import os
31 import re
32 import subprocess
33
34
35 class MatchObject:
36
37     def __init__(self):
38         self.params = {}
39
40
41 class Matcher:
42
43     def match(self, value, mo):
44         raise NotImplementedError
45
46     def _matchSequence(self, refValues, srcValues, mo):
47         if not isinstance(srcValues, (list, tuple)):
48             return False
49
50         if len(refValues) != len(srcValues):
51             return False
52
53         for refValue, srcValue in zip(refValues, srcValues):
54             if not refValue.match(srcValue, mo):
55                 return False
56         return True
57
58     def __str__(self):
59         raise NotImplementerError
60
61     def __repr__(self):
62         return str(self)
63
64
65 class WildcardMatcher(Matcher):
66
67     def __init__(self, name = ''):
68         self.name = name
69
70     def match(self, value, mo):
71         if self.name:
72             try:
73                 refValue = mo.params[self.name]
74             except KeyError:
75                 mo.params[self.name] = value
76             else:
77                 return refValue == value
78         return True
79
80     def __str__(self):
81         return '<' + self.name + '>'
82
83
84 class LiteralMatcher(Matcher):
85
86     def __init__(self, refValue):
87         self.refValue = refValue
88
89     def match(self, value, mo):
90         return self.refValue == value
91
92     def __str__(self):
93         return repr(self.refValue)
94
95
96 class ApproxMatcher(Matcher):
97
98     def __init__(self, refValue, tolerance = 2**-23):
99         self.refValue = refValue
100         self.tolerance = tolerance
101
102     def match(self, value, mo):
103         if not isinstance(value, float):
104             return 
105
106         error = abs(self.refValue - value)
107         if self.refValue:
108             error = error / self.refValue
109         return error <= self.tolerance
110
111     def __str__(self):
112         return repr(self.refValue)
113
114
115 class StringMatcher(Matcher):
116
117     def __init__(self, refValue):
118         self.refValue = refValue
119
120     def isShaderDisassembly(self, value):
121         return value.find('// Generated by Microsoft (R) D3D Shader Disassembler\n') != -1
122
123     def normalizeShaderDisassembly(self, value):
124         # Unfortunately slightly different disassemblers produce different output
125         return '\n'.join([line.strip() for line in value.split('\n') if line.strip() and not line.startswith('//')])
126
127     def match(self, value, mo):
128         if self.isShaderDisassembly(self.refValue) and self.isShaderDisassembly(value):
129             return self.normalizeShaderDisassembly(self.refValue) == self.normalizeShaderDisassembly(value)
130         return self.refValue == value
131
132     def __str__(self):
133         return repr(self.refValue)
134
135
136 class BitmaskMatcher(Matcher):
137
138     def __init__(self, refElements):
139         self.refElements = refElements
140
141     def match(self, value, mo):
142         return self._matchSequence(self.refElements, value, mo)
143
144     def __str__(self):
145         return ' | '.join(map(str, self.refElements))
146
147
148 class OffsetMatcher(Matcher):
149
150     def __init__(self, refValue, offset):
151         self.refValue = refValue
152         self.offset = offset
153
154     def match(self, value, mo):
155         return self.refValue.match(value - self.offset, mo)
156
157     def __str__(self):
158         return '%s + %i' % (self.refValue, self.offset)
159
160
161 class ArrayMatcher(Matcher):
162
163     def __init__(self, refElements):
164         self.refElements = refElements
165
166     def match(self, value, mo):
167         return self._matchSequence(self.refElements, value, mo)
168
169     def __str__(self):
170         return '{' + ', '.join(map(str, self.refElements)) + '}'
171
172
173 class StructMatcher(Matcher):
174
175     def __init__(self, refMembers):
176         self.refMembers = refMembers
177
178     def match(self, value, mo):
179         if not isinstance(value, dict):
180             return False
181
182         if len(value) != len(self.refMembers):
183             return False
184
185         for name, refMember in self.refMembers.iteritems():
186             try:
187                 member = value[name]
188             except KeyError:
189                 return False
190             else:
191                 if not refMember.match(member, mo):
192                     return False
193
194         return True
195
196     def __str__(self):
197         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
198
199
200 class CallMatcher(Matcher):
201
202     def __init__(self, callNo, functionName, args, ret):
203         self.callNo = callNo
204         self.functionName = functionName
205         self.args = args
206         self.ret = ret
207
208     def match(self, call, mo):
209         callNo, srcFunctionName, srcArgs, srcRet = call
210
211         if self.functionName != srcFunctionName:
212             return False
213
214         refArgs = [value for name, value in self.args]
215         srcArgs = [value for name, value in srcArgs]
216
217         if not self._matchSequence(refArgs, srcArgs, mo):
218             return False
219
220         if self.ret is None:
221             if srcRet is not None:
222                 return False
223         else:
224             if not self.ret.match(srcRet, mo):
225                 return False
226
227         if self.callNo is not None:
228             if not self.callNo.match(callNo, mo):
229                 return False
230
231         return True
232
233     def __str__(self):
234         s = self.functionName
235         s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.args]) + ')'
236         if self.ret is not None:
237             s += ' = ' + str(self.ret)
238         return s
239
240
241 class TraceMismatch(Exception):
242
243     pass
244
245
246 class TraceMatcher:
247
248     def __init__(self, calls):
249         self.calls = calls
250
251     def match(self, calls, verbose = False):
252         mo = MatchObject()
253         srcCalls = iter(calls)
254         for refCall in self.calls:
255             if verbose:
256                 print refCall
257             skippedSrcCalls = []
258             while True:
259                 try:
260                     srcCall = srcCalls.next()
261                 except StopIteration:
262                     if skippedSrcCalls:
263                         raise TraceMismatch('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
264                     else:
265                         raise TraceMismatch('missing call %s' % refCall)
266                 if verbose:
267                     print '\t%s %s%r = %r' % srcCall
268                 if refCall.match(srcCall, mo):
269                     break
270                 else:
271                     skippedSrcCalls.append(srcCall)
272         return mo
273
274     def __str__(self):
275         return ''.join(['%s\n' % call for call in self.calls])
276
277
278 #######################################################################
279
280 EOF = -1
281 SKIP = -2
282
283
284 class ParseError(Exception):
285
286     def __init__(self, msg=None, filename=None, line=None, col=None):
287         self.msg = msg
288         self.filename = filename
289         self.line = line
290         self.col = col
291
292     def __str__(self):
293         return ':'.join([str(part) for part in (self.filename, self.line, self.col, self.msg) if part != None])
294         
295
296 class Scanner:
297     """Stateless scanner."""
298
299     # should be overriden by derived classes
300     tokens = []
301     symbols = {}
302     literals = {}
303     ignorecase = False
304
305     def __init__(self):
306         flags = re.DOTALL
307         if self.ignorecase:
308             flags |= re.IGNORECASE
309         self.tokens_re = re.compile(
310             '|'.join(['(' + regexp + ')' for type, regexp, test_lit in self.tokens]),
311              flags
312         )
313
314     def next(self, buf, pos):
315         if pos >= len(buf):
316             return EOF, '', pos
317         mo = self.tokens_re.match(buf, pos)
318         if mo:
319             text = mo.group()
320             type, regexp, test_lit = self.tokens[mo.lastindex - 1]
321             pos = mo.end()
322             if test_lit:
323                 type = self.literals.get(text, type)
324             return type, text, pos
325         else:
326             c = buf[pos]
327             return self.symbols.get(c, None), c, pos + 1
328
329
330 class Token:
331
332     def __init__(self, type, text, line, col):
333         self.type = type
334         self.text = text
335         self.line = line
336         self.col = col
337
338
339 class Lexer:
340
341     # should be overriden by derived classes
342     scanner = None
343     tabsize = 8
344
345     newline_re = re.compile(r'\r\n?|\n')
346
347     def __init__(self, buf = None, pos = 0, filename = None, fp = None):
348         if fp is not None:
349             try:
350                 fileno = fp.fileno()
351                 length = os.path.getsize(fp.name)
352                 import mmap
353             except:
354                 # read whole file into memory
355                 buf = fp.read()
356                 pos = 0
357             else:
358                 # map the whole file into memory
359                 if length:
360                     # length must not be zero
361                     buf = mmap.mmap(fileno, length, access = mmap.ACCESS_READ)
362                     pos = os.lseek(fileno, 0, 1)
363                 else:
364                     buf = ''
365                     pos = 0
366
367             if filename is None:
368                 try:
369                     filename = fp.name
370                 except AttributeError:
371                     filename = None
372
373         self.buf = buf
374         self.pos = pos
375         self.line = 1
376         self.col = 1
377         self.filename = filename
378
379     def next(self):
380         while True:
381             # save state
382             pos = self.pos
383             line = self.line
384             col = self.col
385
386             type, text, endpos = self.scanner.next(self.buf, pos)
387             assert pos + len(text) == endpos
388             self.consume(text)
389             type, text = self.filter(type, text)
390             self.pos = endpos
391
392             if type == SKIP:
393                 continue
394             elif type is None:
395                 msg = 'unexpected char '
396                 if text >= ' ' and text <= '~':
397                     msg += "'%s'" % text
398                 else:
399                     msg += "0x%X" % ord(text)
400                 raise ParseError(msg, self.filename, line, col)
401             else:
402                 break
403         return Token(type = type, text = text, line = line, col = col)
404
405     def consume(self, text):
406         # update line number
407         pos = 0
408         for mo in self.newline_re.finditer(text, pos):
409             self.line += 1
410             self.col = 1
411             pos = mo.end()
412
413         # update column number
414         while True:
415             tabpos = text.find('\t', pos)
416             if tabpos == -1:
417                 break
418             self.col += tabpos - pos
419             self.col = ((self.col - 1)//self.tabsize + 1)*self.tabsize + 1
420             pos = tabpos + 1
421         self.col += len(text) - pos
422     
423     def filter(self, type, text):
424         return type, text
425
426
427 class Parser:
428
429     def __init__(self, lexer):
430         self.lexer = lexer
431         self.lookahead = self.lexer.next()
432
433     def match(self, type):
434         return self.lookahead.type == type
435
436     def skip(self, type):
437         while not self.match(type):
438             self.consume()
439
440     def error(self):
441         raise ParseError(
442             msg = 'unexpected token %r' % self.lookahead.text, 
443             filename = self.lexer.filename, 
444             line = self.lookahead.line, 
445             col = self.lookahead.col)
446
447     def consume(self, type = None):
448         if type is not None and not self.match(type):
449             self.error()
450         token = self.lookahead
451         self.lookahead = self.lexer.next()
452         return token
453
454
455 #######################################################################
456
457 ID, NUMBER, HEXNUM, STRING, WILDCARD, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, PLUS, VERT, BLOB = xrange(15)
458
459
460 class CallScanner(Scanner):
461
462     # token regular expression table
463     tokens = [
464         # whitespace
465         (SKIP, r'[ \t\f\r\n\v]+', False),
466
467         # comments
468         (SKIP, r'//[^\r\n]*', False),
469
470         # Alphanumeric IDs
471         (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True),
472
473         # Numeric IDs
474         (HEXNUM, r'-?0x[0-9a-fA-F]+', False),
475         
476         # Numeric IDs
477         (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)(?:[eE][-+][0-9]+)?', False),
478
479         # String IDs
480         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
481         
482         # Wildcard
483         (WILDCARD, r'<[^>]*>', False),
484     ]
485
486     # symbol table
487     symbols = {
488         '(': LPAREN,
489         ')': RPAREN,
490         '{': LCURLY,
491         '}': RCURLY,
492         ',': COMMA,
493         '&': AMP,
494         '=': EQUAL,
495         '+': PLUS,
496         '|': VERT,
497     }
498
499     # literal table
500     literals = {
501         'blob': BLOB
502     }
503
504
505 class CallLexer(Lexer):
506
507     scanner = CallScanner()
508
509     def filter(self, type, text):
510         if type == STRING:
511             text = text[1:-1]
512
513             # quotes
514             text = text.replace('\\"', '"')
515
516         return type, text
517
518
519 class TraceParser(Parser):
520
521     def __init__(self, stream):
522         lexer = CallLexer(fp = stream)
523         Parser.__init__(self, lexer)
524
525     def eof(self):
526         return self.match(EOF)
527
528     def parse(self):
529         while not self.eof():
530             self.parse_call()
531         return TraceMatcher(self.calls)
532
533     def parse_call(self):
534         if self.lookahead.type == NUMBER:
535             token = self.consume()
536             callNo = self.handleInt(int(token.text))
537         elif self.lookahead.type == WILDCARD:
538             token = self.consume()
539             callNo = self.handleWildcard((token.text[1:-1]))
540         else:
541             callNo = None
542         
543         functionName = self.consume(ID).text
544
545         args = self.parse_sequence(LPAREN, RPAREN, self.parse_pair)
546
547         if self.match(EQUAL):
548             self.consume(EQUAL)
549             ret = self.parse_value()
550         else:
551             ret = None
552
553         self.handleCall(callNo, functionName, args, ret)
554
555     def parse_pair(self):
556         '''Parse a `name = value` pair.'''
557         name = self.consume(ID).text
558         self.consume(EQUAL)
559         value = self.parse_value()
560         return name, value
561
562     def parse_opt_pair(self):
563         '''Parse an optional `name = value` pair.'''
564         if self.match(ID):
565             token = self.consume(ID)
566             if self.match(EQUAL):
567                 self.consume(EQUAL)
568                 name = token.text
569                 value = self.parse_value()
570             else:
571                 name = None
572                 value = self.handleID(token.text)
573         else:
574             name = None
575             value = self.parse_value()
576         if name is None:
577             return value
578         else:
579             return name, value
580
581     def parse_value(self):
582         value = self._parse_value()
583         if self.match(VERT):
584             flags = [value]
585             while self.match(VERT):
586                 self.consume()
587                 value = self._parse_value()
588                 flags.append(value)
589             return self.handleBitmask(flags)
590         elif self.match(PLUS):
591             self.consume()
592             if self.match(NUMBER):
593                 token = self.consume()
594                 offset = int(token.text)
595             elif self.match(HEXNUM):
596                 token = self.consume()
597                 offset = int(token.text, 16)
598             else:
599                 assert 0
600             return self.handleOffset(value, offset)
601         else:
602             return value
603
604     def _parse_value(self):
605         if self.match(AMP):
606             self.consume()
607             value = [self.parse_value()]
608             return self.handleArray(value)
609         elif self.match(ID):
610             token = self.consume()
611             value = token.text
612             return self.handleID(value)
613         elif self.match(STRING):
614             token = self.consume()
615             value = token.text
616             return self.handleString(value)
617         elif self.match(NUMBER):
618             token = self.consume()
619             try:
620                 value = int(token.text)
621             except ValueError:
622                 value = float(token.text)
623                 return self.handleFloat(value)
624             else:
625                 return self.handleInt(value)
626         elif self.match(HEXNUM):
627             token = self.consume()
628             value = int(token.text, 16)
629             return self.handleInt(value)
630         elif self.match(LCURLY):
631             value = self.parse_sequence(LCURLY, RCURLY, self.parse_opt_pair)
632             if len(value) and isinstance(value[0], tuple):
633                 value = dict(value)
634                 return self.handleStruct(value)
635             else:
636                 return self.handleArray(value)
637         elif self.match(BLOB):
638             token = self.consume()
639             self.consume(LPAREN)
640             token = self.consume()
641             length = int(token.text)
642             self.consume(RPAREN)
643             return self.handleBlob(length)
644         elif self.match(WILDCARD):
645             token = self.consume()
646             return self.handleWildcard(token.text[1:-1])
647         else:
648             self.error()
649
650     def parse_sequence(self, ltype, rtype, elementParser):
651         '''Parse a comma separated list'''
652
653         elements = []
654
655         self.consume(ltype)
656         sep = None
657         while not self.match(rtype):
658             if sep is None:
659                 sep = COMMA
660             else:
661                 self.consume(sep)
662             element = elementParser()
663             elements.append(element)
664         self.consume(rtype)
665
666         return elements
667     
668     def handleID(self, value):
669         raise NotImplementedError
670
671     def handleInt(self, value):
672         raise NotImplementedError
673
674     def handleFloat(self, value):
675         raise NotImplementedError
676
677     def handleString(self, value):
678         raise NotImplementedError
679
680     def handleBitmask(self, value):
681         raise NotImplementedError
682
683     def handleOffset(self, value, offset):
684         raise NotImplementedError
685
686     def handleArray(self, value):
687         raise NotImplementedError
688
689     def handleStruct(self, value):
690         raise NotImplementedError
691
692     def handleBlob(self, length):
693         return self.handleID('blob(%u)' % length)
694
695     def handleWildcard(self, name):
696         raise NotImplementedError
697
698     def handleCall(self, callNo, functionName, args, ret):
699         raise NotImplementedError
700
701
702 class RefTraceParser(TraceParser):
703
704     def __init__(self, fileName):
705         TraceParser.__init__(self, open(fileName, 'rt'))
706         self.calls = []
707
708     def parse(self):
709         TraceParser.parse(self)
710         return TraceMatcher(self.calls)
711
712     def handleID(self, value):
713         return LiteralMatcher(value)
714
715     def handleInt(self, value):
716         return LiteralMatcher(value)
717
718     def handleFloat(self, value):
719         return ApproxMatcher(value)
720
721     def handleString(self, value):
722         return StringMatcher(value)
723
724     def handleBitmask(self, value):
725         return BitmaskMatcher(value)
726
727     def handleOffset(self, value, offset):
728         return OffsetMatcher(value, offset)
729
730     def handleArray(self, value):
731         return ArrayMatcher(value)
732
733     def handleStruct(self, value):
734         return StructMatcher(value)
735
736     def handleWildcard(self, name):
737         return WildcardMatcher(name)
738
739     def handleCall(self, callNo, functionName, args, ret):
740         call = CallMatcher(callNo, functionName, args, ret)
741         self.calls.append(call)
742
743
744 class SrcTraceParser(TraceParser):
745
746     def __init__(self, stream):
747         TraceParser.__init__(self, stream)
748         self.calls = []
749
750     def parse(self):
751         TraceParser.parse(self)
752         return self.calls
753
754     def handleID(self, value):
755         return value
756
757     def handleInt(self, value):
758         return int(value)
759
760     def handleFloat(self, value):
761         return float(value)
762
763     def handleString(self, value):
764         return value
765
766     def handleBitmask(self, value):
767         return value
768
769     def handleArray(self, elements):
770         return list(elements)
771
772     def handleStruct(self, members):
773         return dict(members)
774
775     def handleCall(self, callNo, functionName, args, ret):
776         call = (callNo, functionName, args, ret)
777         self.calls.append(call)
778
779
780 def main():
781     # Parse command line options
782     optparser = optparse.OptionParser(
783         usage='\n\t%prog [OPTIONS] REF_TXT SRC_TRACE',
784         version='%%prog')
785     optparser.add_option(
786         '--apitrace', metavar='PROGRAM',
787         type='string', dest='apitrace', default=os.environ.get('APITRACE', 'apitrace'),
788         help='path to apitrace executable')
789     optparser.add_option(
790         '-v', '--verbose',
791         action="store_true",
792         dest="verbose", default=True,
793         help="verbose output")
794     (options, args) = optparser.parse_args(sys.argv[1:])
795
796     if len(args) != 2:
797         optparser.error('wrong number of arguments')
798
799     refFileName, srcFileName = args
800
801     refParser = RefTraceParser(refFileName)
802     refTrace = refParser.parse()
803     if options.verbose:
804         sys.stdout.write('// Reference\n')
805         sys.stdout.write(str(refTrace))
806         sys.stdout.write('\n')
807
808     if srcFileName.endswith('.trace'):
809         cmd = [options.apitrace, 'dump', '--verbose', '--color=never', srcFileName]
810         p = subprocess.Popen(cmd, stdout=subprocess.PIPE, universal_newlines=True)
811         srcStream = p.stdout
812     else:
813         srcStream = open(srcFileName, 'rt')
814     srcParser = SrcTraceParser(srcStream)
815     srcTrace = srcParser.parse()
816     if options.verbose:
817         sys.stdout.write('// Source\n')
818         sys.stdout.write(''.join(['%s %s%r = %r\n' % call for call in srcTrace]))
819         sys.stdout.write('\n')
820
821     if options.verbose:
822         sys.stdout.write('// Matching\n')
823     mo = refTrace.match(srcTrace, options.verbose)
824     if options.verbose:
825         sys.stdout.write('\n')
826
827     if options.verbose:
828         sys.stdout.write('// Parameters\n')
829         paramNames = mo.params.keys()
830         paramNames.sort()
831         for paramName in paramNames:
832             print '%s = %r' % (paramName, mo.params[paramName])
833
834
835 if __name__ == '__main__':
836     main()