X-Git-Url: https://git.cworth.org/git?a=blobdiff_plain;f=inject%2Finjectee.cpp;fp=inject%2Finjectee.cpp;h=86992b83e94e3923b324e42378d446fa5b2073ee;hb=bd4937e47675d600b13174773dc05ab1129c266b;hp=0000000000000000000000000000000000000000;hpb=1592ad278ba90ed2d5f6fa2b49fadf2cc6e4e77e;p=apitrace diff --git a/inject/injectee.cpp b/inject/injectee.cpp new file mode 100644 index 0000000..86992b8 --- /dev/null +++ b/inject/injectee.cpp @@ -0,0 +1,572 @@ +/************************************************************************** + * + * Copyright 2011-2012 Jose Fonseca + * All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + **************************************************************************/ + + +/* + * Code for the DLL that will be injected in the target process. + * + * The injected DLL will manipulate the import tables to hook the + * modules/functions of interest. + * + * See also: + * - http://www.codeproject.com/KB/system/api_spying_hack.aspx + * - http://www.codeproject.com/KB/threads/APIHooking.aspx + * - http://msdn.microsoft.com/en-us/magazine/cc301808.aspx + */ + + +#include +#include +#include + +#include +#include + +#include "inject.h" + + +#define VERBOSITY 0 +#define NOOP 0 + + +static CRITICAL_SECTION Mutex = {(PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0}; + + +static void +debugPrintf(const char *format, ...) +{ +#if VERBOSITY > 0 + static char buf[4096]; + + EnterCriticalSection(&Mutex); + + va_list ap; + va_start(ap, format); + _vsnprintf(buf, sizeof buf, format, ap); + va_end(ap); + + OutputDebugStringA(buf); + + LeaveCriticalSection(&Mutex); +#endif +} + + +static HMODULE WINAPI +MyLoadLibraryA(LPCSTR lpLibFileName); + +static HMODULE WINAPI +MyLoadLibraryW(LPCWSTR lpLibFileName); + +static HMODULE WINAPI +MyLoadLibraryExA(LPCSTR lpFileName, HANDLE hFile, DWORD dwFlags); + +static HMODULE WINAPI +MyLoadLibraryExW(LPCWSTR lpFileName, HANDLE hFile, DWORD dwFlags); + +static FARPROC WINAPI +MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName); + + +static const char * +getImportDescriptionName(HMODULE hModule, const PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor) { + const char* szName = (const char*)((PBYTE)hModule + pImportDescriptor->Name); + return szName; +} + + +static PIMAGE_IMPORT_DESCRIPTOR +getImportDescriptor(HMODULE hModule, + const char *szModule, + const char *pszDllName) +{ + MEMORY_BASIC_INFORMATION MemoryInfo; + if (VirtualQuery(hModule, &MemoryInfo, sizeof MemoryInfo) != sizeof MemoryInfo) { + debugPrintf("%s: %s: VirtualQuery failed\n", __FUNCTION__, szModule); + return NULL; + } + if (MemoryInfo.Protect & (PAGE_NOACCESS | PAGE_EXECUTE)) { + debugPrintf("%s: %s: no read access (Protect = 0x%08x)\n", __FUNCTION__, szModule, MemoryInfo.Protect); + return NULL; + } + + PIMAGE_DOS_HEADER pDosHeader = (PIMAGE_DOS_HEADER)hModule; + PIMAGE_NT_HEADERS pNtHeaders = (PIMAGE_NT_HEADERS)((PBYTE)hModule + pDosHeader->e_lfanew); + + PIMAGE_OPTIONAL_HEADER pOptionalHeader = &pNtHeaders->OptionalHeader; + + UINT_PTR ImportAddress = pOptionalHeader->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress; + + if (!ImportAddress) { + return NULL; + } + + PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = (PIMAGE_IMPORT_DESCRIPTOR)((PBYTE)hModule + ImportAddress); + + while (pImportDescriptor->FirstThunk) { + const char* szName = getImportDescriptionName(hModule, pImportDescriptor); + if (stricmp(pszDllName, szName) == 0) { + return pImportDescriptor; + } + ++pImportDescriptor; + } + + return NULL; +} + + +static BOOL +replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress) +{ + DWORD flOldProtect; + + if (*lpOldAddress == lpNewAddress) { + return TRUE; + } + + EnterCriticalSection(&Mutex); + + if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, PAGE_READWRITE, &flOldProtect))) { + LeaveCriticalSection(&Mutex); + return FALSE; + } + + *lpOldAddress = lpNewAddress; + + if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, flOldProtect, &flOldProtect))) { + LeaveCriticalSection(&Mutex); + return FALSE; + } + + LeaveCriticalSection(&Mutex); + return TRUE; +} + + +static LPVOID * +getOldFunctionAddress(HMODULE hModule, + PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor, + const char* pszFunctionName) +{ + PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk); + PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk); + + //debugPrintf(" %s\n", __FUNCTION__); + + while (pOriginalFirstThunk->u1.Function) { + PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData); + const char* szName = (const char* )pImport->Name; + //debugPrintf(" %s\n", szName); + if (strcmp(pszFunctionName, szName) == 0) { + //debugPrintf(" %s succeeded\n", __FUNCTION__); + return (LPVOID *)(&pFirstThunk->u1.Function); + } + ++pOriginalFirstThunk; + ++pFirstThunk; + } + + //debugPrintf(" %s failed\n", __FUNCTION__); + + return NULL; +} + + +static void +replaceModule(HMODULE hModule, + const char *szModule, + PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor, + HMODULE hNewModule) +{ + PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk); + PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk); + + while (pOriginalFirstThunk->u1.Function) { + PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData); + const char* szFunctionName = (const char* )pImport->Name; + debugPrintf(" hooking %s->%s!%s\n", szModule, + getImportDescriptionName(hModule, pImportDescriptor), + szFunctionName); + + PROC pNewProc = GetProcAddress(hNewModule, szFunctionName); + if (!pNewProc) { + debugPrintf(" warning: no replacement for %s\n", szFunctionName); + } else { + LPVOID *lpOldAddress = (LPVOID *)(&pFirstThunk->u1.Function); + replaceAddress(lpOldAddress, (LPVOID)pNewProc); + } + + ++pOriginalFirstThunk; + ++pFirstThunk; + } +} + + +static BOOL +hookFunction(HMODULE hModule, + const char *szModule, + const char *pszDllName, + const char *pszFunctionName, + LPVOID lpNewAddress) +{ + PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName); + if (pImportDescriptor == NULL) { + return FALSE; + } + LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName); + if (lpOldFunctionAddress == NULL) { + return FALSE; + } + + if (*lpOldFunctionAddress == lpNewAddress) { + return TRUE; + } + + if (VERBOSITY >= 3) { + debugPrintf(" hooking %s->%s!%s\n", szModule, pszDllName, pszFunctionName); + } + + return replaceAddress(lpOldFunctionAddress, lpNewAddress); +} + + +static BOOL +replaceImport(HMODULE hModule, + const char *szModule, + const char *pszDllName, + HMODULE hNewModule) +{ +#if NOOP + return TRUE; +#endif + + PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName); + if (pImportDescriptor == NULL) { + return TRUE; + } + + replaceModule(hModule, szModule, pImportDescriptor, hNewModule); + + return TRUE; +} + +static HMODULE g_hThisModule = NULL; + + +struct Replacement { + const char *szMatchModule; + HMODULE hReplaceModule; +}; + +static unsigned numReplacements = 0; +static Replacement replacements[32]; + + + +static void +hookModule(HMODULE hModule, + const char *szModule) +{ + if (hModule == g_hThisModule) { + return; + } + + for (unsigned i = 0; i < numReplacements; ++i) { + if (hModule == replacements[i].hReplaceModule) { + return; + } + } + + hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryA", (LPVOID)MyLoadLibraryA); + hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryW", (LPVOID)MyLoadLibraryW); + hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExA", (LPVOID)MyLoadLibraryExA); + hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExW", (LPVOID)MyLoadLibraryExW); + hookFunction(hModule, szModule, "kernel32.dll", "GetProcAddress", (LPVOID)MyGetProcAddress); + + const char *szBaseName = getBaseName(szModule); + for (unsigned i = 0; i < numReplacements; ++i) { + if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) { + return; + } + } + + /* Don't hook internal dependencies */ + if (stricmp(szBaseName, "d3d10core.dll") == 0 || + stricmp(szBaseName, "d3d10level9.dll") == 0 || + stricmp(szBaseName, "d3d10sdklayers.dll") == 0 || + stricmp(szBaseName, "d3d10_1core.dll") == 0 || + stricmp(szBaseName, "d3d11sdklayers.dll") == 0 || + stricmp(szBaseName, "d3d11_1sdklayers.dll") == 0) { + return; + } + + for (unsigned i = 0; i < numReplacements; ++i) { + replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule); + replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule); + replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule); + } +} + +static void +hookAllModules(void) +{ + HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId()); + if (hModuleSnap == INVALID_HANDLE_VALUE) { + return; + } + + MODULEENTRY32 me32; + me32.dwSize = sizeof me32; + + static bool first = true; + if (first) { + if (Module32First(hModuleSnap, &me32)) { + debugPrintf(" modules:\n"); + do { + debugPrintf(" %s\n", me32.szExePath); + } while (Module32Next(hModuleSnap, &me32)); + } + first = false; + } + + if (Module32First(hModuleSnap, &me32)) { + do { + hookModule(me32.hModule, me32.szExePath); + } while (Module32Next(hModuleSnap, &me32)); + } + + CloseHandle(hModuleSnap); +} + + + + +static HMODULE WINAPI +MyLoadLibrary(LPCSTR lpLibFileName, HANDLE hFile = NULL, DWORD dwFlags = 0) +{ + // To Send the information to the server informing that, + // LoadLibrary is invoked. + HMODULE hModule = LoadLibraryExA(lpLibFileName, hFile, dwFlags); + + //hookModule(hModule, lpLibFileName); + hookAllModules(); + + return hModule; +} + +static HMODULE WINAPI +MyLoadLibraryA(LPCSTR lpLibFileName) +{ + if (VERBOSITY >= 2) { + debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName); + } + + const char *szBaseName = getBaseName(lpLibFileName); + for (unsigned i = 0; i < numReplacements; ++i) { + if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) { + debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName); +#ifdef __GNUC__ + void *caller = __builtin_return_address (0); + + HMODULE hModule = 0; + BOOL bRet = GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, + (LPCTSTR)caller, + &hModule); + assert(bRet); + char szCaller[256]; + DWORD dwRet = GetModuleFileNameA(hModule, szCaller, sizeof szCaller); + assert(dwRet); + debugPrintf(" called from %s\n", szCaller); +#endif + break; + } + } + + return MyLoadLibrary(lpLibFileName); +} + +static HMODULE WINAPI +MyLoadLibraryW(LPCWSTR lpLibFileName) +{ + if (VERBOSITY >= 2) { + debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName); + } + + char szFileName[256]; + wcstombs(szFileName, lpLibFileName, sizeof szFileName); + + return MyLoadLibrary(szFileName); +} + +static HMODULE WINAPI +MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags) +{ + if (VERBOSITY >= 2) { + debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName); + } + return MyLoadLibrary(lpLibFileName, hFile, dwFlags); +} + +static HMODULE WINAPI +MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags) +{ + if (VERBOSITY >= 2) { + debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName); + } + + char szFileName[256]; + wcstombs(szFileName, lpLibFileName, sizeof szFileName); + + return MyLoadLibrary(szFileName, hFile, dwFlags); +} + +static FARPROC WINAPI +MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName) { + + if (VERBOSITY >= 99) { + /* XXX this can cause segmentation faults */ + debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpProcName); + } + + assert(hModule != g_hThisModule); + for (unsigned i = 0; i < numReplacements; ++i) { + if (hModule == replacements[i].hReplaceModule) { + return GetProcAddress(hModule, lpProcName); + } + } + +#if !NOOP + char szModule[256]; + DWORD dwRet = GetModuleFileNameA(hModule, szModule, sizeof szModule); + assert(dwRet); + const char *szBaseName = getBaseName(szModule); + + for (unsigned i = 0; i < numReplacements; ++i) { + + if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) { + debugPrintf(" %s(\"%s\", \"%s\")\n", __FUNCTION__, szModule, lpProcName); + FARPROC pProcAddress = GetProcAddress(replacements[i].hReplaceModule, lpProcName); + if (pProcAddress) { + if (VERBOSITY >= 2) { + debugPrintf(" replacing %s!%s\n", szBaseName, lpProcName); + } + return pProcAddress; + } else { + debugPrintf(" ignoring %s!%s\n", szBaseName, lpProcName); + break; + } + } + } +#endif + + return GetProcAddress(hModule, lpProcName); +} + + +EXTERN_C BOOL WINAPI +DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved) +{ + const char *szNewDllName = NULL; + HMODULE hNewModule = NULL; + const char *szNewDllBaseName; + + switch (fdwReason) { + case DLL_PROCESS_ATTACH: + debugPrintf("DLL_PROCESS_ATTACH\n"); + + g_hThisModule = hinstDLL; + + { + char szProcess[MAX_PATH]; + GetModuleFileNameA(NULL, szProcess, sizeof szProcess); + debugPrintf(" attached to %s\n", szProcess); + } + + /* + * Calling LoadLibrary inside DllMain is strongly discouraged. But it + * works quite well, provided that the loaded DLL does not require or do + * anything special in its DllMain, which seems to be the general case. + * + * See also: + * - http://stackoverflow.com/questions/4370812/calling-loadlibrary-from-dllmain + * - http://msdn.microsoft.com/en-us/library/ms682583 + */ + +#if 0 + szNewDllName = getenv("INJECT_DLL"); + if (!szNewDllName) { + debugPrintf("warning: INJECT_DLL not set\n"); + return FALSE; + } +#else + static char szSharedMemCopy[MAX_PATH]; + GetSharedMem(szSharedMemCopy, sizeof szSharedMemCopy); + szNewDllName = szSharedMemCopy; +#endif + debugPrintf(" injecting %s\n", szNewDllName); + + hNewModule = LoadLibraryA(szNewDllName); + if (!hNewModule) { + debugPrintf("warning: failed to load %s\n", szNewDllName); + return FALSE; + } + + szNewDllBaseName = getBaseName(szNewDllName); + if (stricmp(szNewDllBaseName, "dxgitrace.dll") == 0) { + replacements[numReplacements].szMatchModule = "dxgi.dll"; + replacements[numReplacements].hReplaceModule = hNewModule; + ++numReplacements; + + replacements[numReplacements].szMatchModule = "d3d10.dll"; + replacements[numReplacements].hReplaceModule = hNewModule; + ++numReplacements; + + replacements[numReplacements].szMatchModule = "d3d10_1.dll"; + replacements[numReplacements].hReplaceModule = hNewModule; + ++numReplacements; + + replacements[numReplacements].szMatchModule = "d3d11.dll"; + replacements[numReplacements].hReplaceModule = hNewModule; + ++numReplacements; + } else { + replacements[numReplacements].szMatchModule = szNewDllBaseName; + replacements[numReplacements].hReplaceModule = hNewModule; + ++numReplacements; + } + + hookAllModules(); + break; + + case DLL_THREAD_ATTACH: + break; + + case DLL_THREAD_DETACH: + break; + + case DLL_PROCESS_DETACH: + debugPrintf("DLL_PROCESS_DETACH\n"); + break; + } + return TRUE; +}