"""Trace code generation for Windows DLLs."""
+import ntpath
+
from trace import Tracer
from dispatch import Dispatcher
+from specs.stdapi import API
-class DllTracer(Tracer):
+class DllDispatcher(Dispatcher):
- def __init__(self, dllname):
- self.dllname = dllname
-
- def header(self, api):
- print '''
-static HINSTANCE g_hDll = NULL;
+ def dispatchModule(self, module):
+ tag = module.name.upper()
+ print r'HMODULE g_h%sModule = NULL;' % (tag,)
+ print r''
+ print r'static PROC'
+ print r'_get%sProcAddress(LPCSTR lpProcName) {' % tag
+ print r' if (!g_h%sModule) {' % tag
+ print r' char szDll[MAX_PATH] = {0};'
+ print r' if (!GetSystemDirectoryA(szDll, MAX_PATH)) {'
+ print r' return NULL;'
+ print r' }'
+ print r' strcat(szDll, "\\\\%s.dll");' % module.name
+ print r' g_h%sModule = LoadLibraryA(szDll);' % tag
+ print r' if (!g_h%sModule) {' % tag
+ print r' return NULL;'
+ print r' }'
+ print r' }'
+ print r' return GetProcAddress(g_h%sModule, lpProcName);' % tag
+ print r'}'
+ print r''
-static PROC
-_getPublicProcAddress(LPCSTR lpProcName)
-{
- if (!g_hDll) {
- char szDll[MAX_PATH] = {0};
-
- if (!GetSystemDirectoryA(szDll, MAX_PATH)) {
- return NULL;
- }
-
- strcat(szDll, "\\\\%s");
-
- g_hDll = LoadLibraryA(szDll);
- if (!g_hDll) {
- return NULL;
- }
- }
-
- return GetProcAddress(g_hDll, lpProcName);
-}
+ Dispatcher.dispatchModule(self, module)
-''' % self.dllname
+ def getProcAddressName(self, module, function):
+ assert self.isFunctionPublic(module, function)
+ return '_get%sProcAddress' % (module.name.upper())
- dispatcher = Dispatcher()
- dispatcher.dispatch_api(api)
- Tracer.header(self, api)
+class DllTracer(Tracer):
+ def header(self, api):
+
+ for module in api.modules:
+ dispatcher = DllDispatcher()
+ dispatcher.dispatchModule(module)
+
+ Tracer.header(self, api)