Use the new trace checker.
authorJosé Fonseca <jose.r.fonseca@gmail.com>
Fri, 23 Nov 2012 07:52:17 +0000 (07:52 +0000)
committerJosé Fonseca <jose.r.fonseca@gmail.com>
Fri, 23 Nov 2012 07:52:17 +0000 (07:52 +0000)
app_driver.py
checker.py

index 2a935d064af68fb481a70435def33975f76effbc..2e05a57b00fd3dabecd26abca7f299d7dbc0d792 100755 (executable)
@@ -45,96 +45,95 @@ except ImportError:
 from base_driver import *
 
 
-class TraceChecker:
+import checker
 
-    def __init__(self, srcStream, refFileName, verbose=False):
-        self.srcStream = srcStream
-        self.refFileName = refFileName
-        if refFileName:
-            self.refStream = open(refFileName, 'rt')
-        else:
-            self.refStream = None
-        self.verbose = verbose
-        self.doubleBuffer = False
-        self.callNo = 0
-        self.refLine = ''
-        self.images = []
-        self.states = []
 
-    call_re = re.compile(r'^([0-9]+) (\w+)\(')
+class RefTraceParser(checker.RefTraceParser):
 
-    def check(self):
-
-        swapbuffers = 0
-        flushes = 0
-
-        srcLines = []
-        self.consumeRefLine()
-        for line in self.srcStream:
-            line = line.rstrip()
-            if self.verbose:
-                sys.stdout.write(line + '\n')
-            mo = self.call_re.match(line)
-            if mo:
-                self.callNo = int(mo.group(1))
-                function_name = mo.group(2)
-                if function_name.find('SwapBuffers') != -1 or \
-                   line.find('kCGLPFADoubleBuffer') != -1:
-                    swapbuffers += 1
-                if function_name in ('glFlush', 'glFinish'):
-                    flushes += 1
-                srcLine = line[mo.start(2):]
-            else:
-                srcLine = line
-            if self.refLine:
-                if srcLine == self.refLine:
-                    self.consumeRefLine()
-                    srcLines = []
-                else:
-                    srcLines.append(srcLine)
+    def __init__(self, fileName):
+        checker.RefTraceParser.__init__(self, open(fileName, 'rt'))
+        self.fileName = fileName
+        self.images = []
+        self.states = []
+        self.pragmaNo = 0
 
-        if self.refLine:
-            if srcLines:
-                fail('missing call `%s` (found `%s`)' % (self.refLine, srcLines[0]))
+    def handlePragma(self, line):
+        if self.calls:
+            lastCall = self.calls[-1]
+            if lastCall.callNo is None:
+                paramName = 'pragma%u' % self.pragmaNo
+                lastCall.callNo = checker.WildcardMatcher(paramName)
             else:
-                fail('missing call %s' % self.refLine)
-
-        if swapbuffers:
-            self.doubleBuffer = True
+                paramName = lastCall.callNo.name
         else:
-            self.doubleBuffer = False
+            paramName = 0
+            self.pragmaNo += 1
 
-    def consumeRefLine(self):
-        if not self.refStream:
-            self.refLine = ''
-            return
-
-        while True:
-            line = self.refStream.readline()
-            if not line:
-                break
-            line = line.rstrip()
-            if line.startswith('#'):
-                self.handlePragma(line)
-            else:
-                break
-        self.refLine = line
-
-    def handlePragma(self, line):
         pragma, rest = line.split(None, 1)
         if pragma == '#image':
             imageFileName = self.getAbsPath(rest)
-            self.images.append((self.callNo, imageFileName))
+            self.images.append((paramName, imageFileName))
         elif pragma == '#state':
             stateFileName = self.getAbsPath(rest)
-            self.states.append((self.callNo, stateFileName))
+            self.states.append((paramName, stateFileName))
         else:
             assert False
 
     def getAbsPath(self, path):
         '''Get the absolute from a path relative to the reference filename'''
-        return os.path.abspath(os.path.join(os.path.dirname(self.refFileName), path))
+        return os.path.abspath(os.path.join(os.path.dirname(self.fileName), path))
+
 
+class SrcTraceParser(checker.SrcTraceParser):
+
+    def __init__(self, stream):
+        checker.SrcTraceParser.__init__(self, stream)
+        self.swapbuffers = 0
+
+    def handleCall(self, callNo, functionName, args, ret):
+        checker.SrcTraceParser.handleCall(self, callNo, functionName, args, ret)
+
+        if functionName.find('SwapBuffers') != -1 or \
+           repr(args).find('kCGLPFADoubleBuffer') != -1:
+            self.swapbuffers += 1
+
+
+class TraceChecker:
+
+    def __init__(self, srcStream, refFileName):
+        self.srcStream = srcStream
+        self.refFileName = refFileName
+        self.doubleBuffer = False
+        self.callNo = 0
+        self.images = []
+        self.states = []
+
+    def check(self):
+        srcParser = SrcTraceParser(self.srcStream)
+        srcTrace = srcParser.parse()
+        self.doubleBuffer = srcParser.swapbuffers > 0
+
+        if self.refFileName:
+            refParser = RefTraceParser(self.refFileName)
+            refTrace = refParser.parse()
+
+            try:
+                mo = refTrace.match(srcTrace)
+            except checker.TraceMismatch, ex:
+                self.fail(str(ex))
+
+            for paramName, imageFileName in refParser.images:
+                if isinstance(paramName, int):
+                    callNo = paramName
+                else:
+                    callNo = mo.params[paramName]
+                self.images.append((callNo, imageFileName))
+            for paramName, stateFileName in refParser.states:
+                if isinstance(paramName, int):
+                    callNo = paramName
+                else:
+                    callNo = mo.params[paramName]
+                self.states.append((callNo, stateFileName))
 
 
 class AppDriver(Driver):
@@ -241,7 +240,7 @@ class AppDriver(Driver):
         cmd = [options.apitrace, 'dump', '--color=never', self.trace_file]
         p = popen(cmd, stdout=subprocess.PIPE)
 
-        checker = TraceChecker(p.stdout, self.ref_dump, self.verbose)
+        checker = TraceChecker(p.stdout, self.ref_dump)
         checker.check()
         p.wait()
         if p.returncode != 0:
index 08ef9d324c9d12bc1a2134a491131d3e11376e3a..2f69dd369590c0170ca9943ef60a1998da1f84cc 100755 (executable)
@@ -202,6 +202,11 @@ class CallMatcher(Matcher):
         return s
 
 
+class TraceMismatch(Exception):
+
+    pass
+
+
 class TraceMatcher:
 
     def __init__(self, calls):
@@ -217,9 +222,9 @@ class TraceMatcher:
                     srcCall = srcCalls.next()
                 except StopIteration:
                     if skippedSrcCalls:
-                        raise Exception('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
+                        raise TraceMismatch('missing call `%s` (found `%s`)' % (refCall, skippedSrcCalls[0]))
                     else:
-                        raise Exception('missing call %s' % refCall)
+                        raise TraceMismatch('missing call %s' % refCall)
                 if refCall.match(srcCall, mo):
                     break
                 else:
@@ -490,17 +495,12 @@ class TraceParser(Parser):
 
     def parse_element(self):
         if self.lookahead.type == PRAGMA:
-            # TODO
             token = self.consume()
             self.handlePragma(token.text)
         else:
             self.parse_call()
 
     def parse_call(self):
-        while self.lookahead.type == PRAGMA:
-            # TODO
-            token = self.consume()
-
         if self.lookahead.type == NUMBER:
             token = self.consume()
             callNo = self.handleInt(int(token.text))
@@ -649,7 +649,7 @@ class TraceParser(Parser):
         raise NotImplementedError
 
     def handlePragma(self, line):
-        pass
+        raise NotImplementedError
 
 
 class RefTraceParser(TraceParser):
@@ -689,6 +689,9 @@ class RefTraceParser(TraceParser):
     def handleCall(self, callNo, functionName, args, ret):
         call = CallMatcher(callNo, functionName, args, ret)
         self.calls.append(call)
+    
+    def handlePragma(self, line):
+        pass
 
 
 class SrcTraceParser(TraceParser):