]> git.cworth.org Git - apitrace/blob - inject/injectee.cpp
inject: Add define to use environment var instead of shared memory.
[apitrace] / inject / injectee.cpp
1 /**************************************************************************
2  *
3  * Copyright 2011-2012 Jose Fonseca
4  * All Rights Reserved.
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to deal
8  * in the Software without restriction, including without limitation the rights
9  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10  * copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in
14  * all copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22  * THE SOFTWARE.
23  *
24  **************************************************************************/
25
26
27 /*
28  * Code for the DLL that will be injected in the target process.
29  *
30  * The injected DLL will manipulate the import tables to hook the
31  * modules/functions of interest.
32  *
33  * See also:
34  * - http://www.codeproject.com/KB/system/api_spying_hack.aspx
35  * - http://www.codeproject.com/KB/threads/APIHooking.aspx
36  * - http://msdn.microsoft.com/en-us/magazine/cc301808.aspx
37  */
38
39
40 #include <assert.h>
41 #include <stdio.h>
42 #include <stdarg.h>
43
44 #include <windows.h>
45 #include <tlhelp32.h>
46
47 #include "inject.h"
48
49
50 #define VERBOSITY 0
51 #define NOOP 0
52
53
54 static CRITICAL_SECTION Mutex = {(PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0};
55
56
57 static void
58 debugPrintf(const char *format, ...)
59 {
60 #if VERBOSITY > 0
61     static char buf[4096];
62
63     EnterCriticalSection(&Mutex);
64
65     va_list ap;
66     va_start(ap, format);
67     _vsnprintf(buf, sizeof buf, format, ap);
68     va_end(ap);
69
70     OutputDebugStringA(buf);
71
72     LeaveCriticalSection(&Mutex);
73 #endif
74 }
75
76
77 static HMODULE WINAPI
78 MyLoadLibraryA(LPCSTR lpLibFileName);
79
80 static HMODULE WINAPI
81 MyLoadLibraryW(LPCWSTR lpLibFileName);
82
83 static HMODULE WINAPI
84 MyLoadLibraryExA(LPCSTR lpFileName, HANDLE hFile, DWORD dwFlags);
85
86 static HMODULE WINAPI
87 MyLoadLibraryExW(LPCWSTR lpFileName, HANDLE hFile, DWORD dwFlags);
88
89 static FARPROC WINAPI
90 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName);
91
92
93 static const char *
94 getImportDescriptionName(HMODULE hModule, const PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor) {
95     const char* szName = (const char*)((PBYTE)hModule + pImportDescriptor->Name);
96     return szName;
97 }
98
99
100 static PIMAGE_IMPORT_DESCRIPTOR
101 getImportDescriptor(HMODULE hModule,
102                     const char *szModule,
103                     const char *pszDllName)
104 {
105     MEMORY_BASIC_INFORMATION MemoryInfo;
106     if (VirtualQuery(hModule, &MemoryInfo, sizeof MemoryInfo) != sizeof MemoryInfo) {
107         debugPrintf("%s: %s: VirtualQuery failed\n", __FUNCTION__, szModule);
108         return NULL;
109     }
110     if (MemoryInfo.Protect & (PAGE_NOACCESS | PAGE_EXECUTE)) {
111         debugPrintf("%s: %s: no read access (Protect = 0x%08x)\n", __FUNCTION__, szModule, MemoryInfo.Protect);
112         return NULL;
113     }
114
115     PIMAGE_DOS_HEADER pDosHeader = (PIMAGE_DOS_HEADER)hModule;
116     PIMAGE_NT_HEADERS pNtHeaders = (PIMAGE_NT_HEADERS)((PBYTE)hModule + pDosHeader->e_lfanew);
117
118     PIMAGE_OPTIONAL_HEADER pOptionalHeader = &pNtHeaders->OptionalHeader;
119
120     UINT_PTR ImportAddress = pOptionalHeader->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress;
121
122     if (!ImportAddress) {
123         return NULL;
124     }
125
126     PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = (PIMAGE_IMPORT_DESCRIPTOR)((PBYTE)hModule + ImportAddress);
127
128     while (pImportDescriptor->FirstThunk) {
129         const char* szName = getImportDescriptionName(hModule, pImportDescriptor);
130         if (stricmp(pszDllName, szName) == 0) {
131             return pImportDescriptor;
132         }
133         ++pImportDescriptor;
134     }
135
136     return NULL;
137 }
138
139
140 static BOOL
141 replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress)
142 {
143     DWORD flOldProtect;
144
145     if (*lpOldAddress == lpNewAddress) {
146         return TRUE;
147     }
148
149     EnterCriticalSection(&Mutex);
150
151     if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, PAGE_READWRITE, &flOldProtect))) {
152         LeaveCriticalSection(&Mutex);
153         return FALSE;
154     }
155
156     *lpOldAddress = lpNewAddress;
157
158     if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, flOldProtect, &flOldProtect))) {
159         LeaveCriticalSection(&Mutex);
160         return FALSE;
161     }
162
163     LeaveCriticalSection(&Mutex);
164     return TRUE;
165 }
166
167
168 static LPVOID *
169 getOldFunctionAddress(HMODULE hModule,
170                     PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
171                     const char* pszFunctionName)
172 {
173     PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
174     PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
175
176     //debugPrintf("  %s\n", __FUNCTION__);
177
178     while (pOriginalFirstThunk->u1.Function) {
179         PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
180         const char* szName = (const char* )pImport->Name;
181         //debugPrintf("    %s\n", szName);
182         if (strcmp(pszFunctionName, szName) == 0) {
183             //debugPrintf("  %s succeeded\n", __FUNCTION__);
184             return (LPVOID *)(&pFirstThunk->u1.Function);
185         }
186         ++pOriginalFirstThunk;
187         ++pFirstThunk;
188     }
189
190     //debugPrintf("  %s failed\n", __FUNCTION__);
191
192     return NULL;
193 }
194
195
196 static void
197 replaceModule(HMODULE hModule,
198               const char *szModule,
199               PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
200               HMODULE hNewModule)
201 {
202     PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
203     PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
204
205     while (pOriginalFirstThunk->u1.Function) {
206         PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
207         const char* szFunctionName = (const char* )pImport->Name;
208         debugPrintf("      hooking %s->%s!%s\n", szModule,
209                 getImportDescriptionName(hModule, pImportDescriptor),
210                 szFunctionName);
211
212         PROC pNewProc = GetProcAddress(hNewModule, szFunctionName);
213         if (!pNewProc) {
214             debugPrintf("  warning: no replacement for %s\n", szFunctionName);
215         } else {
216             LPVOID *lpOldAddress = (LPVOID *)(&pFirstThunk->u1.Function);
217             replaceAddress(lpOldAddress, (LPVOID)pNewProc);
218         }
219
220         ++pOriginalFirstThunk;
221         ++pFirstThunk;
222     }
223 }
224
225
226 static BOOL
227 hookFunction(HMODULE hModule,
228              const char *szModule,
229              const char *pszDllName,
230              const char *pszFunctionName,
231              LPVOID lpNewAddress)
232 {
233     PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
234     if (pImportDescriptor == NULL) {
235         return FALSE;
236     }
237     LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName);
238     if (lpOldFunctionAddress == NULL) {
239         return FALSE;
240     }
241
242     if (*lpOldFunctionAddress == lpNewAddress) {
243         return TRUE;
244     }
245
246     if (VERBOSITY >= 3) {
247         debugPrintf("      hooking %s->%s!%s\n", szModule, pszDllName, pszFunctionName);
248     }
249
250     return replaceAddress(lpOldFunctionAddress, lpNewAddress);
251 }
252
253
254 static BOOL
255 replaceImport(HMODULE hModule,
256               const char *szModule,
257               const char *pszDllName,
258               HMODULE hNewModule)
259 {
260 #if NOOP
261     return TRUE;
262 #endif
263
264     PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
265     if (pImportDescriptor == NULL) {
266         return TRUE;
267     }
268
269     replaceModule(hModule, szModule, pImportDescriptor, hNewModule);
270
271     return TRUE;
272 }
273
274 static HMODULE g_hThisModule = NULL;
275
276
277 struct Replacement {
278     const char *szMatchModule;
279     HMODULE hReplaceModule;
280 };
281
282 static unsigned numReplacements = 0;
283 static Replacement replacements[32];
284
285
286
287 static void
288 hookModule(HMODULE hModule,
289            const char *szModule)
290 {
291     if (hModule == g_hThisModule) {
292         return;
293     }
294
295     for (unsigned i = 0; i < numReplacements; ++i) {
296         if (hModule == replacements[i].hReplaceModule) {
297             return;
298         }
299     }
300
301     hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryA", (LPVOID)MyLoadLibraryA);
302     hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryW", (LPVOID)MyLoadLibraryW);
303     hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExA", (LPVOID)MyLoadLibraryExA);
304     hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExW", (LPVOID)MyLoadLibraryExW);
305     hookFunction(hModule, szModule, "kernel32.dll", "GetProcAddress", (LPVOID)MyGetProcAddress);
306
307     const char *szBaseName = getBaseName(szModule);
308     for (unsigned i = 0; i < numReplacements; ++i) {
309         if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
310             return;
311         }
312     }
313
314     /* Don't hook internal dependencies */
315     if (stricmp(szBaseName, "d3d10core.dll") == 0 ||
316         stricmp(szBaseName, "d3d10level9.dll") == 0 ||
317         stricmp(szBaseName, "d3d10sdklayers.dll") == 0 ||
318         stricmp(szBaseName, "d3d10_1core.dll") == 0 ||
319         stricmp(szBaseName, "d3d11sdklayers.dll") == 0 ||
320         stricmp(szBaseName, "d3d11_1sdklayers.dll") == 0) {
321         return;
322     }
323
324     for (unsigned i = 0; i < numReplacements; ++i) {
325         replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
326         replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
327         replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
328     }
329 }
330
331 static void
332 hookAllModules(void)
333 {
334     HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
335     if (hModuleSnap == INVALID_HANDLE_VALUE) {
336         return;
337     }
338
339     MODULEENTRY32 me32;
340     me32.dwSize = sizeof me32;
341
342     static bool first = true;
343     if (first) {
344         if (Module32First(hModuleSnap, &me32)) {
345             debugPrintf("  modules:\n");
346             do  {
347                 debugPrintf("     %s\n", me32.szExePath);
348             } while (Module32Next(hModuleSnap, &me32));
349         }
350         first = false;
351     }
352
353     if (Module32First(hModuleSnap, &me32)) {
354         do  {
355             hookModule(me32.hModule, me32.szExePath);
356         } while (Module32Next(hModuleSnap, &me32));
357     }
358
359     CloseHandle(hModuleSnap);
360 }
361
362
363
364
365 static HMODULE WINAPI
366 MyLoadLibrary(LPCSTR lpLibFileName, HANDLE hFile = NULL, DWORD dwFlags = 0)
367 {
368     // To Send the information to the server informing that,
369     // LoadLibrary is invoked.
370     HMODULE hModule = LoadLibraryExA(lpLibFileName, hFile, dwFlags);
371
372     //hookModule(hModule, lpLibFileName);
373     hookAllModules();
374
375     return hModule;
376 }
377
378 static HMODULE WINAPI
379 MyLoadLibraryA(LPCSTR lpLibFileName)
380 {
381     if (VERBOSITY >= 2) {
382         debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
383     }
384
385     const char *szBaseName = getBaseName(lpLibFileName);
386     for (unsigned i = 0; i < numReplacements; ++i) {
387         if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
388             debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
389 #ifdef __GNUC__
390             void *caller = __builtin_return_address (0);
391
392             HMODULE hModule = 0;
393             BOOL bRet = GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
394                                      (LPCTSTR)caller,
395                                      &hModule);
396             assert(bRet);
397             char szCaller[256];
398             DWORD dwRet = GetModuleFileNameA(hModule, szCaller, sizeof szCaller);
399             assert(dwRet);
400             debugPrintf("  called from %s\n", szCaller);
401 #endif
402             break;
403         }
404     }
405
406     return MyLoadLibrary(lpLibFileName);
407 }
408
409 static HMODULE WINAPI
410 MyLoadLibraryW(LPCWSTR lpLibFileName)
411 {
412     if (VERBOSITY >= 2) {
413         debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
414     }
415
416     char szFileName[256];
417     wcstombs(szFileName, lpLibFileName, sizeof szFileName);
418
419     return MyLoadLibrary(szFileName);
420 }
421
422 static HMODULE WINAPI
423 MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
424 {
425     if (VERBOSITY >= 2) {
426         debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
427     }
428     return MyLoadLibrary(lpLibFileName, hFile, dwFlags);
429 }
430
431 static HMODULE WINAPI
432 MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
433 {
434     if (VERBOSITY >= 2) {
435         debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
436     }
437
438     char szFileName[256];
439     wcstombs(szFileName, lpLibFileName, sizeof szFileName);
440
441     return MyLoadLibrary(szFileName, hFile, dwFlags);
442 }
443
444 static FARPROC WINAPI
445 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName) {
446
447     if (VERBOSITY >= 99) {
448         /* XXX this can cause segmentation faults */
449         debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpProcName);
450     }
451
452     assert(hModule != g_hThisModule);
453     for (unsigned i = 0; i < numReplacements; ++i) {
454         if (hModule == replacements[i].hReplaceModule) {
455             return GetProcAddress(hModule, lpProcName);
456         }
457     }
458
459 #if !NOOP
460     char szModule[256];
461     DWORD dwRet = GetModuleFileNameA(hModule, szModule, sizeof szModule);
462     assert(dwRet);
463     const char *szBaseName = getBaseName(szModule);
464
465     for (unsigned i = 0; i < numReplacements; ++i) {
466
467         if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
468             debugPrintf("  %s(\"%s\", \"%s\")\n", __FUNCTION__, szModule, lpProcName);
469             FARPROC pProcAddress = GetProcAddress(replacements[i].hReplaceModule, lpProcName);
470             if (pProcAddress) {
471                 if (VERBOSITY >= 2) {
472                     debugPrintf("      replacing %s!%s\n", szBaseName, lpProcName);
473                 }
474                 return pProcAddress;
475             } else {
476                 debugPrintf("      ignoring %s!%s\n", szBaseName, lpProcName);
477                 break;
478             }
479         }
480     }
481 #endif
482
483     return GetProcAddress(hModule, lpProcName);
484 }
485
486
487 EXTERN_C BOOL WINAPI
488 DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
489 {
490     const char *szNewDllName = NULL;
491     HMODULE hNewModule = NULL;
492     const char *szNewDllBaseName;
493
494     switch (fdwReason) {
495     case DLL_PROCESS_ATTACH:
496         debugPrintf("DLL_PROCESS_ATTACH\n");
497
498         g_hThisModule = hinstDLL;
499
500         {
501             char szProcess[MAX_PATH];
502             GetModuleFileNameA(NULL, szProcess, sizeof szProcess);
503             debugPrintf("  attached to %s\n", szProcess);
504         }
505
506         /*
507          * Calling LoadLibrary inside DllMain is strongly discouraged.  But it
508          * works quite well, provided that the loaded DLL does not require or do
509          * anything special in its DllMain, which seems to be the general case.
510          *
511          * See also:
512          * - http://stackoverflow.com/questions/4370812/calling-loadlibrary-from-dllmain
513          * - http://msdn.microsoft.com/en-us/library/ms682583
514          */
515
516 #if !USE_SHARED_MEM
517         szNewDllName = getenv("INJECT_DLL");
518         if (!szNewDllName) {
519             debugPrintf("warning: INJECT_DLL not set\n");
520             return FALSE;
521         }
522 #else
523         static char szSharedMemCopy[MAX_PATH];
524         GetSharedMem(szSharedMemCopy, sizeof szSharedMemCopy);
525         szNewDllName = szSharedMemCopy;
526 #endif
527         debugPrintf("  injecting %s\n", szNewDllName);
528
529         hNewModule = LoadLibraryA(szNewDllName);
530         if (!hNewModule) {
531             debugPrintf("warning: failed to load %s\n", szNewDllName);
532             return FALSE;
533         }
534
535         szNewDllBaseName = getBaseName(szNewDllName);
536         if (stricmp(szNewDllBaseName, "dxgitrace.dll") == 0) {
537             replacements[numReplacements].szMatchModule = "dxgi.dll";
538             replacements[numReplacements].hReplaceModule = hNewModule;
539             ++numReplacements;
540
541             replacements[numReplacements].szMatchModule = "d3d10.dll";
542             replacements[numReplacements].hReplaceModule = hNewModule;
543             ++numReplacements;
544
545             replacements[numReplacements].szMatchModule = "d3d10_1.dll";
546             replacements[numReplacements].hReplaceModule = hNewModule;
547             ++numReplacements;
548
549             replacements[numReplacements].szMatchModule = "d3d11.dll";
550             replacements[numReplacements].hReplaceModule = hNewModule;
551             ++numReplacements;
552         } else {
553             replacements[numReplacements].szMatchModule = szNewDllBaseName;
554             replacements[numReplacements].hReplaceModule = hNewModule;
555             ++numReplacements;
556         }
557
558         hookAllModules();
559         break;
560
561     case DLL_THREAD_ATTACH:
562         break;
563
564     case DLL_THREAD_DETACH:
565         break;
566
567     case DLL_PROCESS_DETACH:
568         debugPrintf("DLL_PROCESS_DETACH\n");
569         break;
570     }
571     return TRUE;
572 }