X-Git-Url: https://git.cworth.org/git?a=blobdiff_plain;f=wrappers%2Ftrace.py;h=61a2cb95fcf8249c13719464a36bee5891f8f76a;hb=48c661ea6c0f2bd9b76a3385cd946b7d07bc9b5f;hp=eea2fdc22d6f21a57cb928407ead4cebfcf310e6;hpb=a0e97860386786fcdb106471a3908b4ba66242e6;p=apitrace diff --git a/wrappers/trace.py b/wrappers/trace.py index eea2fdc..61a2cb9 100644 --- a/wrappers/trace.py +++ b/wrappers/trace.py @@ -39,6 +39,7 @@ def getWrapperInterfaceName(interface): return "Wrap" + interface.expr + class ComplexValueSerializer(stdapi.OnceVisitor): '''Type visitors which generates serialization functions for complex types. @@ -63,21 +64,20 @@ class ComplexValueSerializer(stdapi.OnceVisitor): self.visit(const.type) def visitStruct(self, struct): - for type, name in struct.members: - self.visit(type) - print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr) - print ' static const char * members[%u] = {' % (len(struct.members),) + print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members)) for type, name, in struct.members: - print ' "%s",' % (name,) - print ' };' - print ' static const trace::StructSig sig = {' - print ' %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members)) - print ' };' - print ' trace::localWriter.beginStruct(&sig);' - for type, name in struct.members: - self.serializer.visit(type, 'value.%s' % (name,)) - print ' trace::localWriter.endStruct();' - print '}' + if name is None: + print ' "",' + else: + print ' "%s",' % (name,) + print '};' + print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,) + if struct.name is None: + structName = '""' + else: + structName = '"%s"' % struct.name + print ' %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag) + print '};' print def visitArray(self, array): @@ -89,22 +89,22 @@ class ComplexValueSerializer(stdapi.OnceVisitor): def visitEnum(self, enum): print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag) for value in enum.values: - print ' {"%s", %s},' % (value, value) + print ' {"%s", %s},' % (value, value) print '};' print print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag) - print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag) + print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag) print '};' print def visitBitmask(self, bitmask): print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag) for value in bitmask.values: - print ' {"%s", %s},' % (value, value) + print ' {"%s", %s},' % (value, value) print '};' print print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag) - print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag) + print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag) print '};' print @@ -150,7 +150,7 @@ class ComplexValueSerializer(stdapi.OnceVisitor): print -class ValueSerializer(stdapi.Visitor): +class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin): '''Visitor which generates code to serialize any type. Simple types are serialized inline here, whereas the serialization of @@ -162,36 +162,40 @@ class ValueSerializer(stdapi.Visitor): print ' trace::localWriter.write%s(%s);' % (literal.kind, instance) def visitString(self, string, instance): - if string.kind == 'String': + if not string.wide: cast = 'const char *' - elif string.kind == 'WString': - cast = 'const wchar_t *' + suffix = 'String' else: - assert False + cast = 'const wchar_t *' + suffix = 'WString' if cast != string.expr: # reinterpret_cast is necessary for GLubyte * <=> char * instance = 'reinterpret_cast<%s>(%s)' % (cast, instance) if string.length is not None: - length = ', %s' % string.length + length = ', %s' % self.expand(string.length) else: length = '' - print ' trace::localWriter.write%s(%s%s);' % (string.kind, instance, length) + print ' trace::localWriter.write%s(%s%s);' % (suffix, instance, length) def visitConst(self, const, instance): self.visit(const.type, instance) def visitStruct(self, struct, instance): - print ' _write__%s(%s);' % (struct.tag, instance) + print ' trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,) + for member in struct.members: + self.visitMember(member, instance) + print ' trace::localWriter.endStruct();' def visitArray(self, array, instance): length = '_c' + array.type.tag index = '_i' + array.type.tag + array_length = self.expand(array.length) print ' if (%s) {' % instance - print ' size_t %s = %s;' % (length, array.length) + print ' size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length) print ' trace::localWriter.beginArray(%s);' % length print ' for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index) print ' trace::localWriter.beginElement();' - self.visit(array.type, '(%s)[%s]' % (instance, index)) + self.visitElement(index, array.type, '(%s)[%s]' % (instance, index)) print ' trace::localWriter.endElement();' print ' }' print ' trace::localWriter.endArray();' @@ -200,7 +204,7 @@ class ValueSerializer(stdapi.Visitor): print ' }' def visitBlob(self, blob, instance): - print ' trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size) + print ' trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size)) def visitEnum(self, enum, instance): print ' trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance) @@ -247,12 +251,21 @@ class ValueSerializer(stdapi.Visitor): if polymorphic.contextLess: print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance) else: - print ' switch (%s) {' % polymorphic.switchExpr + switchExpr = self.expand(polymorphic.switchExpr) + print ' switch (%s) {' % switchExpr for cases, type in polymorphic.iterSwitch(): for case in cases: print ' %s:' % case - self.visit(type, 'static_cast<%s>(%s)' % (type, instance)) + caseInstance = instance + if type.expr is not None: + caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance) + self.visit(type, caseInstance) print ' break;' + if polymorphic.defaultType is None: + print r' default:' + print r' os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,) + print r' trace::localWriter.writeNull();' + print r' break;' print ' }' @@ -272,7 +285,7 @@ class WrapDecider(stdapi.Traverser): self.needsWrapping = True -class ValueWrapper(stdapi.Traverser): +class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin): '''Type visitor which will generate the code to wrap an instance. Wrapping is necessary mostly for interfaces, however interface pointers can @@ -280,13 +293,14 @@ class ValueWrapper(stdapi.Traverser): ''' def visitStruct(self, struct, instance): - for type, name in struct.members: - self.visit(type, "(%s).%s" % (instance, name)) + for member in struct.members: + self.visitMember(member, instance) def visitArray(self, array, instance): + array_length = self.expand(array.length) print " if (%s) {" % instance - print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length - self.visit(array.type, instance + "[_i]") + print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length + self.visitElement('_i', array.type, instance + "[_i]") print " }" print " }" @@ -299,15 +313,17 @@ class ValueWrapper(stdapi.Traverser): elem_type = pointer.type.mutable() if isinstance(elem_type, stdapi.Interface): self.visitInterfacePointer(elem_type, instance) + elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface): + self.visitInterfacePointer(elem_type.type, instance) else: - self.visitPointer(self, pointer, instance) + self.visitPointer(pointer, instance) def visitInterface(self, interface, instance): raise NotImplementedError def visitInterfacePointer(self, interface, instance): print " if (%s) {" % instance - print " %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance) + print " %s = %s::_Create(__FUNCTION__, %s);" % (instance, getWrapperInterfaceName(interface), instance) print " }" def visitPolymorphic(self, type, instance): @@ -320,13 +336,31 @@ class ValueUnwrapper(ValueWrapper): allocated = False + def visitStruct(self, struct, instance): + if not self.allocated: + # Argument is constant. We need to create a non const + print ' {' + print " %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct) + print ' *_t = %s;' % (instance,) + assert instance.startswith('*') + print ' %s = _t;' % (instance[1:],) + instance = '*_t' + self.allocated = True + try: + return ValueWrapper.visitStruct(self, struct, instance) + finally: + print ' }' + else: + return ValueWrapper.visitStruct(self, struct, instance) + def visitArray(self, array, instance): if self.allocated or isinstance(instance, stdapi.Interface): return ValueWrapper.visitArray(self, array, instance) + array_length = self.expand(array.length) elem_type = array.type.mutable() - print " if (%s && %s) {" % (instance, array.length) - print " %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length) - print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length + print " if (%s && %s) {" % (instance, array_length) + print " %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length) + print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length print " _t[_i] = %s[_i];" % instance self.allocated = True self.visit(array.type, "_t[_i]") @@ -348,6 +382,9 @@ class ValueUnwrapper(ValueWrapper): class Tracer: '''Base class to orchestrate the code generation of API tracing.''' + # 0-3 are reserved to memcpy, malloc, free, and realloc + __id = 4 + def __init__(self): self.api = None @@ -365,8 +402,9 @@ class Tracer: self.header(api) # Includes - for header in api.headers: - print header + for module in api.modules: + for header in module.headers: + print header print # Generate the serializer functions @@ -381,14 +419,27 @@ class Tracer: # Function wrappers self.interface = None self.base = None - map(self.traceFunctionDecl, api.functions) - map(self.traceFunctionImpl, api.functions) + for function in api.getAllFunctions(): + self.traceFunctionDecl(function) + for function in api.getAllFunctions(): + self.traceFunctionImpl(function) print self.footer(api) def header(self, api): - pass + print '#ifdef _WIN32' + print '# include // alloca' + print '# ifndef alloca' + print '# define alloca _alloca' + print '# endif' + print '#else' + print '# include // alloca' + print '#endif' + print + print '#include "trace.hpp"' + print + print 'static std::map g_WrappedObjects;' def footer(self, api): pass @@ -396,12 +447,18 @@ class Tracer: def traceFunctionDecl(self, function): # Per-function declarations - if function.args: - print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args])) - else: - print 'static const char ** _%s_args = NULL;' % (function.name,) - print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name) - print + if not function.internal: + if function.args: + print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args])) + else: + print 'static const char ** _%s_args = NULL;' % (function.name,) + print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, self.getFunctionSigId(), function.name, len(function.args), function.name) + print + + def getFunctionSigId(self): + id = Tracer.__id + Tracer.__id += 1 + return id def isFunctionPublic(self, function): return True @@ -414,29 +471,46 @@ class Tracer: print function.prototype() + ' {' if function.type is not stdapi.Void: print ' %s _result;' % function.type + + # No-op if tracing is disabled + print ' if (!trace::isTracingEnabled()) {' + Tracer.invokeFunction(self, function) + if function.type is not stdapi.Void: + print ' return _result;' + else: + print ' return;' + print ' }' + self.traceFunctionImplBody(function) if function.type is not stdapi.Void: - self.wrapRet(function, "_result") print ' return _result;' print '}' print def traceFunctionImplBody(self, function): - print ' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,) - for arg in function.args: - if not arg.output: - self.unwrapArg(function, arg) - self.serializeArg(function, arg) - print ' trace::localWriter.endEnter();' + if not function.internal: + print ' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,) + for arg in function.args: + if not arg.output: + self.unwrapArg(function, arg) + for arg in function.args: + if not arg.output: + self.serializeArg(function, arg) + print ' trace::localWriter.endEnter();' self.invokeFunction(function) - print ' trace::localWriter.beginLeave(_call);' - for arg in function.args: - if arg.output: - self.serializeArg(function, arg) - self.wrapArg(function, arg) - if function.type is not stdapi.Void: - self.serializeRet(function, "_result") - print ' trace::localWriter.endLeave();' + if not function.internal: + print ' trace::localWriter.beginLeave(_call);' + print ' if (%s) {' % self.wasFunctionSuccessful(function) + for arg in function.args: + if arg.output: + self.serializeArg(function, arg) + self.wrapArg(function, arg) + print ' }' + if function.type is not stdapi.Void: + self.serializeRet(function, "_result") + if function.type is not stdapi.Void: + self.wrapRet(function, "_result") + print ' trace::localWriter.endLeave();' def invokeFunction(self, function, prefix='_', suffix=''): if function.type is stdapi.Void: @@ -446,6 +520,13 @@ class Tracer: dispatch = prefix + function.name + suffix print ' %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args])) + def wasFunctionSuccessful(self, function): + if function.type is stdapi.Void: + return 'true' + if str(function.type) == 'HRESULT': + return 'SUCCEEDED(_result)' + return 'true' + def serializeArg(self, function, arg): print ' trace::localWriter.beginArg(%u);' % (arg.index,) self.serializeArgValue(function, arg) @@ -485,9 +566,6 @@ class Tracer: def wrapRet(self, function, instance): self.wrapValue(function.type, instance) - def unwrapRet(self, function, instance): - self.unwrapValue(function.type, instance) - def needsWrapping(self, type): visitor = WrapDecider() visitor.visit(type) @@ -515,31 +593,69 @@ class Tracer: def declareWrapperInterface(self, interface): print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name) print "{" - print "public:" + print "private:" print " %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name) print " virtual ~%s();" % getWrapperInterfaceName(interface) + print "public:" + print " static %s* _Create(const char *functionName, %s * pInstance);" % (getWrapperInterfaceName(interface), interface.name) print for method in interface.iterMethods(): print " " + method.prototype() + ";" print - self.declareWrapperInterfaceVariables(interface) + #print "private:" + for type, name, value in self.enumWrapperInterfaceVariables(interface): + print ' %s %s;' % (type, name) + for i in range(64): + print r' virtual void _dummy%i(void) const {' % i + print r' os::log("error: %s: unexpected virtual method\n");' % interface.name + print r' os::abort();' + print r' }' print "};" print - def declareWrapperInterfaceVariables(self, interface): - #print "private:" - print " DWORD m_dwMagic;" - print " %s * m_pInstance;" % (interface.name,) + def enumWrapperInterfaceVariables(self, interface): + return [ + ("DWORD", "m_dwMagic", "0xd8365d6c"), + ("%s *" % interface.name, "m_pInstance", "pInstance"), + ("void *", "m_pVtbl", "*(void **)pInstance"), + ("UINT", "m_NumMethods", len(list(interface.iterBaseMethods()))), + ] def implementWrapperInterface(self, interface): self.interface = interface + # Private constructor print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name) - print ' m_dwMagic = 0xd8365d6c;' - print ' m_pInstance = pInstance;' + for type, name, value in self.enumWrapperInterfaceVariables(interface): + print ' %s = %s;' % (name, value) print '}' print + + # Public constructor + print '%s *%s::_Create(const char *functionName, %s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name) + print r' std::map::const_iterator it = g_WrappedObjects.find(pInstance);' + print r' if (it != g_WrappedObjects.end()) {' + print r' Wrap%s *pWrapper = (Wrap%s *)it->second;' % (interface.name, interface.name) + print r' assert(pWrapper);' + print r' assert(pWrapper->m_dwMagic == 0xd8365d6c);' + print r' assert(pWrapper->m_pInstance == pInstance);' + print r' if (pWrapper->m_pVtbl == *(void **)pInstance &&' + print r' pWrapper->m_NumMethods >= %s) {' % len(list(interface.iterBaseMethods())) + #print r' os::log("%s: fetched pvObj=%p pWrapper=%p pVtbl=%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' + print r' return pWrapper;' + print r' }' + print r' }' + print r' Wrap%s *pWrapper = new Wrap%s(pInstance);' % (interface.name, interface.name) + #print r' os::log("%%s: created %s pvObj=%%p pWrapper=%%p pVtbl=%%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' % interface.name + print r' g_WrappedObjects[pInstance] = pWrapper;' + print r' return pWrapper;' + print '}' + print + + # Destructor print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface)) + #print r' os::log("%s::Release: deleted pvObj=%%p pWrapper=%%p pVtbl=%%p\n", m_pInstance, this, m_pVtbl);' % interface.name + print r' g_WrappedObjects.erase(m_pInstance);' print '}' print @@ -551,6 +667,10 @@ class Tracer: def implementWrapperInterfaceMethod(self, interface, base, method): print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {' + + if False: + print r' os::log("%%s(%%p -> %%p)\n", "%s", this, m_pInstance);' % (getWrapperInterfaceName(interface) + '::' + method.name) + if method.type is not stdapi.Void: print ' %s _result;' % method.type @@ -562,8 +682,10 @@ class Tracer: print def implementWrapperInterfaceMethodBody(self, interface, base, method): + assert not method.internal + print ' static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args])) - print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1) + print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (self.getFunctionSigId(), interface.name + '::' + method.name, len(method.args) + 1) print ' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base) @@ -574,27 +696,34 @@ class Tracer: for arg in method.args: if not arg.output: self.unwrapArg(method, arg) + for arg in method.args: + if not arg.output: self.serializeArg(method, arg) print ' trace::localWriter.endEnter();' self.invokeMethod(interface, base, method) print ' trace::localWriter.beginLeave(_call);' + + print ' if (%s) {' % self.wasFunctionSuccessful(method) for arg in method.args: if arg.output: self.serializeArg(method, arg) self.wrapArg(method, arg) + print ' }' if method.type is not stdapi.Void: - print ' trace::localWriter.beginReturn();' - self.serializeValue(method.type, "_result") - print ' trace::localWriter.endReturn();' - self.wrapValue(method.type, '_result') - print ' trace::localWriter.endLeave();' + self.serializeRet(method, '_result') + if method.type is not stdapi.Void: + self.wrapRet(method, '_result') + if method.name == 'Release': assert method.type is not stdapi.Void - print ' if (!_result)' - print ' delete this;' + print r' if (!_result) {' + print r' delete this;' + print r' }' + + print ' trace::localWriter.endLeave();' def implementIidWrapper(self, api): print r'static void' @@ -613,7 +742,7 @@ class Tracer: else_ = '' for iface in api.getAllInterfaces(): print r' %sif (riid == IID_%s) {' % (else_, iface.name) - print r' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name) + print r' *ppvObj = Wrap%s::_Create(functionName, (%s *) *ppvObj);' % (iface.name, iface.name) print r' }' else_ = 'else ' print r' %s{' % else_ @@ -666,4 +795,15 @@ class Tracer: print ' trace::localWriter.endEnter();' print ' trace::localWriter.beginLeave(_call);' print ' trace::localWriter.endLeave();' + + def fake_call(self, function, args): + print ' unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,) + for arg, instance in zip(function.args, args): + assert not arg.output + print ' trace::localWriter.beginArg(%u);' % (arg.index,) + self.serializeValue(arg.type, instance) + print ' trace::localWriter.endArg();' + print ' trace::localWriter.endEnter();' + print ' trace::localWriter.beginLeave(_fake_call);' + print ' trace::localWriter.endLeave();'