]> git.cworth.org Git - apitrace-tests/blob - checker.py
Further improvements to checker.
[apitrace-tests] / checker.py
1 #!/usr/bin/env python
2 ##########################################################################
3 #
4 # Copyright 2008-2012 Jose Fonseca
5 # All Rights Reserved.
6 #
7 # Permission is hereby granted, free of charge, to any person obtaining a copy
8 # of this software and associated documentation files (the "Software"), to deal
9 # in the Software without restriction, including without limitation the rights
10 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 # copies of the Software, and to permit persons to whom the Software is
12 # furnished to do so, subject to the following conditions:
13 #
14 # The above copyright notice and this permission notice shall be included in
15 # all copies or substantial portions of the Software.
16 #
17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 # THE SOFTWARE.
24 #
25 ##########################################################################/
26
27
28 import sys
29 import optparse
30 import re
31
32
33 class MatchObject:
34
35     def __init__(self):
36         self.params = {}
37
38
39 class Matcher:
40
41     def match(self, value, mo):
42         raise NotImplementedError
43
44     def _matchSequence(self, refValues, srcValues, mo):
45         if not isinstance(srcValues, (list, tuple)):
46             return False
47
48         if len(refValues) != len(srcValues):
49             return False
50
51         for refValue, srcValue in zip(refValues, srcValues):
52             if not refValue.match(srcValue, mo):
53                 return False
54         return True
55
56     def __str__(self):
57         raise NotImplementerError
58
59     def __repr__(self):
60         return str(self)
61
62
63 class WildcardMatcher(Matcher):
64
65     def __init__(self, name = ''):
66         self.name = name
67
68     def match(self, value, mo):
69         if self.name:
70             try:
71                 refValue = mo.params[self.name]
72             except KeyError:
73                 mo.params[self.name] = value
74             else:
75                 return refValue == value
76         return True
77
78     def __str__(self):
79         return '<' + self.name + '>'
80
81
82 class LiteralMatcher(Matcher):
83
84     def __init__(self, refValue):
85         self.refValue = refValue
86
87     def match(self, value, mo):
88         return self.refValue == value
89
90     def __str__(self):
91         return repr(self.refValue)
92
93
94 class ApproxMatcher(Matcher):
95
96     def __init__(self, refValue, tolerance = 2**-23):
97         self.refValue = refValue
98         self.tolerance = tolerance
99
100     def match(self, value, mo):
101         if not isinstance(value, float):
102             return 
103
104         error = abs(self.refValue - value)
105         if self.refValue:
106             error = error / self.refValue
107         return error <= self.tolerance
108
109     def __str__(self):
110         return repr(self.refValue)
111
112
113 class BitmaskMatcher(Matcher):
114
115     def __init__(self, refElements):
116         self.refElements = refElements
117
118     def match(self, value, mo):
119         return self._matchSequence(self.refElements, value, mo)
120
121     def __str__(self):
122         return ' | '.join(map(str, self.refElements))
123
124
125 class ArrayMatcher(Matcher):
126
127     def __init__(self, refElements):
128         self.refElements = refElements
129
130     def match(self, value, mo):
131         return self._matchSequence(self.refElements, value, mo)
132
133     def __str__(self):
134         return '{' + ', '.join(map(str, self.refElements)) + '}'
135
136
137 class StructMatcher(Matcher):
138
139     def __init__(self, refMembers):
140         self.refMembers = refMembers
141
142     def match(self, value, mo):
143         if not isinstance(value, dict):
144             return False
145
146         if len(value) != len(self.refMembers):
147             return False
148
149         for name, refMember in self.refMembers.iteritems():
150             try:
151                 member = value[name]
152             except KeyError:
153                 return False
154             else:
155                 if not refMember.match(member, mo):
156                     return False
157
158         return True
159
160     def __str__(self):
161         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
162
163
164 class CallMatcher(Matcher):
165
166     def __init__(self, callNo, functionName, args, ret):
167         self.callNo = callNo
168         self.functionName = functionName
169         self.args = args
170         self.ret = ret
171
172     def match(self, call, mo):
173         callNo, srcFunctionName, srcArgs, srcRet = call
174
175         if self.functionName != srcFunctionName:
176             return False
177
178         refArgs = [value for name, value in self.args]
179         srcArgs = [value for name, value in srcArgs]
180
181         if not self._matchSequence(refArgs, srcArgs, mo):
182             return False
183
184         if self.ret is None:
185             if srcRet is not None:
186                 return False
187         else:
188             if not self.ret.match(srcRet, mo):
189                 return False
190
191         if self.callNo is not None:
192             if not self.callNo.match(callNo, mo):
193                 return False
194
195         return True
196
197     def __str__(self):
198         s = self.functionName
199         s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.args]) + ')'
200         if self.ret is not None:
201             s += ' = ' + str(self.ret)
202         return s
203
204
205 class TraceMatcher:
206
207     def __init__(self, calls):
208         self.calls = calls
209
210     def match(self, trace):
211         mo = MatchObject()
212         srcCalls = iter(trace.calls)
213         for refCall in self.calls:
214             skippedSrcCalls = []
215             while True:
216                 try:
217                     srcCall = srcCalls.next()
218                 except StopIteration:
219                     if skippedSrcCalls:
220                         raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
221                     else:
222                         raise Exception('missing call %s' % refCall)
223                 if refCall.match(srcCall, mo):
224                     break
225                 else:
226                     skippedSrcCalls.append(srcCall)
227         return mo
228
229     def __str__(self):
230         return ''.join(['%s\n' % call for call in self.calls])
231
232
233 #######################################################################
234
235 EOF = -1
236 SKIP = -2
237
238
239 class ParseError(Exception):
240
241     def __init__(self, msg=None, filename=None, line=None, col=None):
242         self.msg = msg
243         self.filename = filename
244         self.line = line
245         self.col = col
246
247     def __str__(self):
248         return ':'.join([str(part) for part in (self.filename, self.line, self.col, self.msg) if part != None])
249         
250
251 class Scanner:
252     """Stateless scanner."""
253
254     # should be overriden by derived classes
255     tokens = []
256     symbols = {}
257     literals = {}
258     ignorecase = False
259
260     def __init__(self):
261         flags = re.DOTALL
262         if self.ignorecase:
263             flags |= re.IGNORECASE
264         self.tokens_re = re.compile(
265             '|'.join(['(' + regexp + ')' for type, regexp, test_lit in self.tokens]),
266              flags
267         )
268
269     def next(self, buf, pos):
270         if pos >= len(buf):
271             return EOF, '', pos
272         mo = self.tokens_re.match(buf, pos)
273         if mo:
274             text = mo.group()
275             type, regexp, test_lit = self.tokens[mo.lastindex - 1]
276             pos = mo.end()
277             if test_lit:
278                 type = self.literals.get(text, type)
279             return type, text, pos
280         else:
281             c = buf[pos]
282             return self.symbols.get(c, None), c, pos + 1
283
284
285 class Token:
286
287     def __init__(self, type, text, line, col):
288         self.type = type
289         self.text = text
290         self.line = line
291         self.col = col
292
293
294 class Lexer:
295
296     # should be overriden by derived classes
297     scanner = None
298     tabsize = 8
299
300     newline_re = re.compile(r'\r\n?|\n')
301
302     def __init__(self, buf = None, pos = 0, filename = None, fp = None):
303         if fp is not None:
304             try:
305                 fileno = fp.fileno()
306                 length = os.path.getsize(fp.name)
307                 import mmap
308             except:
309                 # read whole file into memory
310                 buf = fp.read()
311                 pos = 0
312             else:
313                 # map the whole file into memory
314                 if length:
315                     # length must not be zero
316                     buf = mmap.mmap(fileno, length, access = mmap.ACCESS_READ)
317                     pos = os.lseek(fileno, 0, 1)
318                 else:
319                     buf = ''
320                     pos = 0
321
322             if filename is None:
323                 try:
324                     filename = fp.name
325                 except AttributeError:
326                     filename = None
327
328         self.buf = buf
329         self.pos = pos
330         self.line = 1
331         self.col = 1
332         self.filename = filename
333
334     def next(self):
335         while True:
336             # save state
337             pos = self.pos
338             line = self.line
339             col = self.col
340
341             type, text, endpos = self.scanner.next(self.buf, pos)
342             assert pos + len(text) == endpos
343             self.consume(text)
344             type, text = self.filter(type, text)
345             self.pos = endpos
346
347             if type == SKIP:
348                 continue
349             elif type is None:
350                 msg = 'unexpected char '
351                 if text >= ' ' and text <= '~':
352                     msg += "'%s'" % text
353                 else:
354                     msg += "0x%X" % ord(text)
355                 raise ParseError(msg, self.filename, line, col)
356             else:
357                 break
358         return Token(type = type, text = text, line = line, col = col)
359
360     def consume(self, text):
361         # update line number
362         pos = 0
363         for mo in self.newline_re.finditer(text, pos):
364             self.line += 1
365             self.col = 1
366             pos = mo.end()
367
368         # update column number
369         while True:
370             tabpos = text.find('\t', pos)
371             if tabpos == -1:
372                 break
373             self.col += tabpos - pos
374             self.col = ((self.col - 1)//self.tabsize + 1)*self.tabsize + 1
375             pos = tabpos + 1
376         self.col += len(text) - pos
377
378
379 class Parser:
380
381     def __init__(self, lexer):
382         self.lexer = lexer
383         self.lookahead = self.lexer.next()
384
385     def match(self, type):
386         return self.lookahead.type == type
387
388     def skip(self, type):
389         while not self.match(type):
390             self.consume()
391
392     def error(self):
393         raise ParseError(
394             msg = 'unexpected token %r' % self.lookahead.text, 
395             filename = self.lexer.filename, 
396             line = self.lookahead.line, 
397             col = self.lookahead.col)
398
399     def consume(self, type = None):
400         if type is not None and not self.match(type):
401             self.error()
402         token = self.lookahead
403         self.lookahead = self.lexer.next()
404         return token
405
406
407 #######################################################################
408
409 ID, NUMBER, HEXNUM, STRING, WILDCARD, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(15)
410
411
412 class CallScanner(Scanner):
413
414     # token regular expression table
415     tokens = [
416         # whitespace
417         (SKIP, r'[ \t\f\r\n\v]+', False),
418
419         # Alphanumeric IDs
420         (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True),
421
422         # Numeric IDs
423         (HEXNUM, r'-?0x[0-9a-fA-F]+', False),
424         
425         # Numeric IDs
426         (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)(?:[eE][-+][0-9]+)?', False),
427
428         # String IDs
429         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
430         
431         # Wildcard
432         (WILDCARD, r'<[^>]*>', False),
433         
434         # Pragma
435         (PRAGMA, r'#[^\r\n]*', False),
436     ]
437
438     # symbol table
439     symbols = {
440         '(': LPAREN,
441         ')': RPAREN,
442         '{': LCURLY,
443         '}': RCURLY,
444         ',': COMMA,
445         '&': AMP,
446         '=': EQUAL,
447         '|': VERT,
448     }
449
450     # literal table
451     literals = {
452         'blob': BLOB
453     }
454
455
456 class CallLexer(Lexer):
457
458     scanner = CallScanner()
459
460     def filter(self, type, text):
461         if type == STRING:
462             text = text[1:-1]
463
464             # line continuations
465             text = text.replace('\\\r\n', '')
466             text = text.replace('\\\r', '')
467             text = text.replace('\\\n', '')
468             
469             # quotes
470             text = text.replace('\\"', '"')
471
472             type = ID
473
474         return type, text
475
476
477 class TraceParser(Parser):
478
479     def __init__(self, stream):
480         lexer = CallLexer(fp = stream)
481         Parser.__init__(self, lexer)
482
483     def eof(self):
484         return self.match(EOF)
485
486     def parse(self):
487         while not self.eof():
488             self.parse_element()
489         return TraceMatcher(self.calls)
490
491     def parse_element(self):
492         if self.lookahead.type == PRAGMA:
493             # TODO
494             token = self.consume()
495             self.handlePragma(token.text)
496         else:
497             self.parse_call()
498
499     def parse_call(self):
500         while self.lookahead.type == PRAGMA:
501             # TODO
502             token = self.consume()
503
504         if self.lookahead.type == NUMBER:
505             token = self.consume()
506             callNo = self.handleInt(int(token.text))
507         elif self.lookahead.type == WILDCARD:
508             token = self.consume()
509             callNo = self.handleWildcard((token.text[1:-1]))
510         else:
511             callNo = None
512         
513         functionName = self.consume(ID).text
514
515         args = self.parse_sequence(LPAREN, RPAREN, self.parse_pair)
516
517         if self.match(EQUAL):
518             self.consume(EQUAL)
519             ret = self.parse_value()
520         else:
521             ret = None
522
523         self.handleCall(callNo, functionName, args, ret)
524
525     def parse_pair(self):
526         '''Parse a `name = value` pair.'''
527         name = self.consume(ID).text
528         self.consume(EQUAL)
529         value = self.parse_value()
530         return name, value
531
532     def parse_opt_pair(self):
533         '''Parse an optional `name = value` pair.'''
534         if self.match(ID):
535             name = self.consume(ID).text
536             if self.match(EQUAL):
537                 self.consume(EQUAL)
538                 value = self.parse_value()
539             else:
540                 value = name
541                 name = None
542         else:
543             name = None
544             value = self.parse_value()
545         if name is None:
546             return value
547         else:
548             return name, value
549
550     def parse_value(self):
551         value = self._parse_value()
552         if self.match(VERT):
553             flags = [value]
554             while self.match(VERT):
555                 self.consume()
556                 value = self._parse_value()
557                 flags.append(value)
558             return self.handleBitmask(flags)
559         else:
560             return value
561
562     def _parse_value(self):
563         if self.match(AMP):
564             self.consume()
565             value = [self.parse_value()]
566             return self.handleArray(value)
567         elif self.match(ID):
568             token = self.consume()
569             value = token.text
570             return self.handleID(value)
571         elif self.match(STRING):
572             token = self.consume()
573             value = token.text
574             return self.handleString(value)
575         elif self.match(NUMBER):
576             token = self.consume()
577             value = float(token.text)
578             return self.handleFloat(value)
579         elif self.match(HEXNUM):
580             token = self.consume()
581             value = int(token.text, 16)
582             return self.handleInt(value)
583         elif self.match(LCURLY):
584             value = self.parse_sequence(LCURLY, RCURLY, self.parse_opt_pair)
585             if len(value) and isinstance(value[0], tuple):
586                 value = dict(value)
587                 return self.handleStruct(value)
588             else:
589                 return self.handleArray(value)
590         elif self.match(BLOB):
591             token = self.consume()
592             self.consume(LPAREN)
593             token = self.consume()
594             length = int(token.text)
595             self.consume(RPAREN)
596             return self.handleBlob(length)
597         elif self.match(WILDCARD):
598             token = self.consume()
599             return self.handleWildcard(token.text[1:-1])
600         else:
601             self.error()
602
603     def parse_sequence(self, ltype, rtype, elementParser):
604         '''Parse a comma separated list'''
605
606         elements = []
607
608         self.consume(ltype)
609         sep = None
610         while not self.match(rtype):
611             if sep is None:
612                 sep = COMMA
613             else:
614                 self.consume(sep)
615             element = elementParser()
616             elements.append(element)
617         self.consume(rtype)
618
619         return elements
620     
621     def handleID(self, value):
622         raise NotImplementedError
623
624     def handleInt(self, value):
625         raise NotImplementedError
626
627     def handleFloat(self, value):
628         raise NotImplementedError
629
630     def handleString(self, value):
631         raise NotImplementedError
632
633     def handleBitmask(self, value):
634         raise NotImplementedError
635
636     def handleArray(self, value):
637         raise NotImplementedError
638
639     def handleStruct(self, value):
640         raise NotImplementedError
641
642     def handleBlob(self, length):
643         return self.handleID('blob(%u)' % length)
644
645     def handleWildcard(self, name):
646         raise NotImplementedError
647
648     def handleCall(self, callNo, functionName, args, ret):
649         raise NotImplementedError
650
651     def handlePragma(self, line):
652         pass
653
654
655 class RefTraceParser(TraceParser):
656
657     def __init__(self, stream):
658         TraceParser.__init__(self, stream)
659         self.calls = []
660
661     def parse(self):
662         TraceParser.parse(self)
663         return TraceMatcher(self.calls)
664
665     def handleID(self, value):
666         return LiteralMatcher(value)
667
668     def handleInt(self, value):
669         return LiteralMatcher(value)
670
671     def handleFloat(self, value):
672         return ApproxMatcher(value)
673
674     def handleString(self, value):
675         return LiteralMatcher(value)
676
677     def handleBitmask(self, value):
678         return BitmaskMatcher(value)
679
680     def handleArray(self, value):
681         return ArrayMatcher(value)
682
683     def handleStruct(self, value):
684         return StructMatcher(value)
685
686     def handleWildcard(self, name):
687         return WildcardMatcher(name)
688
689     def handleCall(self, callNo, functionName, args, ret):
690         call = CallMatcher(callNo, functionName, args, ret)
691         self.calls.append(call)
692
693
694 class SrcTraceParser(TraceParser):
695
696     def __init__(self, stream):
697         TraceParser.__init__(self, stream)
698         self.calls = []
699
700     def parse(self):
701         TraceParser.parse(self)
702         return TraceMatcher(self.calls)
703
704     def handleID(self, value):
705         return value
706
707     def handleInt(self, value):
708         return int(value)
709
710     def handleFloat(self, value):
711         return float(value)
712
713     def handleString(self, value):
714         return value
715
716     def handleBitmask(self, value):
717         return value
718
719     def handleArray(self, elements):
720         return list(elements)
721
722     def handleStruct(self, members):
723         return dict(members)
724
725     def handleCall(self, callNo, functionName, args, ret):
726         call = (callNo, functionName, args, ret)
727         self.calls.append(call)
728
729
730 def main():
731     # Parse command line options
732     optparser = optparse.OptionParser(
733         usage='\n\t%prog [OPTIONS] REF_TRACE SRC_TRACE',
734         version='%%prog')
735     optparser.add_option(
736         '-v', '--verbose',
737         action="store_true",
738         dest="verbose", default=False,
739         help="verbose output")
740     (options, args) = optparser.parse_args(sys.argv[1:])
741
742     if len(args) != 2:
743         optparser.error('wrong number of arguments')
744
745     refParser = RefTraceParser(open(args[0], 'rt'))
746     refTrace = refParser.parse()
747     sys.stdout.write(str(refTrace))
748     srcParser = SrcTraceParser(open(args[1], 'rt'))
749     srcTrace = srcParser.parse()
750     mo = refTrace.match(srcTrace)
751
752     paramNames = mo.params.keys()
753     paramNames.sort()
754     for paramName in paramNames:
755         print '%s = %r' % (paramName, mo.params[paramName])
756
757
758 if __name__ == '__main__':
759     main()