1 ##########################################################################
3 # Copyright 2008-2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
26 """Common trace code generation."""
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
35 import specs.stdapi as stdapi
38 def getWrapperInterfaceName(interface):
39 return "Wrap" + interface.expr
44 '''Mixin class that provides a bunch of methods to expand C expressions
45 from the specifications.'''
50 def expand(self, expr):
51 # Expand a C expression, replacing certain variables
52 if not isinstance(expr, basestring):
56 if self.__structs is not None:
57 variables['self'] = '(%s)' % self.__structs[0]
58 if self.__indices is not None:
59 variables['i'] = self.__indices[0]
61 expandedExpr = expr.format(**variables)
62 if expandedExpr != expr and 0:
63 sys.stderr.write(" %r -> %r\n" % (expr, expandedExpr))
66 def visitMember(self, structInstance, member_type, *args, **kwargs):
67 self.__structs = (structInstance, self.__structs)
69 return self.visit(member_type, *args, **kwargs)
71 _, self.__structs = self.__structs
73 def visitElement(self, element_index, element_type, *args, **kwargs):
74 self.__indices = (element_index, self.__indices)
76 return self.visit(element_type, *args, **kwargs)
78 _, self.__indices = self.__indices
81 class ComplexValueSerializer(stdapi.OnceVisitor):
82 '''Type visitors which generates serialization functions for
85 Simple types are serialized inline.
88 def __init__(self, serializer):
89 stdapi.OnceVisitor.__init__(self)
90 self.serializer = serializer
92 def visitVoid(self, literal):
95 def visitLiteral(self, literal):
98 def visitString(self, string):
101 def visitConst(self, const):
102 self.visit(const.type)
104 def visitStruct(self, struct):
105 print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
106 for type, name, in struct.members:
107 print ' "%s",' % (name,)
109 print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
110 print ' %u, "%s", %u, _struct%s_members' % (struct.id, struct.name, len(struct.members), struct.tag)
114 def visitArray(self, array):
115 self.visit(array.type)
117 def visitBlob(self, array):
120 def visitEnum(self, enum):
121 print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
122 for value in enum.values:
123 print ' {"%s", %s},' % (value, value)
126 print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
127 print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
131 def visitBitmask(self, bitmask):
132 print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
133 for value in bitmask.values:
134 print ' {"%s", %s},' % (value, value)
137 print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
138 print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
142 def visitPointer(self, pointer):
143 self.visit(pointer.type)
145 def visitIntPointer(self, pointer):
148 def visitObjPointer(self, pointer):
149 self.visit(pointer.type)
151 def visitLinearPointer(self, pointer):
152 self.visit(pointer.type)
154 def visitHandle(self, handle):
155 self.visit(handle.type)
157 def visitReference(self, reference):
158 self.visit(reference.type)
160 def visitAlias(self, alias):
161 self.visit(alias.type)
163 def visitOpaque(self, opaque):
166 def visitInterface(self, interface):
169 def visitPolymorphic(self, polymorphic):
170 if not polymorphic.contextLess:
172 print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
173 print ' switch (selector) {'
174 for cases, type in polymorphic.iterSwitch():
177 self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
184 class ValueSerializer(stdapi.Visitor, ExpanderMixin):
185 '''Visitor which generates code to serialize any type.
187 Simple types are serialized inline here, whereas the serialization of
188 complex types is dispatched to the serialization functions generated by
189 ComplexValueSerializer visitor above.
193 #stdapi.Visitor.__init__(self)
197 def visitLiteral(self, literal, instance):
198 print ' trace::localWriter.write%s(%s);' % (literal.kind, instance)
200 def visitString(self, string, instance):
202 cast = 'const char *'
205 cast = 'const wchar_t *'
207 if cast != string.expr:
208 # reinterpret_cast is necessary for GLubyte * <=> char *
209 instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
210 if string.length is not None:
211 length = ', %s' % string.length
214 print ' trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
216 def visitConst(self, const, instance):
217 self.visit(const.type, instance)
219 def visitStruct(self, struct, instance):
220 print ' trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
221 for type, name in struct.members:
222 self.visitMember(instance, type, '(%s).%s' % (instance, name,))
223 print ' trace::localWriter.endStruct();'
225 def visitArray(self, array, instance):
226 length = '_c' + array.type.tag
227 index = '_i' + array.type.tag
228 array_length = self.expand(array.length)
229 print ' if (%s) {' % instance
230 print ' size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
231 print ' trace::localWriter.beginArray(%s);' % length
232 print ' for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
233 print ' trace::localWriter.beginElement();'
234 self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
235 print ' trace::localWriter.endElement();'
237 print ' trace::localWriter.endArray();'
239 print ' trace::localWriter.writeNull();'
242 def visitBlob(self, blob, instance):
243 print ' trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
245 def visitEnum(self, enum, instance):
246 print ' trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
248 def visitBitmask(self, bitmask, instance):
249 print ' trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
251 def visitPointer(self, pointer, instance):
252 print ' if (%s) {' % instance
253 print ' trace::localWriter.beginArray(1);'
254 print ' trace::localWriter.beginElement();'
255 self.visit(pointer.type, "*" + instance)
256 print ' trace::localWriter.endElement();'
257 print ' trace::localWriter.endArray();'
259 print ' trace::localWriter.writeNull();'
262 def visitIntPointer(self, pointer, instance):
263 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
265 def visitObjPointer(self, pointer, instance):
266 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
268 def visitLinearPointer(self, pointer, instance):
269 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
271 def visitReference(self, reference, instance):
272 self.visit(reference.type, instance)
274 def visitHandle(self, handle, instance):
275 self.visit(handle.type, instance)
277 def visitAlias(self, alias, instance):
278 self.visit(alias.type, instance)
280 def visitOpaque(self, opaque, instance):
281 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
283 def visitInterface(self, interface, instance):
286 def visitPolymorphic(self, polymorphic, instance):
287 if polymorphic.contextLess:
288 print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
290 print ' switch (%s) {' % polymorphic.switchExpr
291 for cases, type in polymorphic.iterSwitch():
294 self.visit(type, 'static_cast<%s>(%s)' % (type, instance))
299 class WrapDecider(stdapi.Traverser):
300 '''Type visitor which will decide wheter this type will need wrapping or not.
302 For complex types (arrays, structures), we need to know this before hand.
306 self.needsWrapping = False
308 def visitLinearPointer(self, void):
311 def visitInterface(self, interface):
312 self.needsWrapping = True
315 class ValueWrapper(stdapi.Traverser, ExpanderMixin):
316 '''Type visitor which will generate the code to wrap an instance.
318 Wrapping is necessary mostly for interfaces, however interface pointers can
319 appear anywhere inside complex types.
322 def visitStruct(self, struct, instance):
323 for type, name in struct.members:
324 self.visitMember(instance, type, "(%s).%s" % (instance, name))
326 def visitArray(self, array, instance):
327 array_length = self.expand(array.length)
328 print " if (%s) {" % instance
329 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
330 self.visitElement('_i', array.type, instance + "[_i]")
334 def visitPointer(self, pointer, instance):
335 print " if (%s) {" % instance
336 self.visit(pointer.type, "*" + instance)
339 def visitObjPointer(self, pointer, instance):
340 elem_type = pointer.type.mutable()
341 if isinstance(elem_type, stdapi.Interface):
342 self.visitInterfacePointer(elem_type, instance)
344 self.visitPointer(pointer, instance)
346 def visitInterface(self, interface, instance):
347 raise NotImplementedError
349 def visitInterfacePointer(self, interface, instance):
350 print " if (%s) {" % instance
351 print " %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
354 def visitPolymorphic(self, type, instance):
355 # XXX: There might be polymorphic values that need wrapping in the future
356 raise NotImplementedError
359 class ValueUnwrapper(ValueWrapper):
360 '''Reverse of ValueWrapper.'''
364 def visitArray(self, array, instance):
365 if self.allocated or isinstance(instance, stdapi.Interface):
366 return ValueWrapper.visitArray(self, array, instance)
367 array_length = self.expand(array.length)
368 elem_type = array.type.mutable()
369 print " if (%s && %s) {" % (instance, array_length)
370 print " %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
371 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
372 print " _t[_i] = %s[_i];" % instance
373 self.allocated = True
374 self.visit(array.type, "_t[_i]")
376 print " %s = _t;" % instance
379 def visitInterfacePointer(self, interface, instance):
380 print r' if (%s) {' % instance
381 print r' const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
382 print r' if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
383 print r' %s = pWrapper->m_pInstance;' % (instance,)
385 print r' os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
391 '''Base class to orchestrate the code generation of API tracing.'''
396 def serializerFactory(self):
397 '''Create a serializer.
399 Can be overriden by derived classes to inject their own serialzer.
402 return ValueSerializer()
404 def traceApi(self, api):
410 for header in api.headers:
414 # Generate the serializer functions
415 types = api.getAllTypes()
416 visitor = ComplexValueSerializer(self.serializerFactory())
417 map(visitor.visit, types)
421 self.traceInterfaces(api)
424 self.interface = None
426 map(self.traceFunctionDecl, api.functions)
427 map(self.traceFunctionImpl, api.functions)
432 def header(self, api):
433 print '#ifdef _WIN32'
434 print '# include <malloc.h> // alloca'
435 print '# ifndef alloca'
436 print '# define alloca _alloca'
439 print '# include <alloca.h> // alloca'
442 print '#include "trace.hpp"'
445 def footer(self, api):
448 def traceFunctionDecl(self, function):
449 # Per-function declarations
451 if not function.internal:
453 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
455 print 'static const char ** _%s_args = NULL;' % (function.name,)
456 print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
459 def isFunctionPublic(self, function):
462 def traceFunctionImpl(self, function):
463 if self.isFunctionPublic(function):
464 print 'extern "C" PUBLIC'
466 print 'extern "C" PRIVATE'
467 print function.prototype() + ' {'
468 if function.type is not stdapi.Void:
469 print ' %s _result;' % function.type
471 # No-op if tracing is disabled
472 print ' if (!trace::isTracingEnabled()) {'
473 Tracer.invokeFunction(self, function)
474 if function.type is not stdapi.Void:
475 print ' return _result;'
480 self.traceFunctionImplBody(function)
481 if function.type is not stdapi.Void:
482 print ' return _result;'
486 def traceFunctionImplBody(self, function):
487 if not function.internal:
488 print ' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
489 for arg in function.args:
491 self.unwrapArg(function, arg)
492 self.serializeArg(function, arg)
493 print ' trace::localWriter.endEnter();'
494 self.invokeFunction(function)
495 if not function.internal:
496 print ' trace::localWriter.beginLeave(_call);'
497 print ' if (%s) {' % self.wasFunctionSuccessful(function)
498 for arg in function.args:
500 self.serializeArg(function, arg)
501 self.wrapArg(function, arg)
503 if function.type is not stdapi.Void:
504 self.serializeRet(function, "_result")
505 print ' trace::localWriter.endLeave();'
506 if function.type is not stdapi.Void:
507 self.wrapRet(function, "_result")
509 def invokeFunction(self, function, prefix='_', suffix=''):
510 if function.type is stdapi.Void:
513 result = '_result = '
514 dispatch = prefix + function.name + suffix
515 print ' %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
517 def wasFunctionSuccessful(self, function):
518 if function.type is stdapi.Void:
520 if str(function.type) == 'HRESULT':
521 return 'SUCCEEDED(_result)'
524 def serializeArg(self, function, arg):
525 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
526 self.serializeArgValue(function, arg)
527 print ' trace::localWriter.endArg();'
529 def serializeArgValue(self, function, arg):
530 self.serializeValue(arg.type, arg.name)
532 def wrapArg(self, function, arg):
533 assert not isinstance(arg.type, stdapi.ObjPointer)
535 from specs.winapi import REFIID
537 for other_arg in function.args:
538 if not other_arg.output and other_arg.type is REFIID:
540 if riid is not None \
541 and isinstance(arg.type, stdapi.Pointer) \
542 and isinstance(arg.type.type, stdapi.ObjPointer):
543 self.wrapIid(function, riid, arg)
546 self.wrapValue(arg.type, arg.name)
548 def unwrapArg(self, function, arg):
549 self.unwrapValue(arg.type, arg.name)
551 def serializeRet(self, function, instance):
552 print ' trace::localWriter.beginReturn();'
553 self.serializeValue(function.type, instance)
554 print ' trace::localWriter.endReturn();'
556 def serializeValue(self, type, instance):
557 serializer = self.serializerFactory()
558 serializer.visit(type, instance)
560 def wrapRet(self, function, instance):
561 self.wrapValue(function.type, instance)
563 def unwrapRet(self, function, instance):
564 self.unwrapValue(function.type, instance)
566 def needsWrapping(self, type):
567 visitor = WrapDecider()
569 return visitor.needsWrapping
571 def wrapValue(self, type, instance):
572 if self.needsWrapping(type):
573 visitor = ValueWrapper()
574 visitor.visit(type, instance)
576 def unwrapValue(self, type, instance):
577 if self.needsWrapping(type):
578 visitor = ValueUnwrapper()
579 visitor.visit(type, instance)
581 def traceInterfaces(self, api):
582 interfaces = api.getAllInterfaces()
585 map(self.declareWrapperInterface, interfaces)
586 self.implementIidWrapper(api)
587 map(self.implementWrapperInterface, interfaces)
590 def declareWrapperInterface(self, interface):
591 print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
594 print " %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
595 print " virtual ~%s();" % getWrapperInterfaceName(interface)
597 for method in interface.iterMethods():
598 print " " + method.prototype() + ";"
601 for type, name, value in self.enumWrapperInterfaceVariables(interface):
602 print ' %s %s;' % (type, name)
606 def enumWrapperInterfaceVariables(self, interface):
608 ("DWORD", "m_dwMagic", "0xd8365d6c"),
609 ("%s *" % interface.name, "m_pInstance", "pInstance"),
612 def implementWrapperInterface(self, interface):
613 self.interface = interface
615 print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
616 for type, name, value in self.enumWrapperInterfaceVariables(interface):
617 print ' %s = %s;' % (name, value)
620 print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
624 for base, method in interface.iterBaseMethods():
626 self.implementWrapperInterfaceMethod(interface, base, method)
630 def implementWrapperInterfaceMethod(self, interface, base, method):
631 print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
632 if method.type is not stdapi.Void:
633 print ' %s _result;' % method.type
635 self.implementWrapperInterfaceMethodBody(interface, base, method)
637 if method.type is not stdapi.Void:
638 print ' return _result;'
642 def implementWrapperInterfaceMethodBody(self, interface, base, method):
643 assert not method.internal
645 print ' static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
646 print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
648 print ' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
650 print ' unsigned _call = trace::localWriter.beginEnter(&_sig);'
651 print ' trace::localWriter.beginArg(0);'
652 print ' trace::localWriter.writePointer((uintptr_t)m_pInstance);'
653 print ' trace::localWriter.endArg();'
654 for arg in method.args:
656 self.unwrapArg(method, arg)
657 self.serializeArg(method, arg)
658 print ' trace::localWriter.endEnter();'
660 self.invokeMethod(interface, base, method)
662 print ' trace::localWriter.beginLeave(_call);'
664 print ' if (%s) {' % self.wasFunctionSuccessful(method)
665 for arg in method.args:
667 self.serializeArg(method, arg)
668 self.wrapArg(method, arg)
671 if method.type is not stdapi.Void:
672 self.serializeRet(method, '_result')
673 print ' trace::localWriter.endLeave();'
674 if method.type is not stdapi.Void:
675 self.wrapRet(method, '_result')
677 if method.name == 'Release':
678 assert method.type is not stdapi.Void
679 print ' if (!_result)'
680 print ' delete this;'
682 def implementIidWrapper(self, api):
684 print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
685 print r' os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
686 print r' functionName, reason,'
687 print r' riid.Data1, riid.Data2, riid.Data3,'
688 print r' riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
692 print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
693 print r' if (!ppvObj || !*ppvObj) {'
697 for iface in api.getAllInterfaces():
698 print r' %sif (riid == IID_%s) {' % (else_, iface.name)
699 print r' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
702 print r' %s{' % else_
703 print r' warnIID(functionName, riid, "unknown");'
708 def wrapIid(self, function, riid, out):
709 # Cast output arg to `void **` if necessary
711 obj_type = out.type.type.type
712 if not obj_type is stdapi.Void:
713 assert isinstance(obj_type, stdapi.Interface)
714 out_name = 'reinterpret_cast<void * *>(%s)' % out_name
716 print r' if (%s && *%s) {' % (out.name, out.name)
717 functionName = function.name
719 if self.interface is not None:
720 functionName = self.interface.name + '::' + functionName
721 print r' if (*%s == m_pInstance &&' % (out_name,)
722 print r' (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
723 print r' *%s = this;' % (out_name,)
726 print r' %s{' % else_
727 print r' wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
731 def invokeMethod(self, interface, base, method):
732 if method.type is stdapi.Void:
735 result = '_result = '
736 print ' %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
738 def emit_memcpy(self, dest, src, length):
739 print ' unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
740 print ' trace::localWriter.beginArg(0);'
741 print ' trace::localWriter.writePointer((uintptr_t)%s);' % dest
742 print ' trace::localWriter.endArg();'
743 print ' trace::localWriter.beginArg(1);'
744 print ' trace::localWriter.writeBlob(%s, %s);' % (src, length)
745 print ' trace::localWriter.endArg();'
746 print ' trace::localWriter.beginArg(2);'
747 print ' trace::localWriter.writeUInt(%s);' % length
748 print ' trace::localWriter.endArg();'
749 print ' trace::localWriter.endEnter();'
750 print ' trace::localWriter.beginLeave(_call);'
751 print ' trace::localWriter.endLeave();'
753 def fake_call(self, function, args):
754 print ' unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
755 for arg, instance in zip(function.args, args):
756 assert not arg.output
757 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
758 self.serializeValue(arg.type, instance)
759 print ' trace::localWriter.endArg();'
760 print ' trace::localWriter.endEnter();'
761 print ' trace::localWriter.beginLeave(_fake_call);'
762 print ' trace::localWriter.endLeave();'