1 ##########################################################################
3 # Copyright 2008-2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
26 """Common trace code generation."""
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
35 import specs.stdapi as stdapi
38 def getWrapperInterfaceName(interface):
39 return "Wrap" + interface.expr
44 '''Mixin class that provides a bunch of methods to expand C expressions
45 from the specifications.'''
50 def expand(self, expr):
51 # Expand a C expression, replacing certain variables
52 if not isinstance(expr, basestring):
56 if self.__structs is not None:
57 variables['self'] = '(%s)' % self.__structs[0]
58 if self.__indices is not None:
59 variables['i'] = self.__indices[0]
61 expandedExpr = expr.format(**variables)
62 if expandedExpr != expr and 0:
63 sys.stderr.write(" %r -> %r\n" % (expr, expandedExpr))
66 def visitMember(self, structInstance, member_type, *args, **kwargs):
67 self.__structs = (structInstance, self.__structs)
69 return self.visit(member_type, *args, **kwargs)
71 _, self.__structs = self.__structs
73 def visitElement(self, element_index, element_type, *args, **kwargs):
74 self.__indices = (element_index, self.__indices)
76 return self.visit(element_type, *args, **kwargs)
78 _, self.__indices = self.__indices
81 class ComplexValueSerializer(stdapi.OnceVisitor):
82 '''Type visitors which generates serialization functions for
85 Simple types are serialized inline.
88 def __init__(self, serializer):
89 stdapi.OnceVisitor.__init__(self)
90 self.serializer = serializer
92 def visitVoid(self, literal):
95 def visitLiteral(self, literal):
98 def visitString(self, string):
101 def visitConst(self, const):
102 self.visit(const.type)
104 def visitStruct(self, struct):
105 print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
106 for type, name, in struct.members:
107 print ' "%s",' % (name,)
109 print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
110 print ' %u, "%s", %u, _struct%s_members' % (struct.id, struct.name, len(struct.members), struct.tag)
114 def visitArray(self, array):
115 self.visit(array.type)
117 def visitBlob(self, array):
120 def visitEnum(self, enum):
121 print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
122 for value in enum.values:
123 print ' {"%s", %s},' % (value, value)
126 print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
127 print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
131 def visitBitmask(self, bitmask):
132 print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
133 for value in bitmask.values:
134 print ' {"%s", %s},' % (value, value)
137 print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
138 print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
142 def visitPointer(self, pointer):
143 self.visit(pointer.type)
145 def visitIntPointer(self, pointer):
148 def visitObjPointer(self, pointer):
149 self.visit(pointer.type)
151 def visitLinearPointer(self, pointer):
152 self.visit(pointer.type)
154 def visitHandle(self, handle):
155 self.visit(handle.type)
157 def visitReference(self, reference):
158 self.visit(reference.type)
160 def visitAlias(self, alias):
161 self.visit(alias.type)
163 def visitOpaque(self, opaque):
166 def visitInterface(self, interface):
169 def visitPolymorphic(self, polymorphic):
170 if not polymorphic.contextLess:
172 print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
173 print ' switch (selector) {'
174 for cases, type in polymorphic.iterSwitch():
177 self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
184 class ValueSerializer(stdapi.Visitor, ExpanderMixin):
185 '''Visitor which generates code to serialize any type.
187 Simple types are serialized inline here, whereas the serialization of
188 complex types is dispatched to the serialization functions generated by
189 ComplexValueSerializer visitor above.
193 #stdapi.Visitor.__init__(self)
197 def visitLiteral(self, literal, instance):
198 print ' trace::localWriter.write%s(%s);' % (literal.kind, instance)
200 def visitString(self, string, instance):
202 cast = 'const char *'
205 cast = 'const wchar_t *'
207 if cast != string.expr:
208 # reinterpret_cast is necessary for GLubyte * <=> char *
209 instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
210 if string.length is not None:
211 length = ', %s' % string.length
214 print ' trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
216 def visitConst(self, const, instance):
217 self.visit(const.type, instance)
219 def visitStruct(self, struct, instance):
220 print ' trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
221 for type, name in struct.members:
222 self.visitMember(instance, type, '(%s).%s' % (instance, name,))
223 print ' trace::localWriter.endStruct();'
225 def visitArray(self, array, instance):
226 length = '_c' + array.type.tag
227 index = '_i' + array.type.tag
228 array_length = self.expand(array.length)
229 print ' if (%s) {' % instance
230 print ' size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
231 print ' trace::localWriter.beginArray(%s);' % length
232 print ' for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
233 print ' trace::localWriter.beginElement();'
234 self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
235 print ' trace::localWriter.endElement();'
237 print ' trace::localWriter.endArray();'
239 print ' trace::localWriter.writeNull();'
242 def visitBlob(self, blob, instance):
243 print ' trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
245 def visitEnum(self, enum, instance):
246 print ' trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
248 def visitBitmask(self, bitmask, instance):
249 print ' trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
251 def visitPointer(self, pointer, instance):
252 print ' if (%s) {' % instance
253 print ' trace::localWriter.beginArray(1);'
254 print ' trace::localWriter.beginElement();'
255 self.visit(pointer.type, "*" + instance)
256 print ' trace::localWriter.endElement();'
257 print ' trace::localWriter.endArray();'
259 print ' trace::localWriter.writeNull();'
262 def visitIntPointer(self, pointer, instance):
263 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
265 def visitObjPointer(self, pointer, instance):
266 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
268 def visitLinearPointer(self, pointer, instance):
269 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
271 def visitReference(self, reference, instance):
272 self.visit(reference.type, instance)
274 def visitHandle(self, handle, instance):
275 self.visit(handle.type, instance)
277 def visitAlias(self, alias, instance):
278 self.visit(alias.type, instance)
280 def visitOpaque(self, opaque, instance):
281 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
283 def visitInterface(self, interface, instance):
286 def visitPolymorphic(self, polymorphic, instance):
287 if polymorphic.contextLess:
288 print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
290 print ' switch (%s) {' % polymorphic.switchExpr
291 for cases, type in polymorphic.iterSwitch():
294 self.visit(type, 'static_cast<%s>(%s)' % (type, instance))
299 class WrapDecider(stdapi.Traverser):
300 '''Type visitor which will decide wheter this type will need wrapping or not.
302 For complex types (arrays, structures), we need to know this before hand.
306 self.needsWrapping = False
308 def visitLinearPointer(self, void):
311 def visitInterface(self, interface):
312 self.needsWrapping = True
315 class ValueWrapper(stdapi.Traverser, ExpanderMixin):
316 '''Type visitor which will generate the code to wrap an instance.
318 Wrapping is necessary mostly for interfaces, however interface pointers can
319 appear anywhere inside complex types.
322 def visitStruct(self, struct, instance):
323 for type, name in struct.members:
324 self.visitMember(instance, type, "(%s).%s" % (instance, name))
326 def visitArray(self, array, instance):
327 array_length = self.expand(array.length)
328 print " if (%s) {" % instance
329 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
330 self.visitElement('_i', array.type, instance + "[_i]")
334 def visitPointer(self, pointer, instance):
335 print " if (%s) {" % instance
336 self.visit(pointer.type, "*" + instance)
339 def visitObjPointer(self, pointer, instance):
340 elem_type = pointer.type.mutable()
341 if isinstance(elem_type, stdapi.Interface):
342 self.visitInterfacePointer(elem_type, instance)
344 self.visitPointer(pointer, instance)
346 def visitInterface(self, interface, instance):
347 raise NotImplementedError
349 def visitInterfacePointer(self, interface, instance):
350 print " if (%s) {" % instance
351 print " %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
354 def visitPolymorphic(self, type, instance):
355 # XXX: There might be polymorphic values that need wrapping in the future
356 raise NotImplementedError
359 class ValueUnwrapper(ValueWrapper):
360 '''Reverse of ValueWrapper.'''
364 def visitArray(self, array, instance):
365 if self.allocated or isinstance(instance, stdapi.Interface):
366 return ValueWrapper.visitArray(self, array, instance)
367 array_length = self.expand(array.length)
368 elem_type = array.type.mutable()
369 print " if (%s && %s) {" % (instance, array_length)
370 print " %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
371 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
372 print " _t[_i] = %s[_i];" % instance
373 self.allocated = True
374 self.visit(array.type, "_t[_i]")
376 print " %s = _t;" % instance
379 def visitInterfacePointer(self, interface, instance):
380 print r' if (%s) {' % instance
381 print r' const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
382 print r' if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
383 print r' %s = pWrapper->m_pInstance;' % (instance,)
385 print r' os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
391 '''Base class to orchestrate the code generation of API tracing.'''
396 def serializerFactory(self):
397 '''Create a serializer.
399 Can be overriden by derived classes to inject their own serialzer.
402 return ValueSerializer()
404 def traceApi(self, api):
410 for header in api.headers:
414 # Generate the serializer functions
415 types = api.getAllTypes()
416 visitor = ComplexValueSerializer(self.serializerFactory())
417 map(visitor.visit, types)
421 self.traceInterfaces(api)
424 self.interface = None
426 map(self.traceFunctionDecl, api.functions)
427 map(self.traceFunctionImpl, api.functions)
432 def header(self, api):
433 print '#ifdef _WIN32'
434 print '# include <malloc.h> // alloca'
435 print '# ifndef alloca'
436 print '# define alloca _alloca'
439 print '# include <alloca.h> // alloca'
442 print '#include "trace.hpp"'
445 def footer(self, api):
448 def traceFunctionDecl(self, function):
449 # Per-function declarations
451 if not function.internal:
453 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
455 print 'static const char ** _%s_args = NULL;' % (function.name,)
456 print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
459 def isFunctionPublic(self, function):
462 def traceFunctionImpl(self, function):
463 if self.isFunctionPublic(function):
464 print 'extern "C" PUBLIC'
466 print 'extern "C" PRIVATE'
467 print function.prototype() + ' {'
468 if function.type is not stdapi.Void:
469 print ' %s _result;' % function.type
471 # No-op if tracing is disabled
472 print ' if (!trace::isTracingEnabled()) {'
473 Tracer.invokeFunction(self, function)
474 if function.type is not stdapi.Void:
475 print ' return _result;'
480 self.traceFunctionImplBody(function)
481 if function.type is not stdapi.Void:
482 print ' return _result;'
486 def traceFunctionImplBody(self, function):
487 if not function.internal:
488 print ' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
489 for arg in function.args:
491 self.unwrapArg(function, arg)
492 self.serializeArg(function, arg)
493 print ' trace::localWriter.endEnter();'
494 self.invokeFunction(function)
495 if not function.internal:
496 print ' trace::localWriter.beginLeave(_call);'
497 for arg in function.args:
499 self.serializeArg(function, arg)
500 self.wrapArg(function, arg)
501 if function.type is not stdapi.Void:
502 self.serializeRet(function, "_result")
503 print ' trace::localWriter.endLeave();'
504 if function.type is not stdapi.Void:
505 self.wrapRet(function, "_result")
507 def invokeFunction(self, function, prefix='_', suffix=''):
508 if function.type is stdapi.Void:
511 result = '_result = '
512 dispatch = prefix + function.name + suffix
513 print ' %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
515 def serializeArg(self, function, arg):
516 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
517 self.serializeArgValue(function, arg)
518 print ' trace::localWriter.endArg();'
520 def serializeArgValue(self, function, arg):
521 self.serializeValue(arg.type, arg.name)
523 def wrapArg(self, function, arg):
524 assert not isinstance(arg.type, stdapi.ObjPointer)
526 from specs.winapi import REFIID
528 for other_arg in function.args:
529 if not other_arg.output and other_arg.type is REFIID:
531 if riid is not None \
532 and isinstance(arg.type, stdapi.Pointer) \
533 and isinstance(arg.type.type, stdapi.ObjPointer):
534 self.wrapIid(function, riid, arg)
537 self.wrapValue(arg.type, arg.name)
539 def unwrapArg(self, function, arg):
540 self.unwrapValue(arg.type, arg.name)
542 def serializeRet(self, function, instance):
543 print ' trace::localWriter.beginReturn();'
544 self.serializeValue(function.type, instance)
545 print ' trace::localWriter.endReturn();'
547 def serializeValue(self, type, instance):
548 serializer = self.serializerFactory()
549 serializer.visit(type, instance)
551 def wrapRet(self, function, instance):
552 self.wrapValue(function.type, instance)
554 def unwrapRet(self, function, instance):
555 self.unwrapValue(function.type, instance)
557 def needsWrapping(self, type):
558 visitor = WrapDecider()
560 return visitor.needsWrapping
562 def wrapValue(self, type, instance):
563 if self.needsWrapping(type):
564 visitor = ValueWrapper()
565 visitor.visit(type, instance)
567 def unwrapValue(self, type, instance):
568 if self.needsWrapping(type):
569 visitor = ValueUnwrapper()
570 visitor.visit(type, instance)
572 def traceInterfaces(self, api):
573 interfaces = api.getAllInterfaces()
576 map(self.declareWrapperInterface, interfaces)
577 self.implementIidWrapper(api)
578 map(self.implementWrapperInterface, interfaces)
581 def declareWrapperInterface(self, interface):
582 print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
585 print " %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
586 print " virtual ~%s();" % getWrapperInterfaceName(interface)
588 for method in interface.iterMethods():
589 print " " + method.prototype() + ";"
592 for type, name, value in self.enumWrapperInterfaceVariables(interface):
593 print ' %s %s;' % (type, name)
597 def enumWrapperInterfaceVariables(self, interface):
599 ("DWORD", "m_dwMagic", "0xd8365d6c"),
600 ("%s *" % interface.name, "m_pInstance", "pInstance"),
603 def implementWrapperInterface(self, interface):
604 self.interface = interface
606 print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
607 for type, name, value in self.enumWrapperInterfaceVariables(interface):
608 print ' %s = %s;' % (name, value)
611 print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
615 for base, method in interface.iterBaseMethods():
617 self.implementWrapperInterfaceMethod(interface, base, method)
621 def implementWrapperInterfaceMethod(self, interface, base, method):
622 print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
623 if method.type is not stdapi.Void:
624 print ' %s _result;' % method.type
626 self.implementWrapperInterfaceMethodBody(interface, base, method)
628 if method.type is not stdapi.Void:
629 print ' return _result;'
633 def implementWrapperInterfaceMethodBody(self, interface, base, method):
634 assert not method.internal
636 print ' static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
637 print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
639 print ' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
641 print ' unsigned _call = trace::localWriter.beginEnter(&_sig);'
642 print ' trace::localWriter.beginArg(0);'
643 print ' trace::localWriter.writePointer((uintptr_t)m_pInstance);'
644 print ' trace::localWriter.endArg();'
645 for arg in method.args:
647 self.unwrapArg(method, arg)
648 self.serializeArg(method, arg)
649 print ' trace::localWriter.endEnter();'
651 self.invokeMethod(interface, base, method)
653 print ' trace::localWriter.beginLeave(_call);'
654 for arg in method.args:
656 self.serializeArg(method, arg)
657 self.wrapArg(method, arg)
659 if method.type is not stdapi.Void:
660 self.serializeRet(method, '_result')
661 print ' trace::localWriter.endLeave();'
662 if method.type is not stdapi.Void:
663 self.wrapRet(method, '_result')
665 if method.name == 'Release':
666 assert method.type is not stdapi.Void
667 print ' if (!_result)'
668 print ' delete this;'
670 def implementIidWrapper(self, api):
672 print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
673 print r' os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
674 print r' functionName, reason,'
675 print r' riid.Data1, riid.Data2, riid.Data3,'
676 print r' riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
680 print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
681 print r' if (!ppvObj || !*ppvObj) {'
685 for iface in api.getAllInterfaces():
686 print r' %sif (riid == IID_%s) {' % (else_, iface.name)
687 print r' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
690 print r' %s{' % else_
691 print r' warnIID(functionName, riid, "unknown");'
696 def wrapIid(self, function, riid, out):
697 # Cast output arg to `void **` if necessary
699 obj_type = out.type.type.type
700 if not obj_type is stdapi.Void:
701 assert isinstance(obj_type, stdapi.Interface)
702 out_name = 'reinterpret_cast<void * *>(%s)' % out_name
704 print r' if (%s && *%s) {' % (out.name, out.name)
705 functionName = function.name
707 if self.interface is not None:
708 functionName = self.interface.name + '::' + functionName
709 print r' if (*%s == m_pInstance &&' % (out_name,)
710 print r' (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
711 print r' *%s = this;' % (out_name,)
714 print r' %s{' % else_
715 print r' wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
719 def invokeMethod(self, interface, base, method):
720 if method.type is stdapi.Void:
723 result = '_result = '
724 print ' %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
726 def emit_memcpy(self, dest, src, length):
727 print ' unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
728 print ' trace::localWriter.beginArg(0);'
729 print ' trace::localWriter.writePointer((uintptr_t)%s);' % dest
730 print ' trace::localWriter.endArg();'
731 print ' trace::localWriter.beginArg(1);'
732 print ' trace::localWriter.writeBlob(%s, %s);' % (src, length)
733 print ' trace::localWriter.endArg();'
734 print ' trace::localWriter.beginArg(2);'
735 print ' trace::localWriter.writeUInt(%s);' % length
736 print ' trace::localWriter.endArg();'
737 print ' trace::localWriter.endEnter();'
738 print ' trace::localWriter.beginLeave(_call);'
739 print ' trace::localWriter.endLeave();'
741 def fake_call(self, function, args):
742 print ' unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
743 for arg, instance in zip(function.args, args):
744 assert not arg.output
745 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
746 self.serializeValue(arg.type, instance)
747 print ' trace::localWriter.endArg();'
748 print ' trace::localWriter.endEnter();'
749 print ' trace::localWriter.beginLeave(_fake_call);'
750 print ' trace::localWriter.endLeave();'