]> git.cworth.org Git - apitrace-tests/blobdiff - tracematch.py
Be more lenient with shader matching.
[apitrace-tests] / tracematch.py
index 9ffd310880323763b0c7705712cdd8dd1ad0ea71..45228bfd250fbd94c676e21526ae133b83075caf 100755 (executable)
@@ -112,6 +112,27 @@ class ApproxMatcher(Matcher):
         return repr(self.refValue)
 
 
+class StringMatcher(Matcher):
+
+    def __init__(self, refValue):
+        self.refValue = refValue
+
+    def isShaderDisassembly(self, value):
+        return value.find('// Generated by Microsoft (R) D3D Shader Disassembler\n') != -1
+
+    def normalizeShaderDisassembly(self, value):
+        # Unfortunately slightly different disassemblers produce different output
+        return '\n'.join([line.strip() for line in value.split('\n') if line.strip() and not line.startswith('//')])
+
+    def match(self, value, mo):
+        if self.isShaderDisassembly(self.refValue) and self.isShaderDisassembly(value):
+            return self.normalizeShaderDisassembly(self.refValue) == self.normalizeShaderDisassembly(value)
+        return self.refValue == value
+
+    def __str__(self):
+        return repr(self.refValue)
+
+
 class BitmaskMatcher(Matcher):
 
     def __init__(self, refElements):
@@ -398,6 +419,9 @@ class Lexer:
             self.col = ((self.col - 1)//self.tabsize + 1)*self.tabsize + 1
             pos = tabpos + 1
         self.col += len(text) - pos
+    
+    def filter(self, type, text):
+        return type, text
 
 
 class Parser:
@@ -440,6 +464,9 @@ class CallScanner(Scanner):
         # whitespace
         (SKIP, r'[ \t\f\r\n\v]+', False),
 
+        # comments
+        (SKIP, r'//[^\r\n]*', False),
+
         # Alphanumeric IDs
         (ID, r'[a-zA-Z_][a-zA-Z0-9_]*(?:::[a-zA-Z_][a-zA-Z0-9_]*)?', True),
 
@@ -483,16 +510,9 @@ class CallLexer(Lexer):
         if type == STRING:
             text = text[1:-1]
 
-            # line continuations
-            text = text.replace('\\\r\n', '')
-            text = text.replace('\\\r', '')
-            text = text.replace('\\\n', '')
-            
             # quotes
             text = text.replace('\\"', '"')
 
-            type = ID
-
         return type, text
 
 
@@ -542,13 +562,14 @@ class TraceParser(Parser):
     def parse_opt_pair(self):
         '''Parse an optional `name = value` pair.'''
         if self.match(ID):
-            name = self.consume(ID).text
+            token = self.consume(ID)
             if self.match(EQUAL):
                 self.consume(EQUAL)
+                name = token.text
                 value = self.parse_value()
             else:
-                value = name
                 name = None
+                value = self.handleID(token.text)
         else:
             name = None
             value = self.parse_value()
@@ -595,11 +616,12 @@ class TraceParser(Parser):
             return self.handleString(value)
         elif self.match(NUMBER):
             token = self.consume()
-            if '.' in token.text:
+            try:
+                value = int(token.text)
+            except ValueError:
                 value = float(token.text)
                 return self.handleFloat(value)
             else:
-                value = int(token.text)
                 return self.handleInt(value)
         elif self.match(HEXNUM):
             token = self.consume()
@@ -697,7 +719,7 @@ class RefTraceParser(TraceParser):
         return ApproxMatcher(value)
 
     def handleString(self, value):
-        return LiteralMatcher(value)
+        return StringMatcher(value)
 
     def handleBitmask(self, value):
         return BitmaskMatcher(value)