]> git.cworth.org Git - apitrace/commitdiff
inject: Use DLL injection for D3D10+ tracing.
authorJosé Fonseca <jfonseca@vmware.com>
Tue, 4 Dec 2012 13:23:03 +0000 (13:23 +0000)
committerJosé Fonseca <jfonseca@vmware.com>
Tue, 4 Dec 2012 13:23:50 +0000 (13:23 +0000)
CMakeLists.txt
common/trace_tools_trace.cpp
inject/CMakeLists.txt [new file with mode: 0644]
inject/inject.h [new file with mode: 0644]
inject/injectee.cpp [new file with mode: 0644]
inject/injector.cpp [new file with mode: 0644]

index 2ace8ac44576231b41cce7ef0017ef62bd6404c7..961e73e84607be664dcb548e34d7240f3f52a3ef 100644 (file)
@@ -357,6 +357,9 @@ add_subdirectory (retrace)
 # CLI
 
 if (ENABLE_CLI)
+    if (WIN32)
+        add_subdirectory (inject)
+    endif ()
     add_subdirectory (cli)
 endif ()
 
index 4c0082d7e51771e1d540c9ce53e5372ca27ef713..77fc5520ba90a3797bb5a1e3b6d75d50b9d91b21 100644 (file)
@@ -82,81 +82,45 @@ copyWrapper(const os::String & wrapperPath,
 }
 
 
-static const char *glWrappers[] = {
-    GL_TRACE_WRAPPER,
-    NULL
-};
-
-#ifdef EGL_TRACE_WRAPPER
-static const char *eglWrappers[] = {
-    EGL_TRACE_WRAPPER,
-    NULL
-};
-#endif
-
-#ifdef _WIN32
-static const char *d3d7Wrappers[] = {
-    "ddraw.dll",
-    NULL
-};
-
-static const char *d3d8Wrappers[] = {
-    "d3d8.dll",
-    NULL
-};
-
-static const char *d3d9Wrappers[] = {
-    "d3d9.dll",
-    NULL
-};
-
-static const char *dxgiWrappers[] = {
-    "dxgitrace.dll",
-    //"dxgi.dll",
-    "d3d10.dll",
-    "d3d10_1.dll",
-    "d3d11.dll",
-    NULL
-};
-#endif
-
 int
 traceProgram(API api,
              char * const *argv,
              const char *output,
              bool verbose)
 {
-    const char **wrapperFilenames;
-    unsigned numWrappers;
+    const char *wrapperFilename;
+    std::vector<const char *> args;
     int status = 1;
 
     /*
      * TODO: simplify code
      */
 
+    bool useInject = false;
     switch (api) {
     case API_GL:
-        wrapperFilenames = glWrappers;
+        wrapperFilename = GL_TRACE_WRAPPER;
         break;
 #ifdef EGL_TRACE_WRAPPER
     case API_EGL:
-        wrapperFilenames = eglWrappers;
+        wrapperFilename = EGL_TRACE_WRAPPER;
         break;
 #endif
 #ifdef _WIN32
     case API_D3D7:
-        wrapperFilenames = d3d7Wrappers;
+        wrapperFilename = "ddraw.dll";
         break;
     case API_D3D8:
-        wrapperFilenames = d3d8Wrappers;
+        wrapperFilename = "d3d8.dll";
         break;
     case API_D3D9:
-        wrapperFilenames = d3d9Wrappers;
+        wrapperFilename = "d3d9.dll";
         break;
     case API_D3D10:
     case API_D3D10_1:
     case API_D3D11:
-        wrapperFilenames = dxgiWrappers;
+        wrapperFilename = "dxgitrace.dll";
+        useInject = true;
         break;
 #endif
     default:
@@ -164,68 +128,67 @@ traceProgram(API api,
         return 1;
     }
 
-    numWrappers = 0;
-    while (wrapperFilenames[numWrappers]) {
-        ++numWrappers;
+    os::String wrapperPath = findWrapper(wrapperFilename);
+    if (!wrapperPath.length()) {
+        std::cerr << "error: failed to find " << wrapperFilename << "\n";
+        goto exit;
     }
 
-    unsigned i;
-    for (i = 0; i < numWrappers; ++i) {
-        const char *wrapperFilename = wrapperFilenames[i];
-
-        os::String wrapperPath = findWrapper(wrapperFilename);
-
-        if (!wrapperPath.length()) {
-            std::cerr << "error: failed to find " << wrapperFilename << "\n";
-            goto exit;
-        }
-
 #if defined(_WIN32)
+    if (useInject) {
+        args.push_back("inject");
+        args.push_back(wrapperPath);
+    } else {
         /* On Windows copy the wrapper to the program directory.
          */
         if (!copyWrapper(wrapperPath, argv[0], verbose)) {
             goto exit;
         }
-#endif /* _WIN32 */
+    }
+#else  /* !_WIN32 */
+    (void)useInject;
+#endif /* !_WIN32 */
 
 #if defined(__APPLE__)
-        /* On Mac OS X, using DYLD_LIBRARY_PATH, we actually set the
-         * directory, not the file. */
-        wrapperPath.trimFilename();
+    /* On Mac OS X, using DYLD_LIBRARY_PATH, we actually set the
+     * directory, not the file. */
+    wrapperPath.trimFilename();
 #endif
 
 #if defined(TRACE_VARIABLE)
-        assert(numWrappers == 1);
-        if (verbose) {
-            std::cerr << TRACE_VARIABLE << "=" << wrapperPath.str() << "\n";
-        }
-        /* FIXME: Don't modify the current environment */
-        os::setEnvironment(TRACE_VARIABLE, wrapperPath.str());
-#endif /* TRACE_VARIABLE */
+    if (verbose) {
+        std::cerr << TRACE_VARIABLE << "=" << wrapperPath.str() << "\n";
     }
+    /* FIXME: Don't modify the current environment */
+    os::setEnvironment(TRACE_VARIABLE, wrapperPath.str());
+#endif /* TRACE_VARIABLE */
 
     if (output) {
         os::setEnvironment("TRACE_FILE", output);
     }
 
+    for (char * const * arg = argv; *arg; ++arg) {
+        args.push_back(*arg);
+    }
+    args.push_back(NULL);
+
     if (verbose) {
         const char *sep = "";
-        for (char * const * arg = argv; *arg; ++arg) {
-            std::cerr << *arg << sep;
+        for (unsigned i = 0; i < args.size(); ++i) {
+            std::cerr << sep << args[i];
             sep = " ";
         }
         std::cerr << "\n";
     }
 
-    status = os::execute(argv);
+    status = os::execute((char * const *)&args[0]);
 
 exit:
 #if defined(TRACE_VARIABLE)
     os::unsetEnvironment(TRACE_VARIABLE);
 #endif
 #if defined(_WIN32)
-    for (unsigned j = 0; j < i; ++j) {
-        const char *wrapperFilename = wrapperFilenames[j];
+    if (!useInject) {
         os::String tmpWrapper(argv[0]);
         tmpWrapper.trimFilename();
         tmpWrapper.join(wrapperFilename);
diff --git a/inject/CMakeLists.txt b/inject/CMakeLists.txt
new file mode 100644 (file)
index 0000000..489b819
--- /dev/null
@@ -0,0 +1,19 @@
+set (CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
+
+add_library (injectee MODULE
+    injectee.cpp
+)
+set_target_properties (injectee PROPERTIES
+    PREFIX ""
+    OUTPUT_NAME inject
+)
+install (TARGETS injectee LIBRARY DESTINATION bin)
+
+add_executable (injector
+    injector.cpp
+)
+set_target_properties (injector PROPERTIES
+    PREFIX ""
+    OUTPUT_NAME inject
+)
+install (TARGETS injector RUNTIME DESTINATION bin)
diff --git a/inject/inject.h b/inject/inject.h
new file mode 100644 (file)
index 0000000..058ada6
--- /dev/null
@@ -0,0 +1,181 @@
+/**************************************************************************
+ *
+ * Copyright 2011-2012 Jose Fonseca
+ * All Rights Reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ *
+ **************************************************************************/
+
+
+/*
+ * Code for the DLL that will be injected in the target process.
+ *
+ * The injected DLL will manipulate the import tables to hook the
+ * modules/functions of interest.
+ *
+ * See also:
+ * - http://www.codeproject.com/KB/system/api_spying_hack.aspx
+ * - http://www.codeproject.com/KB/threads/APIHooking.aspx
+ * - http://msdn.microsoft.com/en-us/magazine/cc301808.aspx
+ */
+
+
+#include <windows.h>
+
+
+static inline const char *
+getSeparator(const char *szFilename) {
+    const char *p, *q;
+    p = NULL;
+    q = szFilename;
+    char c;
+    do  {
+        c = *q++;
+        if (c == '\\' || c == '/' || c == ':') {
+            p = q;
+        }
+    } while (c);
+    return p;
+}
+
+
+static inline const char *
+getBaseName(const char *szFilename) {
+    const char *pSeparator = getSeparator(szFilename);
+    if (!pSeparator) {
+        return szFilename;
+    }
+    return pSeparator;
+}
+
+
+static inline void
+getDirName(char *szFilename) {
+    char *pSeparator = const_cast<char *>(getSeparator(szFilename));
+    if (pSeparator) {
+        *pSeparator = '\0';
+    }
+}
+
+
+static inline void
+getModuleName(char *szModuleName, size_t n, const char *szFilename) {
+    char *p = szModuleName;
+    const char *q = getBaseName(szFilename);
+    char c;
+    while (--n) {
+        c = *q++;
+        if (c == '.' || c == '\0') {
+            break;
+        }
+        *p++ = c;
+    };
+    *p++ = '\0';
+}
+
+
+#define SHARED_MEM_SIZE 4096
+
+static LPVOID pSharedMem = NULL;
+static HANDLE hFileMapping = NULL;
+
+
+static LPSTR
+OpenSharedMemory(void) {
+    if (pSharedMem) {
+        return (LPSTR)pSharedMem;
+    }
+
+    hFileMapping = CreateFileMapping(
+        INVALID_HANDLE_VALUE,   // system paging file
+        NULL,                   // lpAttributes
+        PAGE_READWRITE,         // read/write access
+        0,                      // dwMaximumSizeHigh
+        SHARED_MEM_SIZE,              // dwMaximumSizeLow
+        TEXT("injectfilemap")); // name of map object
+    if (hFileMapping == NULL) {
+        fprintf(stderr, "Failed to create file mapping\n");
+        return NULL;
+    }
+
+    BOOL bAlreadyExists = (GetLastError() == ERROR_ALREADY_EXISTS);
+
+    pSharedMem = MapViewOfFile(
+        hFileMapping,
+        FILE_MAP_WRITE, // read/write access
+        0,              // dwFileOffsetHigh
+        0,              // dwFileOffsetLow
+        0);             // dwNumberOfBytesToMap (entire file)
+    if (pSharedMem == NULL) {
+        fprintf(stderr, "Failed to map view \n");
+        return NULL;
+    }
+
+    if (!bAlreadyExists) {
+        memset(pSharedMem, 0, SHARED_MEM_SIZE);
+    }
+
+    return (LPSTR)pSharedMem;
+}
+
+
+static inline VOID
+CloseSharedMem(void) {
+    if (!pSharedMem) {
+        return;
+    }
+
+    UnmapViewOfFile(pSharedMem);
+    pSharedMem = NULL;
+
+    CloseHandle(hFileMapping);
+    hFileMapping = NULL;
+}
+
+
+static inline VOID
+SetSharedMem(LPCSTR lpszSrc) {
+    LPSTR lpszDst = OpenSharedMemory();
+    if (!lpszDst) {
+        return;
+    }
+
+    size_t n = 1;
+    while (*lpszSrc && n < SHARED_MEM_SIZE) {
+        *lpszDst++ = *lpszSrc++;
+        n++;
+    }
+    *lpszDst = '\0';
+}
+
+
+static inline VOID
+GetSharedMem(LPSTR lpszDst, size_t n) {
+    LPCSTR lpszSrc = OpenSharedMemory();
+    if (!lpszSrc) {
+        return;
+    }
+
+    while (*lpszSrc && --n) {
+        *lpszDst++ = *lpszSrc++;
+    }
+    *lpszDst = '\0';
+}
+
diff --git a/inject/injectee.cpp b/inject/injectee.cpp
new file mode 100644 (file)
index 0000000..86992b8
--- /dev/null
@@ -0,0 +1,572 @@
+/**************************************************************************
+ *
+ * Copyright 2011-2012 Jose Fonseca
+ * All Rights Reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ *
+ **************************************************************************/
+
+
+/*
+ * Code for the DLL that will be injected in the target process.
+ *
+ * The injected DLL will manipulate the import tables to hook the
+ * modules/functions of interest.
+ *
+ * See also:
+ * - http://www.codeproject.com/KB/system/api_spying_hack.aspx
+ * - http://www.codeproject.com/KB/threads/APIHooking.aspx
+ * - http://msdn.microsoft.com/en-us/magazine/cc301808.aspx
+ */
+
+
+#include <assert.h>
+#include <stdio.h>
+#include <stdarg.h>
+
+#include <windows.h>
+#include <tlhelp32.h>
+
+#include "inject.h"
+
+
+#define VERBOSITY 0
+#define NOOP 0
+
+
+static CRITICAL_SECTION Mutex = {(PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0};
+
+
+static void
+debugPrintf(const char *format, ...)
+{
+#if VERBOSITY > 0
+    static char buf[4096];
+
+    EnterCriticalSection(&Mutex);
+
+    va_list ap;
+    va_start(ap, format);
+    _vsnprintf(buf, sizeof buf, format, ap);
+    va_end(ap);
+
+    OutputDebugStringA(buf);
+
+    LeaveCriticalSection(&Mutex);
+#endif
+}
+
+
+static HMODULE WINAPI
+MyLoadLibraryA(LPCSTR lpLibFileName);
+
+static HMODULE WINAPI
+MyLoadLibraryW(LPCWSTR lpLibFileName);
+
+static HMODULE WINAPI
+MyLoadLibraryExA(LPCSTR lpFileName, HANDLE hFile, DWORD dwFlags);
+
+static HMODULE WINAPI
+MyLoadLibraryExW(LPCWSTR lpFileName, HANDLE hFile, DWORD dwFlags);
+
+static FARPROC WINAPI
+MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName);
+
+
+static const char *
+getImportDescriptionName(HMODULE hModule, const PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor) {
+    const char* szName = (const char*)((PBYTE)hModule + pImportDescriptor->Name);
+    return szName;
+}
+
+
+static PIMAGE_IMPORT_DESCRIPTOR
+getImportDescriptor(HMODULE hModule,
+                    const char *szModule,
+                    const char *pszDllName)
+{
+    MEMORY_BASIC_INFORMATION MemoryInfo;
+    if (VirtualQuery(hModule, &MemoryInfo, sizeof MemoryInfo) != sizeof MemoryInfo) {
+        debugPrintf("%s: %s: VirtualQuery failed\n", __FUNCTION__, szModule);
+        return NULL;
+    }
+    if (MemoryInfo.Protect & (PAGE_NOACCESS | PAGE_EXECUTE)) {
+        debugPrintf("%s: %s: no read access (Protect = 0x%08x)\n", __FUNCTION__, szModule, MemoryInfo.Protect);
+        return NULL;
+    }
+
+    PIMAGE_DOS_HEADER pDosHeader = (PIMAGE_DOS_HEADER)hModule;
+    PIMAGE_NT_HEADERS pNtHeaders = (PIMAGE_NT_HEADERS)((PBYTE)hModule + pDosHeader->e_lfanew);
+
+    PIMAGE_OPTIONAL_HEADER pOptionalHeader = &pNtHeaders->OptionalHeader;
+
+    UINT_PTR ImportAddress = pOptionalHeader->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress;
+
+    if (!ImportAddress) {
+        return NULL;
+    }
+
+    PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = (PIMAGE_IMPORT_DESCRIPTOR)((PBYTE)hModule + ImportAddress);
+
+    while (pImportDescriptor->FirstThunk) {
+        const char* szName = getImportDescriptionName(hModule, pImportDescriptor);
+        if (stricmp(pszDllName, szName) == 0) {
+            return pImportDescriptor;
+        }
+        ++pImportDescriptor;
+    }
+
+    return NULL;
+}
+
+
+static BOOL
+replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress)
+{
+    DWORD flOldProtect;
+
+    if (*lpOldAddress == lpNewAddress) {
+        return TRUE;
+    }
+
+    EnterCriticalSection(&Mutex);
+
+    if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, PAGE_READWRITE, &flOldProtect))) {
+        LeaveCriticalSection(&Mutex);
+        return FALSE;
+    }
+
+    *lpOldAddress = lpNewAddress;
+
+    if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, flOldProtect, &flOldProtect))) {
+        LeaveCriticalSection(&Mutex);
+        return FALSE;
+    }
+
+    LeaveCriticalSection(&Mutex);
+    return TRUE;
+}
+
+
+static LPVOID *
+getOldFunctionAddress(HMODULE hModule,
+                    PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
+                    const char* pszFunctionName)
+{
+    PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
+    PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
+
+    //debugPrintf("  %s\n", __FUNCTION__);
+
+    while (pOriginalFirstThunk->u1.Function) {
+        PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
+        const char* szName = (const char* )pImport->Name;
+        //debugPrintf("    %s\n", szName);
+        if (strcmp(pszFunctionName, szName) == 0) {
+            //debugPrintf("  %s succeeded\n", __FUNCTION__);
+            return (LPVOID *)(&pFirstThunk->u1.Function);
+        }
+        ++pOriginalFirstThunk;
+        ++pFirstThunk;
+    }
+
+    //debugPrintf("  %s failed\n", __FUNCTION__);
+
+    return NULL;
+}
+
+
+static void
+replaceModule(HMODULE hModule,
+              const char *szModule,
+              PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
+              HMODULE hNewModule)
+{
+    PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
+    PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
+
+    while (pOriginalFirstThunk->u1.Function) {
+        PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
+        const char* szFunctionName = (const char* )pImport->Name;
+        debugPrintf("      hooking %s->%s!%s\n", szModule,
+                getImportDescriptionName(hModule, pImportDescriptor),
+                szFunctionName);
+
+        PROC pNewProc = GetProcAddress(hNewModule, szFunctionName);
+        if (!pNewProc) {
+            debugPrintf("  warning: no replacement for %s\n", szFunctionName);
+        } else {
+            LPVOID *lpOldAddress = (LPVOID *)(&pFirstThunk->u1.Function);
+            replaceAddress(lpOldAddress, (LPVOID)pNewProc);
+        }
+
+        ++pOriginalFirstThunk;
+        ++pFirstThunk;
+    }
+}
+
+
+static BOOL
+hookFunction(HMODULE hModule,
+             const char *szModule,
+             const char *pszDllName,
+             const char *pszFunctionName,
+             LPVOID lpNewAddress)
+{
+    PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
+    if (pImportDescriptor == NULL) {
+        return FALSE;
+    }
+    LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName);
+    if (lpOldFunctionAddress == NULL) {
+        return FALSE;
+    }
+
+    if (*lpOldFunctionAddress == lpNewAddress) {
+        return TRUE;
+    }
+
+    if (VERBOSITY >= 3) {
+        debugPrintf("      hooking %s->%s!%s\n", szModule, pszDllName, pszFunctionName);
+    }
+
+    return replaceAddress(lpOldFunctionAddress, lpNewAddress);
+}
+
+
+static BOOL
+replaceImport(HMODULE hModule,
+              const char *szModule,
+              const char *pszDllName,
+              HMODULE hNewModule)
+{
+#if NOOP
+    return TRUE;
+#endif
+
+    PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
+    if (pImportDescriptor == NULL) {
+        return TRUE;
+    }
+
+    replaceModule(hModule, szModule, pImportDescriptor, hNewModule);
+
+    return TRUE;
+}
+
+static HMODULE g_hThisModule = NULL;
+
+
+struct Replacement {
+    const char *szMatchModule;
+    HMODULE hReplaceModule;
+};
+
+static unsigned numReplacements = 0;
+static Replacement replacements[32];
+
+
+
+static void
+hookModule(HMODULE hModule,
+           const char *szModule)
+{
+    if (hModule == g_hThisModule) {
+        return;
+    }
+
+    for (unsigned i = 0; i < numReplacements; ++i) {
+        if (hModule == replacements[i].hReplaceModule) {
+            return;
+        }
+    }
+
+    hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryA", (LPVOID)MyLoadLibraryA);
+    hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryW", (LPVOID)MyLoadLibraryW);
+    hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExA", (LPVOID)MyLoadLibraryExA);
+    hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExW", (LPVOID)MyLoadLibraryExW);
+    hookFunction(hModule, szModule, "kernel32.dll", "GetProcAddress", (LPVOID)MyGetProcAddress);
+
+    const char *szBaseName = getBaseName(szModule);
+    for (unsigned i = 0; i < numReplacements; ++i) {
+        if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
+            return;
+        }
+    }
+
+    /* Don't hook internal dependencies */
+    if (stricmp(szBaseName, "d3d10core.dll") == 0 ||
+        stricmp(szBaseName, "d3d10level9.dll") == 0 ||
+        stricmp(szBaseName, "d3d10sdklayers.dll") == 0 ||
+        stricmp(szBaseName, "d3d10_1core.dll") == 0 ||
+        stricmp(szBaseName, "d3d11sdklayers.dll") == 0 ||
+        stricmp(szBaseName, "d3d11_1sdklayers.dll") == 0) {
+        return;
+    }
+
+    for (unsigned i = 0; i < numReplacements; ++i) {
+        replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
+        replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
+        replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
+    }
+}
+
+static void
+hookAllModules(void)
+{
+    HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
+    if (hModuleSnap == INVALID_HANDLE_VALUE) {
+        return;
+    }
+
+    MODULEENTRY32 me32;
+    me32.dwSize = sizeof me32;
+
+    static bool first = true;
+    if (first) {
+        if (Module32First(hModuleSnap, &me32)) {
+            debugPrintf("  modules:\n");
+            do  {
+                debugPrintf("     %s\n", me32.szExePath);
+            } while (Module32Next(hModuleSnap, &me32));
+        }
+        first = false;
+    }
+
+    if (Module32First(hModuleSnap, &me32)) {
+        do  {
+            hookModule(me32.hModule, me32.szExePath);
+        } while (Module32Next(hModuleSnap, &me32));
+    }
+
+    CloseHandle(hModuleSnap);
+}
+
+
+
+
+static HMODULE WINAPI
+MyLoadLibrary(LPCSTR lpLibFileName, HANDLE hFile = NULL, DWORD dwFlags = 0)
+{
+    // To Send the information to the server informing that,
+    // LoadLibrary is invoked.
+    HMODULE hModule = LoadLibraryExA(lpLibFileName, hFile, dwFlags);
+
+    //hookModule(hModule, lpLibFileName);
+    hookAllModules();
+
+    return hModule;
+}
+
+static HMODULE WINAPI
+MyLoadLibraryA(LPCSTR lpLibFileName)
+{
+    if (VERBOSITY >= 2) {
+        debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
+    }
+
+    const char *szBaseName = getBaseName(lpLibFileName);
+    for (unsigned i = 0; i < numReplacements; ++i) {
+        if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
+            debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
+#ifdef __GNUC__
+            void *caller = __builtin_return_address (0);
+
+            HMODULE hModule = 0;
+            BOOL bRet = GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
+                                     (LPCTSTR)caller,
+                                     &hModule);
+            assert(bRet);
+            char szCaller[256];
+            DWORD dwRet = GetModuleFileNameA(hModule, szCaller, sizeof szCaller);
+            assert(dwRet);
+            debugPrintf("  called from %s\n", szCaller);
+#endif
+            break;
+        }
+    }
+
+    return MyLoadLibrary(lpLibFileName);
+}
+
+static HMODULE WINAPI
+MyLoadLibraryW(LPCWSTR lpLibFileName)
+{
+    if (VERBOSITY >= 2) {
+        debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
+    }
+
+    char szFileName[256];
+    wcstombs(szFileName, lpLibFileName, sizeof szFileName);
+
+    return MyLoadLibrary(szFileName);
+}
+
+static HMODULE WINAPI
+MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
+{
+    if (VERBOSITY >= 2) {
+        debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
+    }
+    return MyLoadLibrary(lpLibFileName, hFile, dwFlags);
+}
+
+static HMODULE WINAPI
+MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
+{
+    if (VERBOSITY >= 2) {
+        debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
+    }
+
+    char szFileName[256];
+    wcstombs(szFileName, lpLibFileName, sizeof szFileName);
+
+    return MyLoadLibrary(szFileName, hFile, dwFlags);
+}
+
+static FARPROC WINAPI
+MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName) {
+
+    if (VERBOSITY >= 99) {
+        /* XXX this can cause segmentation faults */
+        debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpProcName);
+    }
+
+    assert(hModule != g_hThisModule);
+    for (unsigned i = 0; i < numReplacements; ++i) {
+        if (hModule == replacements[i].hReplaceModule) {
+            return GetProcAddress(hModule, lpProcName);
+        }
+    }
+
+#if !NOOP
+    char szModule[256];
+    DWORD dwRet = GetModuleFileNameA(hModule, szModule, sizeof szModule);
+    assert(dwRet);
+    const char *szBaseName = getBaseName(szModule);
+
+    for (unsigned i = 0; i < numReplacements; ++i) {
+
+        if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
+            debugPrintf("  %s(\"%s\", \"%s\")\n", __FUNCTION__, szModule, lpProcName);
+            FARPROC pProcAddress = GetProcAddress(replacements[i].hReplaceModule, lpProcName);
+            if (pProcAddress) {
+                if (VERBOSITY >= 2) {
+                    debugPrintf("      replacing %s!%s\n", szBaseName, lpProcName);
+                }
+                return pProcAddress;
+            } else {
+                debugPrintf("      ignoring %s!%s\n", szBaseName, lpProcName);
+                break;
+            }
+        }
+    }
+#endif
+
+    return GetProcAddress(hModule, lpProcName);
+}
+
+
+EXTERN_C BOOL WINAPI
+DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
+{
+    const char *szNewDllName = NULL;
+    HMODULE hNewModule = NULL;
+    const char *szNewDllBaseName;
+
+    switch (fdwReason) {
+    case DLL_PROCESS_ATTACH:
+        debugPrintf("DLL_PROCESS_ATTACH\n");
+
+        g_hThisModule = hinstDLL;
+
+        {
+            char szProcess[MAX_PATH];
+            GetModuleFileNameA(NULL, szProcess, sizeof szProcess);
+            debugPrintf("  attached to %s\n", szProcess);
+        }
+
+        /*
+         * Calling LoadLibrary inside DllMain is strongly discouraged.  But it
+         * works quite well, provided that the loaded DLL does not require or do
+         * anything special in its DllMain, which seems to be the general case.
+         *
+         * See also:
+         * - http://stackoverflow.com/questions/4370812/calling-loadlibrary-from-dllmain
+         * - http://msdn.microsoft.com/en-us/library/ms682583
+         */
+
+#if 0
+        szNewDllName = getenv("INJECT_DLL");
+        if (!szNewDllName) {
+            debugPrintf("warning: INJECT_DLL not set\n");
+            return FALSE;
+        }
+#else
+        static char szSharedMemCopy[MAX_PATH];
+        GetSharedMem(szSharedMemCopy, sizeof szSharedMemCopy);
+        szNewDllName = szSharedMemCopy;
+#endif
+        debugPrintf("  injecting %s\n", szNewDllName);
+
+        hNewModule = LoadLibraryA(szNewDllName);
+        if (!hNewModule) {
+            debugPrintf("warning: failed to load %s\n", szNewDllName);
+            return FALSE;
+        }
+
+        szNewDllBaseName = getBaseName(szNewDllName);
+        if (stricmp(szNewDllBaseName, "dxgitrace.dll") == 0) {
+            replacements[numReplacements].szMatchModule = "dxgi.dll";
+            replacements[numReplacements].hReplaceModule = hNewModule;
+            ++numReplacements;
+
+            replacements[numReplacements].szMatchModule = "d3d10.dll";
+            replacements[numReplacements].hReplaceModule = hNewModule;
+            ++numReplacements;
+
+            replacements[numReplacements].szMatchModule = "d3d10_1.dll";
+            replacements[numReplacements].hReplaceModule = hNewModule;
+            ++numReplacements;
+
+            replacements[numReplacements].szMatchModule = "d3d11.dll";
+            replacements[numReplacements].hReplaceModule = hNewModule;
+            ++numReplacements;
+        } else {
+            replacements[numReplacements].szMatchModule = szNewDllBaseName;
+            replacements[numReplacements].hReplaceModule = hNewModule;
+            ++numReplacements;
+        }
+
+        hookAllModules();
+        break;
+
+    case DLL_THREAD_ATTACH:
+        break;
+
+    case DLL_THREAD_DETACH:
+        break;
+
+    case DLL_PROCESS_DETACH:
+        debugPrintf("DLL_PROCESS_DETACH\n");
+        break;
+    }
+    return TRUE;
+}
diff --git a/inject/injector.cpp b/inject/injector.cpp
new file mode 100644 (file)
index 0000000..04e0959
--- /dev/null
@@ -0,0 +1,289 @@
+/**************************************************************************
+ *
+ * Copyright 2011 Jose Fonseca
+ * All Rights Reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ *
+ **************************************************************************/
+
+
+/*
+ * Main program to start and inject a DLL into a process via a remote thread.
+ *
+ * For background see:
+ * - http://en.wikipedia.org/wiki/DLL_injection#Approaches_on_Microsoft_Windows
+ * - http://www.codeproject.com/KB/threads/completeinject.aspx
+ * - http://www.codeproject.com/KB/threads/winspy.aspx
+ * - http://www.codeproject.com/KB/DLL/DLL_Injection_tutorial.aspx
+ * - http://www.codeproject.com/KB/threads/APIHooking.aspx
+ *
+ * Other slightly different techniques:
+ * - http://www.fr33project.org/pages/projects/phook.htm
+ * - http://www.hbgary.com/loading-a-dll-without-calling-loadlibrary
+ * - http://securityxploded.com/ntcreatethreadex.php
+ */
+
+#include <string>
+
+#include <windows.h>
+#include <stdio.h>
+
+#include "inject.h"
+
+
+/**
+ * Determine whether an argument should be quoted.
+ */
+static bool
+needsQuote(const char *arg)
+{
+    char c;
+    while (true) {
+        c = *arg++;
+        if (c == '\0') {
+            break;
+        }
+        if (c == ' ' || c == '\t' || c == '\"') {
+            return true;
+        }
+        if (c == '\\') {
+            c = *arg++;
+            if (c == '\0') {
+                break;
+            }
+            if (c == '"') {
+                return true;
+            }
+        }
+    }
+    return false;
+}
+
+static void
+quoteArg(std::string &s, const char *arg)
+{
+    char c;
+    unsigned backslashes = 0;
+
+    s.push_back('"');
+    while (true) {
+        c = *arg++;
+        if (c == '\0') {
+            break;
+        } else if (c == '"') {
+            while (backslashes) {
+                s.push_back('\\');
+                --backslashes;
+            }
+            s.push_back('\\');
+        } else {
+            if (c == '\\') {
+                ++backslashes;
+            } else {
+                backslashes = 0;
+            }
+        }
+        s.push_back(c);
+    }
+    s.push_back('"');
+}
+
+
+int
+main(int argc, char *argv[])
+{
+
+    if (argc < 3) {
+        fprintf(stderr, "inject dllname.dll command [args] ...\n");
+        return 1;
+    }
+
+    const char *szDll = argv[1];
+#if 0
+    SetEnvironmentVariableA("INJECT_DLL", szDll);
+#else
+    SetSharedMem(szDll);
+#endif
+
+    PROCESS_INFORMATION processInfo;
+    HANDLE hProcess;
+    BOOL bAttach;
+    if (isdigit(argv[2][0])) {
+        bAttach = TRUE;
+
+        BOOL bRet;
+        HANDLE hToken   = NULL;
+        bRet = OpenProcessToken(GetCurrentProcess(), TOKEN_ALL_ACCESS, &hToken);
+        if (!bRet) {
+            fprintf(stderr, "error: OpenProcessToken returned %u\n", (unsigned)bRet);
+            return 1;
+        }
+
+        LUID Luid;
+        bRet = LookupPrivilegeValue(NULL, SE_DEBUG_NAME, &Luid);
+        if (!bRet) {
+            fprintf(stderr, "error: LookupPrivilegeValue returned %u\n", (unsigned)bRet);
+            return 1;
+        }
+
+        TOKEN_PRIVILEGES tp;
+        tp.PrivilegeCount = 1;
+        tp.Privileges[0].Luid = Luid;
+        tp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED;
+        bRet = AdjustTokenPrivileges(hToken, FALSE, &tp, sizeof tp, NULL, NULL);
+        if (!bRet) {
+            fprintf(stderr, "error: AdjustTokenPrivileges returned %u\n", (unsigned)bRet);
+            return 1;
+        }
+
+        DWORD dwDesiredAccess =
+            PROCESS_CREATE_THREAD |
+            PROCESS_QUERY_INFORMATION |
+            PROCESS_QUERY_LIMITED_INFORMATION |
+            PROCESS_VM_OPERATION |
+            PROCESS_VM_WRITE |
+            PROCESS_VM_READ;
+        DWORD dwProcessId = atol(argv[2]);
+        hProcess = OpenProcess(
+            dwDesiredAccess,
+            FALSE /* bInheritHandle */,
+            dwProcessId);
+        if (!hProcess) {
+            DWORD dwLastError = GetLastError();
+            fprintf(stderr, "error: failed to open process %lu (%lu)\n", dwProcessId, dwLastError);
+            return 1;
+        }
+    } else {
+        bAttach = FALSE;
+        std::string commandLine;
+        char sep = 0;
+        for (int i = 2; i < argc; ++i) {
+            const char *arg = argv[i];
+
+            if (sep) {
+                commandLine.push_back(sep);
+            }
+
+            if (needsQuote(arg)) {
+                quoteArg(commandLine, arg);
+            } else {
+                commandLine.append(arg);
+            }
+
+            sep = ' ';
+        }
+
+        STARTUPINFO startupInfo;
+        memset(&startupInfo, 0, sizeof startupInfo);
+        startupInfo.cb = sizeof startupInfo;
+
+        // Create the process in suspended state
+        if (!CreateProcessA(
+               NULL,
+               const_cast<char *>(commandLine.c_str()), // only modified by CreateProcessW
+               0, // process attributes
+               0, // thread attributes
+               TRUE, // inherit handles
+               CREATE_SUSPENDED,
+               NULL, // environment
+               NULL, // current directory
+               &startupInfo,
+               &processInfo)) {
+            fprintf(stderr, "error: failed to execute %s\n", commandLine.c_str());
+            return 1;
+        }
+
+        hProcess = processInfo.hProcess;
+    }
+
+    /*
+     * XXX: Mixed architecture don't quite work.  See also
+     * http://www.corsix.org/content/dll-injection-and-wow64
+     */
+    const char *szDllName;
+    szDllName = "inject.dll";
+
+    char szDllPath[MAX_PATH];
+    GetModuleFileNameA(NULL, szDllPath, sizeof szDllPath);
+    getDirName(szDllPath);
+    strncat(szDllPath, szDllName, sizeof szDllPath - strlen(szDllPath) - 1);
+
+    size_t szDllPathLength = strlen(szDllPath) + 1;
+
+    // Allocate memory in the target process to hold the DLL name
+    void *lpMemory = VirtualAllocEx(hProcess, NULL, szDllPathLength, MEM_COMMIT, PAGE_READWRITE);
+    if (!lpMemory) {
+        fprintf(stderr, "error: failed to allocate memory in the process\n");
+        TerminateProcess(hProcess, 1);
+        return 1;
+    }
+
+    // Copy DLL name into the target process
+    if (!WriteProcessMemory(hProcess, lpMemory, szDllPath, szDllPathLength, NULL)) {
+        fprintf(stderr, "error: failed to write into process memory\n");
+        TerminateProcess(hProcess, 1);
+        return 1;
+    }
+
+    /*
+     * Get LoadLibraryA address from kernel32.dll.  It's the same for all the
+     * process (XXX: but only within the same architecture).
+     */
+    PTHREAD_START_ROUTINE lpStartAddress =
+        (PTHREAD_START_ROUTINE)GetProcAddress(GetModuleHandleA("KERNEL32"), "LoadLibraryA");
+
+    // Create remote thread in another process
+    HANDLE hThread = CreateRemoteThread(hProcess, NULL, 0, lpStartAddress, lpMemory, 0, NULL);
+    if (!hThread) {
+        fprintf(stderr, "error: failed to create remote thread\n");
+        TerminateProcess(hProcess, 1);
+        return 1;
+    }
+
+    // Wait for it to finish
+    WaitForSingleObject(hThread, INFINITE);
+
+    DWORD hModule = 0;
+    GetExitCodeThread(hThread, &hModule);
+    if (!hModule) {
+        fprintf(stderr, "error: failed to inject %s\n", szDllPath);
+        TerminateProcess(hProcess, 1);
+        return 1;
+    }
+
+    if (bAttach) {
+        return 0;
+    }
+
+    // Start main process thread
+    ResumeThread(processInfo.hThread);
+
+    // Wait for it to finish
+    WaitForSingleObject(hProcess, INFINITE);
+
+    DWORD exitCode = ~0;
+    GetExitCodeProcess(hProcess, &exitCode);
+
+    CloseHandle(hProcess);
+    CloseHandle(processInfo.hThread);
+
+    return (int)exitCode;
+
+}