]> git.cworth.org Git - apitrace-tests/blob - checker.py
Improve checker.
[apitrace-tests] / checker.py
1 #!/usr/bin/env python
2 ##########################################################################
3 #
4 # Copyright 2008-2012 Jose Fonseca
5 # All Rights Reserved.
6 #
7 # Permission is hereby granted, free of charge, to any person obtaining a copy
8 # of this software and associated documentation files (the "Software"), to deal
9 # in the Software without restriction, including without limitation the rights
10 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 # copies of the Software, and to permit persons to whom the Software is
12 # furnished to do so, subject to the following conditions:
13 #
14 # The above copyright notice and this permission notice shall be included in
15 # all copies or substantial portions of the Software.
16 #
17 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 # THE SOFTWARE.
24 #
25 ##########################################################################/
26
27
28 import sys
29 import optparse
30 import re
31
32
33 class Matcher:
34
35     def match(self, value):
36         raise NotImplementedError
37
38     def _matchSequence(self, refValues, srcValues):
39         if not isinstance(srcValues, (list, tuple)):
40             return False
41
42         if len(refValues) != len(srcValues):
43             return False
44
45         for refValue, srcValue in zip(refValues, srcValues):
46             if not refValue.match(srcValue):
47                 return False
48         return True
49
50     def __str__(self):
51         raise NotImplementerError
52
53     def __repr__(self):
54         return str(self)
55
56
57 class WildcardMatcher(Matcher):
58
59     def match(self, value):
60         return True
61
62     def __str__(self):
63         return '*'
64
65
66 class LiteralMatcher(Matcher):
67
68     def __init__(self, refValue):
69         self.refValue = refValue
70
71     def match(self, value):
72         return self.refValue == value
73
74     def __str__(self):
75         return repr(self.refValue)
76
77
78 class ApproxMatcher(Matcher):
79
80     def __init__(self, refValue, tolerance = 2**-23):
81         self.refValue = refValue
82         self.tolerance = tolerance
83
84     def match(self, value):
85         if not isinstance(value, float):
86             return 
87
88         error = abs(self.refValue - value)
89         if self.refValue:
90             error = error / self.refValue
91         return error <= self.tolerance
92
93     def __str__(self):
94         return repr(self.refValue)
95
96
97 class BitmaskMatcher(Matcher):
98
99     def __init__(self, refElements):
100         self.refElements = refElements
101
102     def match(self, value):
103         return self._matchSequence(self.refElements, value)
104
105     def __str__(self):
106         return ' | '.join(map(str, self.refElements))
107
108
109 class ArrayMatcher(Matcher):
110
111     def __init__(self, refElements):
112         self.refElements = refElements
113
114     def match(self, value):
115         return self._matchSequence(self.refElements, value)
116
117     def __str__(self):
118         return '{' + ', '.join(map(str, self.refElements)) + '}'
119
120
121 class StructMatcher(Matcher):
122
123     def __init__(self, refMembers):
124         self.refMembers = refMembers
125
126     def match(self, value):
127         if not isinstance(value, dict):
128             return False
129
130         if len(value) != len(self.refMembers):
131             return False
132
133         for name, refMember in self.refMembers.iteritems():
134             try:
135                 member = value[name]
136             except KeyError:
137                 return False
138             else:
139                 if not refMember.match(member):
140                     return False
141
142         return True
143
144     def __str__(self):
145         print self.refMembers
146         return '{' + ', '.join(['%s = %s' % refMember for refMember in self.refMembers.iteritems()]) + '}'
147
148
149 class CallMatcher(Matcher):
150
151     def __init__(self, functionName, args, ret = None):
152         self.functionName = functionName
153         self.args = args
154         self.ret = ret
155
156     def match(self, call):
157         srcFunctionName, srcArgs, srcRet = call
158
159         if self.functionName != srcFunctionName:
160             return False
161
162         refArgs = [value for name, value in self.args]
163         srcArgs = [value for name, value in srcArgs]
164
165         if not self._matchSequence(refArgs, srcArgs):
166             return False
167
168         if self.ret is None:
169             if srcRet is not None:
170                 return False
171         else:
172             if not self.ret.match(srcRet):
173                 return False
174
175         return True
176
177     def __str__(self):
178         s = self.functionName
179         s += '(' + ', '.join(['%s = %s' % refArg for refArg in self.args]) + ')'
180         if self.ret is not None:
181             s += ' = ' + str(self.ret)
182         return s
183
184
185 class TraceMatcher:
186
187     def __init__(self, calls):
188         self.calls = calls
189
190     def match(self, trace):
191
192         srcCalls = iter(trace.calls)
193         for refCall in self.calls:
194             skippedSrcCalls = []
195             while True:
196                 try:
197                     srcCall = srcCalls.next()
198                 except StopIteration:
199                     if skippedSrcCalls:
200                         raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
201                     else:
202                         raise Exception('missing call %s' % refCall)
203                 if refCall.match(srcCall):
204                     break
205                 else:
206                     skippedSrcCalls.append(srcCall)
207         return True
208
209     def __str__(self):
210         return ''.join(['%s\n' % call for call in self.calls])
211
212
213 #######################################################################
214
215 EOF = -1
216 SKIP = -2
217
218
219 class ParseError(Exception):
220
221     def __init__(self, msg=None, filename=None, line=None, col=None):
222         self.msg = msg
223         self.filename = filename
224         self.line = line
225         self.col = col
226
227     def __str__(self):
228         return ':'.join([str(part) for part in (self.filename, self.line, self.col, self.msg) if part != None])
229         
230
231 class Scanner:
232     """Stateless scanner."""
233
234     # should be overriden by derived classes
235     tokens = []
236     symbols = {}
237     literals = {}
238     ignorecase = False
239
240     def __init__(self):
241         flags = re.DOTALL
242         if self.ignorecase:
243             flags |= re.IGNORECASE
244         self.tokens_re = re.compile(
245             '|'.join(['(' + regexp + ')' for type, regexp, test_lit in self.tokens]),
246              flags
247         )
248
249     def next(self, buf, pos):
250         if pos >= len(buf):
251             return EOF, '', pos
252         mo = self.tokens_re.match(buf, pos)
253         if mo:
254             text = mo.group()
255             type, regexp, test_lit = self.tokens[mo.lastindex - 1]
256             pos = mo.end()
257             if test_lit:
258                 type = self.literals.get(text, type)
259             return type, text, pos
260         else:
261             c = buf[pos]
262             return self.symbols.get(c, None), c, pos + 1
263
264
265 class Token:
266
267     def __init__(self, type, text, line, col):
268         self.type = type
269         self.text = text
270         self.line = line
271         self.col = col
272
273
274 class Lexer:
275
276     # should be overriden by derived classes
277     scanner = None
278     tabsize = 8
279
280     newline_re = re.compile(r'\r\n?|\n')
281
282     def __init__(self, buf = None, pos = 0, filename = None, fp = None):
283         if fp is not None:
284             try:
285                 fileno = fp.fileno()
286                 length = os.path.getsize(fp.name)
287                 import mmap
288             except:
289                 # read whole file into memory
290                 buf = fp.read()
291                 pos = 0
292             else:
293                 # map the whole file into memory
294                 if length:
295                     # length must not be zero
296                     buf = mmap.mmap(fileno, length, access = mmap.ACCESS_READ)
297                     pos = os.lseek(fileno, 0, 1)
298                 else:
299                     buf = ''
300                     pos = 0
301
302             if filename is None:
303                 try:
304                     filename = fp.name
305                 except AttributeError:
306                     filename = None
307
308         self.buf = buf
309         self.pos = pos
310         self.line = 1
311         self.col = 1
312         self.filename = filename
313
314     def next(self):
315         while True:
316             # save state
317             pos = self.pos
318             line = self.line
319             col = self.col
320
321             type, text, endpos = self.scanner.next(self.buf, pos)
322             assert pos + len(text) == endpos
323             self.consume(text)
324             type, text = self.filter(type, text)
325             self.pos = endpos
326
327             if type == SKIP:
328                 continue
329             elif type is None:
330                 msg = 'unexpected char '
331                 if text >= ' ' and text <= '~':
332                     msg += "'%s'" % text
333                 else:
334                     msg += "0x%X" % ord(text)
335                 raise ParseError(msg, self.filename, line, col)
336             else:
337                 break
338         return Token(type = type, text = text, line = line, col = col)
339
340     def consume(self, text):
341         # update line number
342         pos = 0
343         for mo in self.newline_re.finditer(text, pos):
344             self.line += 1
345             self.col = 1
346             pos = mo.end()
347
348         # update column number
349         while True:
350             tabpos = text.find('\t', pos)
351             if tabpos == -1:
352                 break
353             self.col += tabpos - pos
354             self.col = ((self.col - 1)//self.tabsize + 1)*self.tabsize + 1
355             pos = tabpos + 1
356         self.col += len(text) - pos
357
358
359 class Parser:
360
361     def __init__(self, lexer):
362         self.lexer = lexer
363         self.lookahead = self.lexer.next()
364
365     def match(self, type):
366         return self.lookahead.type == type
367
368     def skip(self, type):
369         while not self.match(type):
370             self.consume()
371
372     def error(self):
373         raise ParseError(
374             msg = 'unexpected token %r' % self.lookahead.text, 
375             filename = self.lexer.filename, 
376             line = self.lookahead.line, 
377             col = self.lookahead.col)
378
379     def consume(self, type = None):
380         if type is not None and not self.match(type):
381             self.error()
382         token = self.lookahead
383         self.lookahead = self.lexer.next()
384         return token
385
386
387 #######################################################################
388
389 ID, NUMBER, HEXNUM, STRING, PRAGMA, LPAREN, RPAREN, LCURLY, RCURLY, COMMA, AMP, EQUAL, VERT, BLOB = xrange(14)
390
391
392 class CallScanner(Scanner):
393
394     # token regular expression table
395     tokens = [
396         # whitespace
397         (SKIP, r'[ \t\f\r\n\v]+', False),
398
399         # Alphanumeric IDs
400         (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True),
401
402         # Numeric IDs
403         (HEXNUM, r'-?0x[0-9a-fA-F]+', False),
404         
405         # Numeric IDs
406         (NUMBER, r'-?(?:\.[0-9]+|[0-9]+(?:\.[0-9]*)?)(?:[eE][-+][0-9]+)?', False),
407
408         # String IDs
409         (STRING, r'"[^"\\]*(?:\\.[^"\\]*)*"', False),
410         
411         # Pragma
412         (PRAGMA, r'#[^\r\n]*', False),
413     ]
414
415     # symbol table
416     symbols = {
417         '(': LPAREN,
418         ')': RPAREN,
419         '{': LCURLY,
420         '}': RCURLY,
421         ',': COMMA,
422         '&': AMP,
423         '=': EQUAL,
424         '|': VERT,
425     }
426
427     # literal table
428     literals = {
429         'blob': BLOB
430     }
431
432
433 class CallLexer(Lexer):
434
435     scanner = CallScanner()
436
437     def filter(self, type, text):
438         if type == STRING:
439             text = text[1:-1]
440
441             # line continuations
442             text = text.replace('\\\r\n', '')
443             text = text.replace('\\\r', '')
444             text = text.replace('\\\n', '')
445             
446             # quotes
447             text = text.replace('\\"', '"')
448
449             type = ID
450
451         return type, text
452
453
454 class TraceParser(Parser):
455
456     def __init__(self, stream):
457         lexer = CallLexer(fp = stream)
458         Parser.__init__(self, lexer)
459
460     def eof(self):
461         return self.match(EOF)
462
463     def parse(self):
464         while not self.eof():
465             self.parse_element()
466         return TraceMatcher(self.calls)
467
468     def parse_element(self):
469         if self.lookahead.type == PRAGMA:
470             # TODO
471             token = self.consume()
472             self.handlePragma(token.text)
473         else:
474             self.parse_call()
475
476     def parse_call(self):
477         while self.lookahead.type == PRAGMA:
478             # TODO
479             token = self.consume()
480             print token.text
481
482         if self.lookahead.type == NUMBER:
483             token = self.consume()
484             callNo = int(token.text)
485         else:
486             callNo = None
487         
488         functionName = self.consume(ID).text
489
490         args = self.parse_sequence(LPAREN, RPAREN, self.parse_pair)
491
492         if self.match(EQUAL):
493             self.consume(EQUAL)
494             ret = self.parse_value()
495         else:
496             ret = None
497
498         self.handleCall(functionName, args, ret)
499
500     def parse_pair(self):
501         '''Parse a `name = value` pair.'''
502         name = self.consume(ID).text
503         self.consume(EQUAL)
504         value = self.parse_value()
505         return name, value
506
507     def parse_opt_pair(self):
508         '''Parse an optional `name = value` pair.'''
509         if self.match(ID):
510             name = self.consume(ID).text
511             if self.match(EQUAL):
512                 self.consume(EQUAL)
513                 value = self.parse_value()
514             else:
515                 value = name
516                 name = None
517         else:
518             name = None
519             value = self.parse_value()
520         if name is None:
521             return value
522         else:
523             return name, value
524
525     def parse_value(self):
526         value = self._parse_value()
527         if self.match(VERT):
528             flags = [value]
529             while self.match(VERT):
530                 self.consume()
531                 value = self._parse_value()
532                 flags.append(value)
533             return self.handleBitmask(flags)
534         else:
535             return value
536
537     def _parse_value(self):
538         if self.match(AMP):
539             self.consume()
540             value = [self.parse_value()]
541             return self.handleArray(value)
542         elif self.match(ID):
543             token = self.consume()
544             value = token.text
545             return self.handleID(value)
546         elif self.match(STRING):
547             token = self.consume()
548             value = token.text
549             return self.handleString(value)
550         elif self.match(NUMBER):
551             token = self.consume()
552             value = float(token.text)
553             return self.handleFloat(value)
554         elif self.match(HEXNUM):
555             token = self.consume()
556             value = int(token.text, 16)
557             return self.handleInt(value)
558         elif self.match(LCURLY):
559             value = self.parse_sequence(LCURLY, RCURLY, self.parse_opt_pair)
560             if len(value) and isinstance(value[0], tuple):
561                 value = dict(value)
562                 return self.handleStruct(value)
563             else:
564                 return self.handleArray(value)
565         elif self.match(BLOB):
566             token = self.consume()
567             self.consume(LPAREN)
568             length = self.consume()
569             self.consume(RPAREN)
570             return self.handleBlob(length)
571         else:
572             self.error()
573
574     def parse_sequence(self, ltype, rtype, elementParser):
575         '''Parse a comma separated list'''
576
577         elements = []
578
579         self.consume(ltype)
580         sep = None
581         while not self.match(rtype):
582             if sep is None:
583                 sep = COMMA
584             else:
585                 self.consume(sep)
586             element = elementParser()
587             elements.append(element)
588         self.consume(rtype)
589
590         return elements
591     
592     def handleID(self, value):
593         raise NotImplementedError
594
595     def handleInt(self, value):
596         raise NotImplementedError
597
598     def handleFloat(self, value):
599         raise NotImplementedError
600
601     def handleString(self, value):
602         raise NotImplementedError
603
604     def handleBitmask(self, value):
605         raise NotImplementedError
606
607     def handleArray(self, value):
608         raise NotImplementedError
609
610     def handleStruct(self, value):
611         raise NotImplementedError
612
613     def handleBlob(self, length):
614         raise NotImplementedError
615         # TODO
616         return WildcardMatcher()
617
618     def handleCall(self, functionName, args, ret):
619         raise NotImplementedError
620
621     def handlePragma(self, line):
622         pass
623
624
625 class RefTraceParser(TraceParser):
626
627     def __init__(self, stream):
628         TraceParser.__init__(self, stream)
629         self.calls = []
630
631     def parse(self):
632         TraceParser.parse(self)
633         return TraceMatcher(self.calls)
634
635     def handleID(self, value):
636         return LiteralMatcher(value)
637
638     def handleInt(self, value):
639         return LiteralMatcher(value)
640
641     def handleFloat(self, value):
642         return ApproxMatcher(value)
643
644     def handleString(self, value):
645         return LiteralMatcher(value)
646
647     def handleBitmask(self, value):
648         return BitmaskMatcher(value)
649
650     def handleArray(self, value):
651         return ArrayMatcher(value)
652
653     def handleStruct(self, value):
654         return StructMatcher(value)
655
656     def handleBlob(self, length):
657         # TODO
658         return WildcardMatcher()
659
660     def handleCall(self, functionName, args, ret):
661         call = CallMatcher(functionName, args, ret)
662         self.calls.append(call)
663
664
665 class SrcTraceParser(TraceParser):
666
667     def __init__(self, stream):
668         TraceParser.__init__(self, stream)
669         self.calls = []
670
671     def parse(self):
672         TraceParser.parse(self)
673         return TraceMatcher(self.calls)
674
675     def handleID(self, value):
676         return value
677
678     def handleInt(self, value):
679         return int(value)
680
681     def handleFloat(self, value):
682         return float(value)
683
684     def handleString(self, value):
685         return value
686
687     def handleBitmask(self, value):
688         return value
689
690     def handleArray(self, elements):
691         return list(elements)
692
693     def handleStruct(self, members):
694         return dict(members)
695
696     def handleBlob(self, length):
697         # TODO
698         return None
699
700     def handleCall(self, functionName, args, ret):
701         call = (functionName, args, ret)
702         self.calls.append(call)
703
704
705 def main():
706     # Parse command line options
707     optparser = optparse.OptionParser(
708         usage='\n\t%prog [OPTIONS] REF_TRACE SRC_TRACE',
709         version='%%prog')
710     optparser.add_option(
711         '-v', '--verbose',
712         action="store_true",
713         dest="verbose", default=False,
714         help="verbose output")
715     (options, args) = optparser.parse_args(sys.argv[1:])
716
717     if len(args) != 2:
718         optparser.error('wrong number of arguments')
719
720     refParser = RefTraceParser(open(args[0], 'rt'))
721     refTrace = refParser.parse()
722     sys.stdout.write(str(refTrace))
723     srcParser = SrcTraceParser(open(args[1], 'rt'))
724     srcTrace = srcParser.parse()
725     refTrace.match(srcTrace)
726
727
728 if __name__ == '__main__':
729     main()