From 84cea3b6f23efbd60bf57a642459e402247f8902 Mon Sep 17 00:00:00 2001
From: =?utf8?q?Jos=C3=A9=20Fonseca?= <jose.r.fonseca@gmail.com>
Date: Wed, 9 May 2012 21:12:30 +0100
Subject: [PATCH] Fix D3D11 tracing with D3D11_CREATE_DEVICE_DEBUG flag.

---
 retrace/retrace.py | 13 ++++++------
 specs/d3d11.py     |  6 ++++++
 specs/stdapi.py    |  3 ++-
 wrappers/d3d11.def |  4 ++++
 wrappers/trace.py  | 49 +++++++++++++++++++++++++---------------------
 5 files changed, 46 insertions(+), 29 deletions(-)

diff --git a/retrace/retrace.py b/retrace/retrace.py
index 9828cfb..51da6fe 100644
--- a/retrace/retrace.py
+++ b/retrace/retrace.py
@@ -486,20 +486,21 @@ class Retracer:
 
         functions = filter(self.filterFunction, api.functions)
         for function in functions:
-            if function.sideeffects:
+            if function.sideeffects and not function.internal:
                 self.retraceFunction(function)
         interfaces = api.getAllInterfaces()
         for interface in interfaces:
             for method in interface.iterMethods():
-                if method.sideeffects:
+                if method.sideeffects and not method.internal:
                     self.retraceInterfaceMethod(interface, method)
 
         print 'const retrace::Entry %s[] = {' % self.table_name
         for function in functions:
-            if function.sideeffects:
-                print '    {"%s", &retrace_%s},' % (function.name, function.name)
-            else:
-                print '    {"%s", &retrace::ignore},' % (function.name,)
+            if not function.internal:
+                if function.sideeffects:
+                    print '    {"%s", &retrace_%s},' % (function.name, function.name)
+                else:
+                    print '    {"%s", &retrace::ignore},' % (function.name,)
         for interface in interfaces:
             for method in interface.iterMethods():                
                 if method.sideeffects:
diff --git a/specs/d3d11.py b/specs/d3d11.py
index 33dddf3..8a681bd 100644
--- a/specs/d3d11.py
+++ b/specs/d3d11.py
@@ -1218,6 +1218,12 @@ d3d11 = API("d3d11")
 d3d11.addFunctions([
     StdFunction(HRESULT, "D3D11CreateDevice", [(ObjPointer(IDXGIAdapter), "pAdapter"), (D3D_DRIVER_TYPE, "DriverType"), (HMODULE, "Software"), (D3D11_CREATE_DEVICE_FLAG, "Flags"), (Array(Const(D3D_FEATURE_LEVEL), "FeatureLevels"), "pFeatureLevels"), (UINT, "FeatureLevels"), (UINT, "SDKVersion"), Out(Pointer(ObjPointer(ID3D11Device)), "ppDevice"), Out(Pointer(D3D_FEATURE_LEVEL), "pFeatureLevel"), Out(Pointer(ObjPointer(ID3D11DeviceContext)), "ppImmediateContext")]),
     StdFunction(HRESULT, "D3D11CreateDeviceAndSwapChain", [(ObjPointer(IDXGIAdapter), "pAdapter"), (D3D_DRIVER_TYPE, "DriverType"), (HMODULE, "Software"), (D3D11_CREATE_DEVICE_FLAG, "Flags"), (Array(Const(D3D_FEATURE_LEVEL), "FeatureLevels"), "pFeatureLevels"), (UINT, "FeatureLevels"), (UINT, "SDKVersion"), (Pointer(Const(DXGI_SWAP_CHAIN_DESC)), "pSwapChainDesc"), Out(Pointer(ObjPointer(IDXGISwapChain)), "ppSwapChain"), Out(Pointer(ObjPointer(ID3D11Device)), "ppDevice"), Out(Pointer(D3D_FEATURE_LEVEL), "pFeatureLevel"), Out(Pointer(ObjPointer(ID3D11DeviceContext)), "ppImmediateContext")]),
+
+    # XXX: Undocumented functions, called by d3d11sdklayers.dll when D3D11_CREATE_DEVICE_DEBUG is set
+    StdFunction(HRESULT, "D3D11CoreRegisterLayers", [LPCVOID, DWORD], internal=True),
+    StdFunction(SIZE_T, "D3D11CoreGetLayeredDeviceSize", [LPCVOID, DWORD], internal=True),
+    StdFunction(HRESULT, "D3D11CoreCreateLayeredDevice", [LPCVOID, DWORD, LPCVOID, (REFIID, "riid"), Out(Pointer(ObjPointer(Void)), "ppvObj")], internal=True),
+    StdFunction(HRESULT, "D3D11CoreCreateDevice", [DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD], internal=True),
 ])
 
 d3d11.addInterfaces([
diff --git a/specs/stdapi.py b/specs/stdapi.py
index 57e9aa6..b2bd904 100644
--- a/specs/stdapi.py
+++ b/specs/stdapi.py
@@ -332,7 +332,7 @@ class Function:
     # 0-3 are reserved to memcpy, malloc, free, and realloc
     __id = 4
 
-    def __init__(self, type, name, args, call = '', fail = None, sideeffects=True):
+    def __init__(self, type, name, args, call = '', fail = None, sideeffects=True, internal=False):
         self.id = Function.__id
         Function.__id += 1
 
@@ -356,6 +356,7 @@ class Function:
         self.call = call
         self.fail = fail
         self.sideeffects = sideeffects
+        self.internal = internal
 
     def prototype(self, name=None):
         if name is not None:
diff --git a/wrappers/d3d11.def b/wrappers/d3d11.def
index 2da8fc2..49333b6 100644
--- a/wrappers/d3d11.def
+++ b/wrappers/d3d11.def
@@ -1,5 +1,9 @@
 LIBRARY	"d3d11"
 
 EXPORTS
+        D3D11CoreCreateDevice
+        D3D11CoreCreateLayeredDevice
+        D3D11CoreGetLayeredDeviceSize
+        D3D11CoreRegisterLayers
         D3D11CreateDevice
         D3D11CreateDeviceAndSwapChain
diff --git a/wrappers/trace.py b/wrappers/trace.py
index d209051..659fcd9 100644
--- a/wrappers/trace.py
+++ b/wrappers/trace.py
@@ -396,12 +396,13 @@ class Tracer:
     def traceFunctionDecl(self, function):
         # Per-function declarations
 
-        if function.args:
-            print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
-        else:
-            print 'static const char ** _%s_args = NULL;' % (function.name,)
-        print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
-        print
+        if not function.internal:
+            if function.args:
+                print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
+            else:
+                print 'static const char ** _%s_args = NULL;' % (function.name,)
+            print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
+            print
 
     def isFunctionPublic(self, function):
         return True
@@ -421,23 +422,25 @@ class Tracer:
         print
 
     def traceFunctionImplBody(self, function):
-        print '    unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
-        for arg in function.args:
-            if not arg.output:
-                self.unwrapArg(function, arg)
-                self.serializeArg(function, arg)
-        print '    trace::localWriter.endEnter();'
+        if not function.internal:
+            print '    unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
+            for arg in function.args:
+                if not arg.output:
+                    self.unwrapArg(function, arg)
+                    self.serializeArg(function, arg)
+            print '    trace::localWriter.endEnter();'
         self.invokeFunction(function)
-        print '    trace::localWriter.beginLeave(_call);'
-        for arg in function.args:
-            if arg.output:
-                self.serializeArg(function, arg)
-                self.wrapArg(function, arg)
-        if function.type is not stdapi.Void:
-            self.serializeRet(function, "_result")
-        print '    trace::localWriter.endLeave();'
-        if function.type is not stdapi.Void:
-            self.wrapRet(function, "_result")
+        if not function.internal:
+            print '    trace::localWriter.beginLeave(_call);'
+            for arg in function.args:
+                if arg.output:
+                    self.serializeArg(function, arg)
+                    self.wrapArg(function, arg)
+            if function.type is not stdapi.Void:
+                self.serializeRet(function, "_result")
+            print '    trace::localWriter.endLeave();'
+            if function.type is not stdapi.Void:
+                self.wrapRet(function, "_result")
 
     def invokeFunction(self, function, prefix='_', suffix=''):
         if function.type is stdapi.Void:
@@ -566,6 +569,8 @@ class Tracer:
         print
 
     def implementWrapperInterfaceMethodBody(self, interface, base, method):
+        assert not method.internal
+
         print '    static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
         print '    static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
 
-- 
2.45.2