]> git.cworth.org Git - apitrace/blobdiff - wrappers/dxgitrace.py
dxgitrace: Handle simultanous D3D11 maps.
[apitrace] / wrappers / dxgitrace.py
index 122b773234149dcd5e8dacba22da2f0d19c8eeca..45767a0f6761410ac9908599496e7a9ffb9b91c4 100644 (file)
 
 import sys
 from dlltrace import DllTracer
+from trace import getWrapperInterfaceName
 from specs import stdapi
 from specs.stdapi import API
-from specs.dxgi import dxgi
-from specs.d3d10 import d3d10
-from specs.d3d10_1 import d3d10_1
-from specs.d3d11 import d3d11
+from specs import dxgi
+from specs import d3d10
+from specs import d3d10_1
+from specs import d3d11
 
 
 class D3DCommonTracer(DllTracer):
@@ -73,12 +74,19 @@ class D3DCommonTracer(DllTracer):
             return
 
         DllTracer.serializeArgValue(self, function, arg)
+
+    # Interfaces that need book-keeping for maps
+    mapInterfaces = (
+        dxgi.IDXGISurface,
+        d3d10.ID3D10Resource,
+        d3d11.ID3D11Resource,
+    )
     
     def enumWrapperInterfaceVariables(self, interface):
         variables = DllTracer.enumWrapperInterfaceVariables(self, interface)
         
         # Add additional members to track maps
-        if interface.getMethodByName('Map') is not None:
+        if interface.hasBase(*self.mapInterfaces):
             variables += [
                 ('_MAP_DESC', '_MapDesc', None),
             ]
@@ -86,7 +94,19 @@ class D3DCommonTracer(DllTracer):
         return variables
 
     def implementWrapperInterfaceMethodBody(self, interface, base, method):
+        if method.name in ('Map', 'Unmap'):
+            # On D3D11 Map/Unmap is not a resource method, but a context method instead.
+            resourceArg = method.getArgByName('pResource')
+            if resourceArg is None:
+                pResource = 'this'
+            else:
+                wrapperInterfaceName = getWrapperInterfaceName(resourceArg.type.type)
+                print '    %s * _pResource = static_cast<%s*>(%s);' % (wrapperInterfaceName, wrapperInterfaceName, resourceArg.name)
+                pResource = '_pResource'
+
         if method.name == 'Unmap':
+            print '    _MAP_DESC _MapDesc = %s->_MapDesc;' % pResource
+            #print r'    os::log("%%p -> %%p+%%lu\n", %s,_MapDesc.pData, (unsigned long)_MapDesc.Size);' % pResource
             print '    if (_MapDesc.Size && _MapDesc.pData) {'
             self.emit_memcpy('_MapDesc.pData', '_MapDesc.pData', '_MapDesc.Size')
             print '    }'
@@ -95,12 +115,15 @@ class D3DCommonTracer(DllTracer):
 
         if method.name == 'Map':
             # NOTE: recursive locks are explicitely forbidden
+            print '    _MAP_DESC _MapDesc;'
             print '    if (SUCCEEDED(_result)) {'
             print '        _getMapDesc(_this, %s, _MapDesc);' % ', '.join(method.argNames())
             print '    } else {'
             print '        _MapDesc.pData = NULL;'
             print '        _MapDesc.Size = 0;'
             print '    }'
+            #print r'    os::log("%%p <- %%p+%%lu\n", %s,_MapDesc.pData, (unsigned long)_MapDesc.Size);' % pResource
+            print '    %s->_MapDesc = _MapDesc;' % pResource
 
 
 if __name__ == '__main__':
@@ -117,24 +140,24 @@ if __name__ == '__main__':
     api = API()
     
     if moduleNames:
-        api.addModule(dxgi)
+        api.addModule(dxgi.dxgi)
     
     if 'd3d10' in moduleNames:
         if 'd3d10_1' in moduleNames:
             print r'#include "d3d10_1imports.hpp"'
-            api.addModule(d3d10_1)
+            api.addModule(d3d10_1.d3d10_1)
         else:
             print r'#include "d3d10imports.hpp"'
         print r'#include "d3d10size.hpp"'
-        api.addModule(d3d10)
+        api.addModule(d3d10.d3d10)
 
     if 'd3d11' in moduleNames:
         print r'#include "d3d11imports.hpp"'
         if 'd3d11_1' in moduleNames:
             print '#include <d3d11_1.h>'
-            import specs.d3d11_1
+            from specs import d3d11_1
         print r'#include "d3d11size.hpp"'
-        api.addModule(d3d11)
+        api.addModule(d3d11.d3d11)
 
     tracer = D3DCommonTracer()
     tracer.traceApi(api)