X-Git-Url: https://git.cworth.org/git?a=blobdiff_plain;f=wrappers%2Ftrace.py;h=958c07279a1bf55eab84d44d349ef6a3fded764f;hb=537c507874cdde0b507d306ac058767f506da8e2;hp=23a537a1e669d56359845b28bcab2c0c135ff089;hpb=9782b29faa129e4bac5931a929602f18d2eb56ce;p=apitrace diff --git a/wrappers/trace.py b/wrappers/trace.py index 23a537a..958c072 100644 --- a/wrappers/trace.py +++ b/wrappers/trace.py @@ -87,24 +87,24 @@ class ComplexValueSerializer(stdapi.OnceVisitor): pass def visitEnum(self, enum): - print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag) + print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag) for value in enum.values: print ' {"%s", %s},' % (value, value) print '};' print - print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag) - print ' %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag) + print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag) + print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag) print '};' print def visitBitmask(self, bitmask): - print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag) + print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag) for value in bitmask.values: print ' {"%s", %s},' % (value, value) print '};' print - print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag) - print ' %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag) + print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag) + print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag) print '};' print @@ -136,6 +136,8 @@ class ComplexValueSerializer(stdapi.OnceVisitor): pass def visitPolymorphic(self, polymorphic): + if not polymorphic.contextLess: + return print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr) print ' switch (selector) {' for cases, type in polymorphic.iterSwitch(): @@ -182,10 +184,10 @@ class ValueSerializer(stdapi.Visitor): print ' _write__%s(%s);' % (struct.tag, instance) def visitArray(self, array, instance): - length = '__c' + array.type.tag - index = '__i' + array.type.tag + length = '_c' + array.type.tag + index = '_i' + array.type.tag print ' if (%s) {' % instance - print ' size_t %s = %s;' % (length, array.length) + print ' size_t %s = %s > 0 ? %s : 0;' % (length, array.length, array.length) print ' trace::localWriter.beginArray(%s);' % length print ' for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index) print ' trace::localWriter.beginElement();' @@ -201,10 +203,10 @@ class ValueSerializer(stdapi.Visitor): print ' trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size) def visitEnum(self, enum, instance): - print ' trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance) + print ' trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance) def visitBitmask(self, bitmask, instance): - print ' trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance) + print ' trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance) def visitPointer(self, pointer, instance): print ' if (%s) {' % instance @@ -218,13 +220,13 @@ class ValueSerializer(stdapi.Visitor): print ' }' def visitIntPointer(self, pointer, instance): - print ' trace::localWriter.writeOpaque((const void *)%s);' % instance + print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance def visitObjPointer(self, pointer, instance): - print ' trace::localWriter.writeOpaque((const void *)%s);' % instance + print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance def visitLinearPointer(self, pointer, instance): - print ' trace::localWriter.writeOpaque((const void *)%s);' % instance + print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance def visitReference(self, reference, instance): self.visit(reference.type, instance) @@ -236,13 +238,22 @@ class ValueSerializer(stdapi.Visitor): self.visit(alias.type, instance) def visitOpaque(self, opaque, instance): - print ' trace::localWriter.writeOpaque((const void *)%s);' % instance + print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance def visitInterface(self, interface, instance): - print ' trace::localWriter.writeOpaque((const void *)&%s);' % instance + assert False def visitPolymorphic(self, polymorphic, instance): - print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance) + if polymorphic.contextLess: + print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance) + else: + print ' switch (%s) {' % polymorphic.switchExpr + for cases, type in polymorphic.iterSwitch(): + for case in cases: + print ' %s:' % case + self.visit(type, 'static_cast<%s>(%s)' % (type, instance)) + print ' break;' + print ' }' class WrapDecider(stdapi.Traverser): @@ -289,7 +300,7 @@ class ValueWrapper(stdapi.Traverser): if isinstance(elem_type, stdapi.Interface): self.visitInterfacePointer(elem_type, instance) else: - self.visitPointer(self, pointer, instance) + self.visitPointer(pointer, instance) def visitInterface(self, interface, instance): raise NotImplementedError @@ -348,7 +359,7 @@ class Tracer: return ValueSerializer() - def trace_api(self, api): + def traceApi(self, api): self.api = api self.header(api) @@ -365,12 +376,11 @@ class Tracer: print # Interfaces wrapers - interfaces = api.getAllInterfaces() - map(self.declareWrapperInterface, interfaces) - map(self.implementWrapperInterface, interfaces) - print + self.traceInterfaces(api) # Function wrappers + self.interface = None + self.base = None map(self.traceFunctionDecl, api.functions) map(self.traceFunctionImpl, api.functions) print @@ -378,7 +388,17 @@ class Tracer: self.footer(api) def header(self, api): - pass + print '#ifdef _WIN32' + print '# include // alloca' + print '# ifndef alloca' + print '# define alloca _alloca' + print '# endif' + print '#else' + print '# include // alloca' + print '#endif' + print + print '#include "trace.hpp"' + print def footer(self, api): pass @@ -386,12 +406,13 @@ class Tracer: def traceFunctionDecl(self, function): # Per-function declarations - if function.args: - print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args])) - else: - print 'static const char ** __%s_args = NULL;' % (function.name,) - print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name) - print + if not function.internal: + if function.args: + print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args])) + else: + print 'static const char ** _%s_args = NULL;' % (function.name,) + print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name) + print def isFunctionPublic(self, function): return True @@ -403,36 +424,49 @@ class Tracer: print 'extern "C" PRIVATE' print function.prototype() + ' {' if function.type is not stdapi.Void: - print ' %s __result;' % function.type + print ' %s _result;' % function.type + + # No-op if tracing is disabled + print ' if (!trace::isTracingEnabled()) {' + Tracer.invokeFunction(self, function) + if function.type is not stdapi.Void: + print ' return _result;' + else: + print ' return;' + print ' }' + self.traceFunctionImplBody(function) if function.type is not stdapi.Void: - self.wrapRet(function, "__result") - print ' return __result;' + print ' return _result;' print '}' print def traceFunctionImplBody(self, function): - print ' unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,) - for arg in function.args: - if not arg.output: - self.unwrapArg(function, arg) - self.serializeArg(function, arg) - print ' trace::localWriter.endEnter();' + if not function.internal: + print ' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,) + for arg in function.args: + if not arg.output: + self.unwrapArg(function, arg) + self.serializeArg(function, arg) + print ' trace::localWriter.endEnter();' self.invokeFunction(function) - print ' trace::localWriter.beginLeave(__call);' - for arg in function.args: - if arg.output: - self.serializeArg(function, arg) - self.wrapArg(function, arg) - if function.type is not stdapi.Void: - self.serializeRet(function, "__result") - print ' trace::localWriter.endLeave();' - - def invokeFunction(self, function, prefix='__', suffix=''): + if not function.internal: + print ' trace::localWriter.beginLeave(_call);' + for arg in function.args: + if arg.output: + self.serializeArg(function, arg) + self.wrapArg(function, arg) + if function.type is not stdapi.Void: + self.serializeRet(function, "_result") + print ' trace::localWriter.endLeave();' + if function.type is not stdapi.Void: + self.wrapRet(function, "_result") + + def invokeFunction(self, function, prefix='_', suffix=''): if function.type is stdapi.Void: result = '' else: - result = '__result = ' + result = '_result = ' dispatch = prefix + function.name + suffix print ' %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args])) @@ -446,6 +480,18 @@ class Tracer: def wrapArg(self, function, arg): assert not isinstance(arg.type, stdapi.ObjPointer) + + from specs.winapi import REFIID + riid = None + for other_arg in function.args: + if not other_arg.output and other_arg.type is REFIID: + riid = other_arg + if riid is not None \ + and isinstance(arg.type, stdapi.Pointer) \ + and isinstance(arg.type.type, stdapi.ObjPointer): + self.wrapIid(function, riid, arg) + return + self.wrapValue(arg.type, arg.name) def unwrapArg(self, function, arg): @@ -481,6 +527,15 @@ class Tracer: visitor = ValueUnwrapper() visitor.visit(type, instance) + def traceInterfaces(self, api): + interfaces = api.getAllInterfaces() + if not interfaces: + return + map(self.declareWrapperInterface, interfaces) + self.implementIidWrapper(api) + map(self.implementWrapperInterface, interfaces) + print + def declareWrapperInterface(self, interface): print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name) print "{" @@ -491,120 +546,145 @@ class Tracer: for method in interface.iterMethods(): print " " + method.prototype() + ";" print - self.declareWrapperInterfaceVariables(interface) + #print "private:" + for type, name, value in self.enumWrapperInterfaceVariables(interface): + print ' %s %s;' % (type, name) print "};" print - def declareWrapperInterfaceVariables(self, interface): - #print "private:" - print " DWORD m_dwMagic;" - print " %s * m_pInstance;" % (interface.name,) + def enumWrapperInterfaceVariables(self, interface): + return [ + ("DWORD", "m_dwMagic", "0xd8365d6c"), + ("%s *" % interface.name, "m_pInstance", "pInstance"), + ] def implementWrapperInterface(self, interface): + self.interface = interface + print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name) - print ' m_dwMagic = 0xd8365d6c;' - print ' m_pInstance = pInstance;' + for type, name, value in self.enumWrapperInterfaceVariables(interface): + print ' %s = %s;' % (name, value) print '}' print print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface)) print '}' print + for base, method in interface.iterBaseMethods(): + self.base = base self.implementWrapperInterfaceMethod(interface, base, method) + print def implementWrapperInterfaceMethod(self, interface, base, method): print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {' if method.type is not stdapi.Void: - print ' %s __result;' % method.type + print ' %s _result;' % method.type self.implementWrapperInterfaceMethodBody(interface, base, method) if method.type is not stdapi.Void: - print ' return __result;' + print ' return _result;' print '}' print def implementWrapperInterfaceMethodBody(self, interface, base, method): - print ' static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args])) - print ' static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1) - print ' unsigned __call = trace::localWriter.beginEnter(&__sig);' - print ' trace::localWriter.beginArg(0);' - print ' trace::localWriter.writeOpaque((const void *)m_pInstance);' - print ' trace::localWriter.endArg();' + assert not method.internal - from specs.winapi import REFIID + print ' static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args])) + print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1) - riid = None + print ' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base) + + print ' unsigned _call = trace::localWriter.beginEnter(&_sig);' + print ' trace::localWriter.beginArg(0);' + print ' trace::localWriter.writePointer((uintptr_t)m_pInstance);' + print ' trace::localWriter.endArg();' for arg in method.args: if not arg.output: self.unwrapArg(method, arg) self.serializeArg(method, arg) - if arg.type is REFIID: - riid = arg print ' trace::localWriter.endEnter();' self.invokeMethod(interface, base, method) - print ' trace::localWriter.beginLeave(__call);' + print ' trace::localWriter.beginLeave(_call);' for arg in method.args: if arg.output: self.serializeArg(method, arg) self.wrapArg(method, arg) - if riid is not None and isinstance(arg.type, stdapi.Pointer): - assert isinstance(arg.type.type, stdapi.ObjPointer) - obj_type = arg.type.type.type - assert obj_type is stdapi.Void - self.wrapIid(interface, method, riid, arg) - riid = None - assert riid is None if method.type is not stdapi.Void: - print ' trace::localWriter.beginReturn();' - self.serializeValue(method.type, "__result") - print ' trace::localWriter.endReturn();' - self.wrapValue(method.type, '__result') + self.serializeRet(method, '_result') print ' trace::localWriter.endLeave();' + if method.type is not stdapi.Void: + self.wrapRet(method, '_result') + if method.name == 'Release': assert method.type is not stdapi.Void - print ' if (!__result)' + print ' if (!_result)' print ' delete this;' - def wrapIid(self, interface, method, riid, out): - print ' if (%s && *%s) {' % (out.name, out.name) - print ' if (*%s == m_pInstance) {' % (out.name,) - print ' *%s = this;' % (out.name,) - print ' }' - for iface in self.api.getAllInterfaces(): - print r' else if (%s == IID_%s) {' % (riid.name, iface.name) - print r' *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name) - print r' }' - print r' else {' - print r' os::log("apitrace: warning: %s::%s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",' - print r' "%s", "%s",' % (interface.name, method.name) - print r' %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name) - print r' %s.Data4[0],' % (riid.name,) - print r' %s.Data4[1],' % (riid.name,) - print r' %s.Data4[2],' % (riid.name,) - print r' %s.Data4[3],' % (riid.name,) - print r' %s.Data4[4],' % (riid.name,) - print r' %s.Data4[5],' % (riid.name,) - print r' %s.Data4[6],' % (riid.name,) - print r' %s.Data4[7]);' % (riid.name,) + def implementIidWrapper(self, api): + print r'static void' + print r'warnIID(const char *functionName, REFIID riid, const char *reason) {' + print r' os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",' + print r' functionName, reason,' + print r' riid.Data1, riid.Data2, riid.Data3,' + print r' riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);' + print r'}' + print + print r'static void' + print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {' + print r' if (!ppvObj || !*ppvObj) {' + print r' return;' + print r' }' + else_ = '' + for iface in api.getAllInterfaces(): + print r' %sif (riid == IID_%s) {' % (else_, iface.name) + print r' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name) + print r' }' + else_ = 'else ' + print r' %s{' % else_ + print r' warnIID(functionName, riid, "unknown");' + print r' }' + print r'}' + print + + def wrapIid(self, function, riid, out): + # Cast output arg to `void **` if necessary + out_name = out.name + obj_type = out.type.type.type + if not obj_type is stdapi.Void: + assert isinstance(obj_type, stdapi.Interface) + out_name = 'reinterpret_cast(%s)' % out_name + + print r' if (%s && *%s) {' % (out.name, out.name) + functionName = function.name + else_ = '' + if self.interface is not None: + functionName = self.interface.name + '::' + functionName + print r' if (*%s == m_pInstance &&' % (out_name,) + print r' (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases()) + print r' *%s = this;' % (out_name,) print r' }' - print ' }' + else_ = 'else ' + print r' %s{' % else_ + print r' wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name) + print r' }' + print r' }' def invokeMethod(self, interface, base, method): if method.type is stdapi.Void: result = '' else: - result = '__result = ' - print ' %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args])) + result = '_result = ' + print ' %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args])) def emit_memcpy(self, dest, src, length): - print ' unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);' + print ' unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig);' print ' trace::localWriter.beginArg(0);' - print ' trace::localWriter.writeOpaque(%s);' % dest + print ' trace::localWriter.writePointer((uintptr_t)%s);' % dest print ' trace::localWriter.endArg();' print ' trace::localWriter.beginArg(1);' print ' trace::localWriter.writeBlob(%s, %s);' % (src, length) @@ -613,6 +693,6 @@ class Tracer: print ' trace::localWriter.writeUInt(%s);' % length print ' trace::localWriter.endArg();' print ' trace::localWriter.endEnter();' - print ' trace::localWriter.beginLeave(__call);' + print ' trace::localWriter.beginLeave(_call);' print ' trace::localWriter.endLeave();'