1 ##########################################################################
3 # Copyright 2008-2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
26 """Common trace code generation."""
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
35 import specs.stdapi as stdapi
38 def getWrapperInterfaceName(interface):
39 return "Wrap" + interface.expr
44 '''Mixin class that provides a bunch of methods to expand C expressions
45 from the specifications.'''
50 def expand(self, expr):
51 # Expand a C expression, replacing certain variables
52 if not isinstance(expr, basestring):
56 if self.__structs is not None:
57 variables['self'] = '(%s)' % self.__structs[0]
58 if self.__indices is not None:
59 variables['i'] = self.__indices[0]
61 expandedExpr = expr.format(**variables)
62 if expandedExpr != expr and 0:
63 sys.stderr.write(" %r -> %r\n" % (expr, expandedExpr))
66 def visitMember(self, structInstance, member_type, *args, **kwargs):
67 self.__structs = (structInstance, self.__structs)
69 return self.visit(member_type, *args, **kwargs)
71 _, self.__structs = self.__structs
73 def visitElement(self, element_index, element_type, *args, **kwargs):
74 self.__indices = (element_index, self.__indices)
76 return self.visit(element_type, *args, **kwargs)
78 _, self.__indices = self.__indices
81 class ComplexValueSerializer(stdapi.OnceVisitor):
82 '''Type visitors which generates serialization functions for
85 Simple types are serialized inline.
88 def __init__(self, serializer):
89 stdapi.OnceVisitor.__init__(self)
90 self.serializer = serializer
92 def visitVoid(self, literal):
95 def visitLiteral(self, literal):
98 def visitString(self, string):
101 def visitConst(self, const):
102 self.visit(const.type)
104 def visitStruct(self, struct):
105 print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
106 for type, name, in struct.members:
107 print ' "%s",' % (name,)
109 print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
110 print ' %u, "%s", %u, _struct%s_members' % (struct.id, struct.name, len(struct.members), struct.tag)
114 def visitArray(self, array):
115 self.visit(array.type)
117 def visitBlob(self, array):
120 def visitEnum(self, enum):
121 print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
122 for value in enum.values:
123 print ' {"%s", %s},' % (value, value)
126 print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
127 print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
131 def visitBitmask(self, bitmask):
132 print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
133 for value in bitmask.values:
134 print ' {"%s", %s},' % (value, value)
137 print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
138 print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
142 def visitPointer(self, pointer):
143 self.visit(pointer.type)
145 def visitIntPointer(self, pointer):
148 def visitObjPointer(self, pointer):
149 self.visit(pointer.type)
151 def visitLinearPointer(self, pointer):
152 self.visit(pointer.type)
154 def visitHandle(self, handle):
155 self.visit(handle.type)
157 def visitReference(self, reference):
158 self.visit(reference.type)
160 def visitAlias(self, alias):
161 self.visit(alias.type)
163 def visitOpaque(self, opaque):
166 def visitInterface(self, interface):
169 def visitPolymorphic(self, polymorphic):
170 if not polymorphic.contextLess:
172 print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
173 print ' switch (selector) {'
174 for cases, type in polymorphic.iterSwitch():
177 self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
184 class ValueSerializer(stdapi.Visitor, ExpanderMixin):
185 '''Visitor which generates code to serialize any type.
187 Simple types are serialized inline here, whereas the serialization of
188 complex types is dispatched to the serialization functions generated by
189 ComplexValueSerializer visitor above.
193 #stdapi.Visitor.__init__(self)
197 def visitLiteral(self, literal, instance):
198 print ' trace::localWriter.write%s(%s);' % (literal.kind, instance)
200 def visitString(self, string, instance):
202 cast = 'const char *'
205 cast = 'const wchar_t *'
207 if cast != string.expr:
208 # reinterpret_cast is necessary for GLubyte * <=> char *
209 instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
210 if string.length is not None:
211 length = ', %s' % string.length
214 print ' trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
216 def visitConst(self, const, instance):
217 self.visit(const.type, instance)
219 def visitStruct(self, struct, instance):
220 print ' trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
221 for type, name in struct.members:
222 self.visitMember(instance, type, '(%s).%s' % (instance, name,))
223 print ' trace::localWriter.endStruct();'
225 def visitArray(self, array, instance):
226 length = '_c' + array.type.tag
227 index = '_i' + array.type.tag
228 array_length = self.expand(array.length)
229 print ' if (%s) {' % instance
230 print ' size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
231 print ' trace::localWriter.beginArray(%s);' % length
232 print ' for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
233 print ' trace::localWriter.beginElement();'
234 self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
235 print ' trace::localWriter.endElement();'
237 print ' trace::localWriter.endArray();'
239 print ' trace::localWriter.writeNull();'
242 def visitBlob(self, blob, instance):
243 print ' trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
245 def visitEnum(self, enum, instance):
246 print ' trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
248 def visitBitmask(self, bitmask, instance):
249 print ' trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
251 def visitPointer(self, pointer, instance):
252 print ' if (%s) {' % instance
253 print ' trace::localWriter.beginArray(1);'
254 print ' trace::localWriter.beginElement();'
255 self.visit(pointer.type, "*" + instance)
256 print ' trace::localWriter.endElement();'
257 print ' trace::localWriter.endArray();'
259 print ' trace::localWriter.writeNull();'
262 def visitIntPointer(self, pointer, instance):
263 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
265 def visitObjPointer(self, pointer, instance):
266 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
268 def visitLinearPointer(self, pointer, instance):
269 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
271 def visitReference(self, reference, instance):
272 self.visit(reference.type, instance)
274 def visitHandle(self, handle, instance):
275 self.visit(handle.type, instance)
277 def visitAlias(self, alias, instance):
278 self.visit(alias.type, instance)
280 def visitOpaque(self, opaque, instance):
281 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
283 def visitInterface(self, interface, instance):
286 def visitPolymorphic(self, polymorphic, instance):
287 if polymorphic.contextLess:
288 print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
290 print ' switch (%s) {' % polymorphic.switchExpr
291 for cases, type in polymorphic.iterSwitch():
294 self.visit(type, 'static_cast<%s>(%s)' % (type, instance))
299 class WrapDecider(stdapi.Traverser):
300 '''Type visitor which will decide wheter this type will need wrapping or not.
302 For complex types (arrays, structures), we need to know this before hand.
306 self.needsWrapping = False
308 def visitLinearPointer(self, void):
311 def visitInterface(self, interface):
312 self.needsWrapping = True
315 class ValueWrapper(stdapi.Traverser, ExpanderMixin):
316 '''Type visitor which will generate the code to wrap an instance.
318 Wrapping is necessary mostly for interfaces, however interface pointers can
319 appear anywhere inside complex types.
322 def visitStruct(self, struct, instance):
323 for type, name in struct.members:
324 self.visitMember(instance, type, "(%s).%s" % (instance, name))
326 def visitArray(self, array, instance):
327 array_length = self.expand(array.length)
328 print " if (%s) {" % instance
329 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
330 self.visitElement('_i', array.type, instance + "[_i]")
334 def visitPointer(self, pointer, instance):
335 print " if (%s) {" % instance
336 self.visit(pointer.type, "*" + instance)
339 def visitObjPointer(self, pointer, instance):
340 elem_type = pointer.type.mutable()
341 if isinstance(elem_type, stdapi.Interface):
342 self.visitInterfacePointer(elem_type, instance)
343 elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
344 self.visitInterfacePointer(elem_type.type, instance)
346 self.visitPointer(pointer, instance)
348 def visitInterface(self, interface, instance):
349 raise NotImplementedError
351 def visitInterfacePointer(self, interface, instance):
352 print " if (%s) {" % instance
353 print " %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
356 def visitPolymorphic(self, type, instance):
357 # XXX: There might be polymorphic values that need wrapping in the future
358 raise NotImplementedError
361 class ValueUnwrapper(ValueWrapper):
362 '''Reverse of ValueWrapper.'''
366 def visitStruct(self, struct, instance):
367 if not self.allocated:
368 # Argument is constant. We need to create a non const
370 print " %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
371 print ' *_t = %s;' % (instance,)
372 assert instance.startswith('*')
373 print ' %s = _t;' % (instance[1:],)
375 self.allocated = True
377 return ValueWrapper.visitStruct(self, struct, instance)
381 return ValueWrapper.visitStruct(self, struct, instance)
383 def visitArray(self, array, instance):
384 if self.allocated or isinstance(instance, stdapi.Interface):
385 return ValueWrapper.visitArray(self, array, instance)
386 array_length = self.expand(array.length)
387 elem_type = array.type.mutable()
388 print " if (%s && %s) {" % (instance, array_length)
389 print " %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
390 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
391 print " _t[_i] = %s[_i];" % instance
392 self.allocated = True
393 self.visit(array.type, "_t[_i]")
395 print " %s = _t;" % instance
398 def visitInterfacePointer(self, interface, instance):
399 print r' if (%s) {' % instance
400 print r' const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
401 print r' if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
402 print r' %s = pWrapper->m_pInstance;' % (instance,)
404 print r' os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
410 '''Base class to orchestrate the code generation of API tracing.'''
415 def serializerFactory(self):
416 '''Create a serializer.
418 Can be overriden by derived classes to inject their own serialzer.
421 return ValueSerializer()
423 def traceApi(self, api):
429 for header in api.headers:
433 # Generate the serializer functions
434 types = api.getAllTypes()
435 visitor = ComplexValueSerializer(self.serializerFactory())
436 map(visitor.visit, types)
440 self.traceInterfaces(api)
443 self.interface = None
445 map(self.traceFunctionDecl, api.functions)
446 map(self.traceFunctionImpl, api.functions)
451 def header(self, api):
452 print '#ifdef _WIN32'
453 print '# include <malloc.h> // alloca'
454 print '# ifndef alloca'
455 print '# define alloca _alloca'
458 print '# include <alloca.h> // alloca'
461 print '#include "trace.hpp"'
464 def footer(self, api):
467 def traceFunctionDecl(self, function):
468 # Per-function declarations
470 if not function.internal:
472 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
474 print 'static const char ** _%s_args = NULL;' % (function.name,)
475 print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
478 def isFunctionPublic(self, function):
481 def traceFunctionImpl(self, function):
482 if self.isFunctionPublic(function):
483 print 'extern "C" PUBLIC'
485 print 'extern "C" PRIVATE'
486 print function.prototype() + ' {'
487 if function.type is not stdapi.Void:
488 print ' %s _result;' % function.type
490 # No-op if tracing is disabled
491 print ' if (!trace::isTracingEnabled()) {'
492 Tracer.invokeFunction(self, function)
493 if function.type is not stdapi.Void:
494 print ' return _result;'
499 self.traceFunctionImplBody(function)
500 if function.type is not stdapi.Void:
501 print ' return _result;'
505 def traceFunctionImplBody(self, function):
506 if not function.internal:
507 print ' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
508 for arg in function.args:
510 self.unwrapArg(function, arg)
511 self.serializeArg(function, arg)
512 print ' trace::localWriter.endEnter();'
513 self.invokeFunction(function)
514 if not function.internal:
515 print ' trace::localWriter.beginLeave(_call);'
516 for arg in function.args:
518 self.serializeArg(function, arg)
519 self.wrapArg(function, arg)
520 if function.type is not stdapi.Void:
521 self.serializeRet(function, "_result")
522 print ' trace::localWriter.endLeave();'
523 if function.type is not stdapi.Void:
524 self.wrapRet(function, "_result")
526 def invokeFunction(self, function, prefix='_', suffix=''):
527 if function.type is stdapi.Void:
530 result = '_result = '
531 dispatch = prefix + function.name + suffix
532 print ' %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
534 def serializeArg(self, function, arg):
535 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
536 self.serializeArgValue(function, arg)
537 print ' trace::localWriter.endArg();'
539 def serializeArgValue(self, function, arg):
540 self.serializeValue(arg.type, arg.name)
542 def wrapArg(self, function, arg):
543 assert not isinstance(arg.type, stdapi.ObjPointer)
545 from specs.winapi import REFIID
547 for other_arg in function.args:
548 if not other_arg.output and other_arg.type is REFIID:
550 if riid is not None \
551 and isinstance(arg.type, stdapi.Pointer) \
552 and isinstance(arg.type.type, stdapi.ObjPointer):
553 self.wrapIid(function, riid, arg)
556 self.wrapValue(arg.type, arg.name)
558 def unwrapArg(self, function, arg):
559 self.unwrapValue(arg.type, arg.name)
561 def serializeRet(self, function, instance):
562 print ' trace::localWriter.beginReturn();'
563 self.serializeValue(function.type, instance)
564 print ' trace::localWriter.endReturn();'
566 def serializeValue(self, type, instance):
567 serializer = self.serializerFactory()
568 serializer.visit(type, instance)
570 def wrapRet(self, function, instance):
571 self.wrapValue(function.type, instance)
573 def unwrapRet(self, function, instance):
574 self.unwrapValue(function.type, instance)
576 def needsWrapping(self, type):
577 visitor = WrapDecider()
579 return visitor.needsWrapping
581 def wrapValue(self, type, instance):
582 if self.needsWrapping(type):
583 visitor = ValueWrapper()
584 visitor.visit(type, instance)
586 def unwrapValue(self, type, instance):
587 if self.needsWrapping(type):
588 visitor = ValueUnwrapper()
589 visitor.visit(type, instance)
591 def traceInterfaces(self, api):
592 interfaces = api.getAllInterfaces()
595 map(self.declareWrapperInterface, interfaces)
596 self.implementIidWrapper(api)
597 map(self.implementWrapperInterface, interfaces)
600 def declareWrapperInterface(self, interface):
601 print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
604 print " %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
605 print " virtual ~%s();" % getWrapperInterfaceName(interface)
607 for method in interface.iterMethods():
608 print " " + method.prototype() + ";"
611 for type, name, value in self.enumWrapperInterfaceVariables(interface):
612 print ' %s %s;' % (type, name)
616 def enumWrapperInterfaceVariables(self, interface):
618 ("DWORD", "m_dwMagic", "0xd8365d6c"),
619 ("%s *" % interface.name, "m_pInstance", "pInstance"),
622 def implementWrapperInterface(self, interface):
623 self.interface = interface
625 print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
626 for type, name, value in self.enumWrapperInterfaceVariables(interface):
627 print ' %s = %s;' % (name, value)
630 print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
634 for base, method in interface.iterBaseMethods():
636 self.implementWrapperInterfaceMethod(interface, base, method)
640 def implementWrapperInterfaceMethod(self, interface, base, method):
641 print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
642 if method.type is not stdapi.Void:
643 print ' %s _result;' % method.type
645 self.implementWrapperInterfaceMethodBody(interface, base, method)
647 if method.type is not stdapi.Void:
648 print ' return _result;'
652 def implementWrapperInterfaceMethodBody(self, interface, base, method):
653 assert not method.internal
655 print ' static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
656 print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
658 print ' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
660 print ' unsigned _call = trace::localWriter.beginEnter(&_sig);'
661 print ' trace::localWriter.beginArg(0);'
662 print ' trace::localWriter.writePointer((uintptr_t)m_pInstance);'
663 print ' trace::localWriter.endArg();'
664 for arg in method.args:
666 self.unwrapArg(method, arg)
667 self.serializeArg(method, arg)
668 print ' trace::localWriter.endEnter();'
670 self.invokeMethod(interface, base, method)
672 print ' trace::localWriter.beginLeave(_call);'
673 for arg in method.args:
675 self.serializeArg(method, arg)
676 self.wrapArg(method, arg)
678 if method.type is not stdapi.Void:
679 self.serializeRet(method, '_result')
680 print ' trace::localWriter.endLeave();'
681 if method.type is not stdapi.Void:
682 self.wrapRet(method, '_result')
684 if method.name == 'Release':
685 assert method.type is not stdapi.Void
686 print ' if (!_result)'
687 print ' delete this;'
689 def implementIidWrapper(self, api):
691 print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
692 print r' os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
693 print r' functionName, reason,'
694 print r' riid.Data1, riid.Data2, riid.Data3,'
695 print r' riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
699 print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
700 print r' if (!ppvObj || !*ppvObj) {'
704 for iface in api.getAllInterfaces():
705 print r' %sif (riid == IID_%s) {' % (else_, iface.name)
706 print r' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
709 print r' %s{' % else_
710 print r' warnIID(functionName, riid, "unknown");'
715 def wrapIid(self, function, riid, out):
716 # Cast output arg to `void **` if necessary
718 obj_type = out.type.type.type
719 if not obj_type is stdapi.Void:
720 assert isinstance(obj_type, stdapi.Interface)
721 out_name = 'reinterpret_cast<void * *>(%s)' % out_name
723 print r' if (%s && *%s) {' % (out.name, out.name)
724 functionName = function.name
726 if self.interface is not None:
727 functionName = self.interface.name + '::' + functionName
728 print r' if (*%s == m_pInstance &&' % (out_name,)
729 print r' (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
730 print r' *%s = this;' % (out_name,)
733 print r' %s{' % else_
734 print r' wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
738 def invokeMethod(self, interface, base, method):
739 if method.type is stdapi.Void:
742 result = '_result = '
743 print ' %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
745 def emit_memcpy(self, dest, src, length):
746 print ' unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
747 print ' trace::localWriter.beginArg(0);'
748 print ' trace::localWriter.writePointer((uintptr_t)%s);' % dest
749 print ' trace::localWriter.endArg();'
750 print ' trace::localWriter.beginArg(1);'
751 print ' trace::localWriter.writeBlob(%s, %s);' % (src, length)
752 print ' trace::localWriter.endArg();'
753 print ' trace::localWriter.beginArg(2);'
754 print ' trace::localWriter.writeUInt(%s);' % length
755 print ' trace::localWriter.endArg();'
756 print ' trace::localWriter.endEnter();'
757 print ' trace::localWriter.beginLeave(_call);'
758 print ' trace::localWriter.endLeave();'
760 def fake_call(self, function, args):
761 print ' unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
762 for arg, instance in zip(function.args, args):
763 assert not arg.output
764 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
765 self.serializeValue(arg.type, instance)
766 print ' trace::localWriter.endArg();'
767 print ' trace::localWriter.endEnter();'
768 print ' trace::localWriter.beginLeave(_fake_call);'
769 print ' trace::localWriter.endLeave();'