]> git.cworth.org Git - apitrace/blob - inject/injectee.cpp
Use skiplist-based FastCallSet within trace::CallSet
[apitrace] / inject / injectee.cpp
1 /**************************************************************************
2  *
3  * Copyright 2011-2012 Jose Fonseca
4  * All Rights Reserved.
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to deal
8  * in the Software without restriction, including without limitation the rights
9  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10  * copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in
14  * all copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22  * THE SOFTWARE.
23  *
24  **************************************************************************/
25
26
27 /*
28  * Code for the DLL that will be injected in the target process.
29  *
30  * The injected DLL will manipulate the import tables to hook the
31  * modules/functions of interest.
32  *
33  * See also:
34  * - http://www.codeproject.com/KB/system/api_spying_hack.aspx
35  * - http://www.codeproject.com/KB/threads/APIHooking.aspx
36  * - http://msdn.microsoft.com/en-us/magazine/cc301808.aspx
37  */
38
39
40 #include <assert.h>
41 #include <stdio.h>
42 #include <stdarg.h>
43
44 #include <windows.h>
45 #include <tlhelp32.h>
46
47 #include "inject.h"
48
49
50 #define VERBOSITY 0
51 #define NOOP 0
52
53
54 static CRITICAL_SECTION Mutex = {(PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0};
55
56
57 static void
58 debugPrintf(const char *format, ...)
59 {
60     char buf[512];
61
62     va_list ap;
63     va_start(ap, format);
64     _vsnprintf(buf, sizeof buf, format, ap);
65     va_end(ap);
66
67     OutputDebugStringA(buf);
68 }
69
70
71 static HMODULE WINAPI
72 MyLoadLibraryA(LPCSTR lpLibFileName);
73
74 static HMODULE WINAPI
75 MyLoadLibraryW(LPCWSTR lpLibFileName);
76
77 static HMODULE WINAPI
78 MyLoadLibraryExA(LPCSTR lpFileName, HANDLE hFile, DWORD dwFlags);
79
80 static HMODULE WINAPI
81 MyLoadLibraryExW(LPCWSTR lpFileName, HANDLE hFile, DWORD dwFlags);
82
83 static FARPROC WINAPI
84 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName);
85
86
87 static const char *
88 getImportDescriptionName(HMODULE hModule, const PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor) {
89     const char* szName = (const char*)((PBYTE)hModule + pImportDescriptor->Name);
90     return szName;
91 }
92
93
94 static PIMAGE_IMPORT_DESCRIPTOR
95 getImportDescriptor(HMODULE hModule,
96                     const char *szModule,
97                     const char *pszDllName)
98 {
99     MEMORY_BASIC_INFORMATION MemoryInfo;
100     if (VirtualQuery(hModule, &MemoryInfo, sizeof MemoryInfo) != sizeof MemoryInfo) {
101         debugPrintf("%s: %s: VirtualQuery failed\n", __FUNCTION__, szModule);
102         return NULL;
103     }
104     if (MemoryInfo.Protect & (PAGE_NOACCESS | PAGE_EXECUTE)) {
105         debugPrintf("%s: %s: no read access (Protect = 0x%08x)\n", __FUNCTION__, szModule, MemoryInfo.Protect);
106         return NULL;
107     }
108
109     PIMAGE_DOS_HEADER pDosHeader = (PIMAGE_DOS_HEADER)hModule;
110     PIMAGE_NT_HEADERS pNtHeaders = (PIMAGE_NT_HEADERS)((PBYTE)hModule + pDosHeader->e_lfanew);
111
112     PIMAGE_OPTIONAL_HEADER pOptionalHeader = &pNtHeaders->OptionalHeader;
113
114     UINT_PTR ImportAddress = pOptionalHeader->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress;
115
116     if (!ImportAddress) {
117         return NULL;
118     }
119
120     PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = (PIMAGE_IMPORT_DESCRIPTOR)((PBYTE)hModule + ImportAddress);
121
122     while (pImportDescriptor->FirstThunk) {
123         const char* szName = getImportDescriptionName(hModule, pImportDescriptor);
124         if (stricmp(pszDllName, szName) == 0) {
125             return pImportDescriptor;
126         }
127         ++pImportDescriptor;
128     }
129
130     return NULL;
131 }
132
133
134 static BOOL
135 replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress)
136 {
137     DWORD flOldProtect;
138
139     if (*lpOldAddress == lpNewAddress) {
140         return TRUE;
141     }
142
143     EnterCriticalSection(&Mutex);
144
145     if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, PAGE_READWRITE, &flOldProtect))) {
146         LeaveCriticalSection(&Mutex);
147         return FALSE;
148     }
149
150     *lpOldAddress = lpNewAddress;
151
152     if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, flOldProtect, &flOldProtect))) {
153         LeaveCriticalSection(&Mutex);
154         return FALSE;
155     }
156
157     LeaveCriticalSection(&Mutex);
158     return TRUE;
159 }
160
161
162 static LPVOID *
163 getOldFunctionAddress(HMODULE hModule,
164                     PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
165                     const char* pszFunctionName)
166 {
167     PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
168     PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
169
170     //debugPrintf("  %s\n", __FUNCTION__);
171
172     while (pOriginalFirstThunk->u1.Function) {
173         PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
174         const char* szName = (const char* )pImport->Name;
175         //debugPrintf("    %s\n", szName);
176         if (strcmp(pszFunctionName, szName) == 0) {
177             //debugPrintf("  %s succeeded\n", __FUNCTION__);
178             return (LPVOID *)(&pFirstThunk->u1.Function);
179         }
180         ++pOriginalFirstThunk;
181         ++pFirstThunk;
182     }
183
184     //debugPrintf("  %s failed\n", __FUNCTION__);
185
186     return NULL;
187 }
188
189
190 static void
191 replaceModule(HMODULE hModule,
192               const char *szModule,
193               PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
194               HMODULE hNewModule)
195 {
196     PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
197     PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
198
199     while (pOriginalFirstThunk->u1.Function) {
200         PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
201         const char* szFunctionName = (const char* )pImport->Name;
202         if (VERBOSITY > 0) {
203             debugPrintf("      hooking %s->%s!%s\n", szModule,
204                     getImportDescriptionName(hModule, pImportDescriptor),
205                     szFunctionName);
206         }
207
208         PROC pNewProc = GetProcAddress(hNewModule, szFunctionName);
209         if (!pNewProc) {
210             debugPrintf("warning: no replacement for %s\n", szFunctionName);
211         } else {
212             LPVOID *lpOldAddress = (LPVOID *)(&pFirstThunk->u1.Function);
213             replaceAddress(lpOldAddress, (LPVOID)pNewProc);
214         }
215
216         ++pOriginalFirstThunk;
217         ++pFirstThunk;
218     }
219 }
220
221
222 static BOOL
223 hookFunction(HMODULE hModule,
224              const char *szModule,
225              const char *pszDllName,
226              const char *pszFunctionName,
227              LPVOID lpNewAddress)
228 {
229     PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
230     if (pImportDescriptor == NULL) {
231         return FALSE;
232     }
233     LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName);
234     if (lpOldFunctionAddress == NULL) {
235         return FALSE;
236     }
237
238     if (*lpOldFunctionAddress == lpNewAddress) {
239         return TRUE;
240     }
241
242     if (VERBOSITY >= 3) {
243         debugPrintf("      hooking %s->%s!%s\n", szModule, pszDllName, pszFunctionName);
244     }
245
246     return replaceAddress(lpOldFunctionAddress, lpNewAddress);
247 }
248
249
250 static BOOL
251 replaceImport(HMODULE hModule,
252               const char *szModule,
253               const char *pszDllName,
254               HMODULE hNewModule)
255 {
256     if (NOOP) {
257         return TRUE;
258     }
259
260     PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
261     if (pImportDescriptor == NULL) {
262         return TRUE;
263     }
264
265     replaceModule(hModule, szModule, pImportDescriptor, hNewModule);
266
267     return TRUE;
268 }
269
270 static HMODULE g_hThisModule = NULL;
271
272
273 struct Replacement {
274     const char *szMatchModule;
275     HMODULE hReplaceModule;
276 };
277
278 static unsigned numReplacements = 0;
279 static Replacement replacements[32];
280
281
282
283 static void
284 hookModule(HMODULE hModule,
285            const char *szModule)
286 {
287     if (hModule == g_hThisModule) {
288         return;
289     }
290
291     for (unsigned i = 0; i < numReplacements; ++i) {
292         if (hModule == replacements[i].hReplaceModule) {
293             return;
294         }
295     }
296
297     hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryA", (LPVOID)MyLoadLibraryA);
298     hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryW", (LPVOID)MyLoadLibraryW);
299     hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExA", (LPVOID)MyLoadLibraryExA);
300     hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExW", (LPVOID)MyLoadLibraryExW);
301     hookFunction(hModule, szModule, "kernel32.dll", "GetProcAddress", (LPVOID)MyGetProcAddress);
302
303     const char *szBaseName = getBaseName(szModule);
304     for (unsigned i = 0; i < numReplacements; ++i) {
305         if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
306             return;
307         }
308     }
309
310     /* Don't hook internal dependencies */
311     if (stricmp(szBaseName, "d3d10core.dll") == 0 ||
312         stricmp(szBaseName, "d3d10level9.dll") == 0 ||
313         stricmp(szBaseName, "d3d10sdklayers.dll") == 0 ||
314         stricmp(szBaseName, "d3d10_1core.dll") == 0 ||
315         stricmp(szBaseName, "d3d11sdklayers.dll") == 0 ||
316         stricmp(szBaseName, "d3d11_1sdklayers.dll") == 0) {
317         return;
318     }
319
320     for (unsigned i = 0; i < numReplacements; ++i) {
321         replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
322         replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
323         replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
324     }
325 }
326
327 static void
328 hookAllModules(void)
329 {
330     HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
331     if (hModuleSnap == INVALID_HANDLE_VALUE) {
332         return;
333     }
334
335     MODULEENTRY32 me32;
336     me32.dwSize = sizeof me32;
337
338     if (VERBOSITY > 0) {
339         static bool first = true;
340         if (first) {
341             if (Module32First(hModuleSnap, &me32)) {
342                 debugPrintf("  modules:\n");
343                 do  {
344                     debugPrintf("     %s\n", me32.szExePath);
345                 } while (Module32Next(hModuleSnap, &me32));
346             }
347             first = false;
348         }
349     }
350
351     if (Module32First(hModuleSnap, &me32)) {
352         do  {
353             hookModule(me32.hModule, me32.szExePath);
354         } while (Module32Next(hModuleSnap, &me32));
355     }
356
357     CloseHandle(hModuleSnap);
358 }
359
360
361
362
363 static HMODULE WINAPI
364 MyLoadLibrary(LPCSTR lpLibFileName, HANDLE hFile = NULL, DWORD dwFlags = 0)
365 {
366     // To Send the information to the server informing that,
367     // LoadLibrary is invoked.
368     HMODULE hModule = LoadLibraryExA(lpLibFileName, hFile, dwFlags);
369
370     //hookModule(hModule, lpLibFileName);
371     hookAllModules();
372
373     return hModule;
374 }
375
376 static HMODULE WINAPI
377 MyLoadLibraryA(LPCSTR lpLibFileName)
378 {
379     if (VERBOSITY >= 2) {
380         debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
381     }
382
383     if (VERBOSITY > 0) {
384         const char *szBaseName = getBaseName(lpLibFileName);
385         for (unsigned i = 0; i < numReplacements; ++i) {
386             if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
387                 debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
388 #ifdef __GNUC__
389                 void *caller = __builtin_return_address (0);
390
391                 HMODULE hModule = 0;
392                 BOOL bRet = GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
393                                          (LPCTSTR)caller,
394                                          &hModule);
395                 assert(bRet);
396                 char szCaller[256];
397                 DWORD dwRet = GetModuleFileNameA(hModule, szCaller, sizeof szCaller);
398                 assert(dwRet);
399                 debugPrintf("  called from %s\n", szCaller);
400 #endif
401                 break;
402             }
403         }
404     }
405
406     return MyLoadLibrary(lpLibFileName);
407 }
408
409 static HMODULE WINAPI
410 MyLoadLibraryW(LPCWSTR lpLibFileName)
411 {
412     if (VERBOSITY >= 2) {
413         debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
414     }
415
416     char szFileName[256];
417     wcstombs(szFileName, lpLibFileName, sizeof szFileName);
418
419     return MyLoadLibrary(szFileName);
420 }
421
422 static HMODULE WINAPI
423 MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
424 {
425     if (VERBOSITY >= 2) {
426         debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
427     }
428     return MyLoadLibrary(lpLibFileName, hFile, dwFlags);
429 }
430
431 static HMODULE WINAPI
432 MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
433 {
434     if (VERBOSITY >= 2) {
435         debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
436     }
437
438     char szFileName[256];
439     wcstombs(szFileName, lpLibFileName, sizeof szFileName);
440
441     return MyLoadLibrary(szFileName, hFile, dwFlags);
442 }
443
444 static FARPROC WINAPI
445 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName) {
446
447     if (VERBOSITY >= 99) {
448         /* XXX this can cause segmentation faults */
449         debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpProcName);
450     }
451
452     assert(hModule != g_hThisModule);
453     for (unsigned i = 0; i < numReplacements; ++i) {
454         if (hModule == replacements[i].hReplaceModule) {
455             return GetProcAddress(hModule, lpProcName);
456         }
457     }
458
459     if (!NOOP) {
460         char szModule[256];
461         DWORD dwRet = GetModuleFileNameA(hModule, szModule, sizeof szModule);
462         assert(dwRet);
463         const char *szBaseName = getBaseName(szModule);
464
465         for (unsigned i = 0; i < numReplacements; ++i) {
466
467             if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
468                 if (VERBOSITY > 0) {
469                     debugPrintf("  %s(\"%s\", \"%s\")\n", __FUNCTION__, szModule, lpProcName);
470                 }
471                 FARPROC pProcAddress = GetProcAddress(replacements[i].hReplaceModule, lpProcName);
472                 if (pProcAddress) {
473                     if (VERBOSITY >= 2) {
474                         debugPrintf("      replacing %s!%s\n", szBaseName, lpProcName);
475                     }
476                     return pProcAddress;
477                 } else {
478                     if (VERBOSITY > 0) {
479                         debugPrintf("      ignoring %s!%s\n", szBaseName, lpProcName);
480                     }
481                     break;
482                 }
483             }
484         }
485     }
486
487     return GetProcAddress(hModule, lpProcName);
488 }
489
490
491 EXTERN_C BOOL WINAPI
492 DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
493 {
494     const char *szNewDllName = NULL;
495     HMODULE hNewModule = NULL;
496     const char *szNewDllBaseName;
497
498     switch (fdwReason) {
499     case DLL_PROCESS_ATTACH:
500         if (VERBOSITY > 0) {
501             debugPrintf("DLL_PROCESS_ATTACH\n");
502         }
503
504         g_hThisModule = hinstDLL;
505
506         {
507             char szProcess[MAX_PATH];
508             GetModuleFileNameA(NULL, szProcess, sizeof szProcess);
509             if (VERBOSITY > 0) {
510                 debugPrintf("  attached to %s\n", szProcess);
511             }
512         }
513
514         /*
515          * Calling LoadLibrary inside DllMain is strongly discouraged.  But it
516          * works quite well, provided that the loaded DLL does not require or do
517          * anything special in its DllMain, which seems to be the general case.
518          *
519          * See also:
520          * - http://stackoverflow.com/questions/4370812/calling-loadlibrary-from-dllmain
521          * - http://msdn.microsoft.com/en-us/library/ms682583
522          */
523
524 #if !USE_SHARED_MEM
525         szNewDllName = getenv("INJECT_DLL");
526         if (!szNewDllName) {
527             debugPrintf("warning: INJECT_DLL not set\n");
528             return FALSE;
529         }
530 #else
531         static char szSharedMemCopy[MAX_PATH];
532         GetSharedMem(szSharedMemCopy, sizeof szSharedMemCopy);
533         szNewDllName = szSharedMemCopy;
534 #endif
535         if (VERBOSITY > 0) {
536             debugPrintf("  injecting %s\n", szNewDllName);
537         }
538
539         hNewModule = LoadLibraryA(szNewDllName);
540         if (!hNewModule) {
541             debugPrintf("warning: failed to load %s\n", szNewDllName);
542             return FALSE;
543         }
544
545         szNewDllBaseName = getBaseName(szNewDllName);
546         if (stricmp(szNewDllBaseName, "dxgitrace.dll") == 0) {
547             replacements[numReplacements].szMatchModule = "dxgi.dll";
548             replacements[numReplacements].hReplaceModule = hNewModule;
549             ++numReplacements;
550
551             replacements[numReplacements].szMatchModule = "d3d10.dll";
552             replacements[numReplacements].hReplaceModule = hNewModule;
553             ++numReplacements;
554
555             replacements[numReplacements].szMatchModule = "d3d10_1.dll";
556             replacements[numReplacements].hReplaceModule = hNewModule;
557             ++numReplacements;
558
559             replacements[numReplacements].szMatchModule = "d3d11.dll";
560             replacements[numReplacements].hReplaceModule = hNewModule;
561             ++numReplacements;
562         } else {
563             replacements[numReplacements].szMatchModule = szNewDllBaseName;
564             replacements[numReplacements].hReplaceModule = hNewModule;
565             ++numReplacements;
566         }
567
568         hookAllModules();
569         break;
570
571     case DLL_THREAD_ATTACH:
572         break;
573
574     case DLL_THREAD_DETACH:
575         break;
576
577     case DLL_PROCESS_DETACH:
578         if (VERBOSITY > 0) {
579             debugPrintf("DLL_PROCESS_DETACH\n");
580         }
581         break;
582     }
583     return TRUE;
584 }