1 ##########################################################################
3 # Copyright 2008-2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
26 """Common trace code generation."""
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
35 import specs.stdapi as stdapi
38 def getWrapperInterfaceName(interface):
39 return "Wrap" + interface.expr
42 class ComplexValueSerializer(stdapi.OnceVisitor):
43 '''Type visitors which generates serialization functions for
46 Simple types are serialized inline.
49 def __init__(self, serializer):
50 stdapi.OnceVisitor.__init__(self)
51 self.serializer = serializer
53 def visitVoid(self, literal):
56 def visitLiteral(self, literal):
59 def visitString(self, string):
62 def visitConst(self, const):
63 self.visit(const.type)
65 def visitStruct(self, struct):
66 for type, name in struct.members:
68 print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
69 print ' static const char * members[%u] = {' % (len(struct.members),)
70 for type, name, in struct.members:
71 print ' "%s",' % (name,)
73 print ' static const trace::StructSig sig = {'
74 print ' %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
76 print ' trace::localWriter.beginStruct(&sig);'
77 for type, name in struct.members:
78 self.serializer.visit(type, 'value.%s' % (name,))
79 print ' trace::localWriter.endStruct();'
83 def visitArray(self, array):
84 self.visit(array.type)
86 def visitBlob(self, array):
89 def visitEnum(self, enum):
90 print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
91 for value in enum.values:
92 print ' {"%s", %s},' % (value, value)
95 print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
96 print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
100 def visitBitmask(self, bitmask):
101 print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
102 for value in bitmask.values:
103 print ' {"%s", %s},' % (value, value)
106 print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
107 print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
111 def visitPointer(self, pointer):
112 self.visit(pointer.type)
114 def visitIntPointer(self, pointer):
117 def visitObjPointer(self, pointer):
118 self.visit(pointer.type)
120 def visitLinearPointer(self, pointer):
121 self.visit(pointer.type)
123 def visitHandle(self, handle):
124 self.visit(handle.type)
126 def visitReference(self, reference):
127 self.visit(reference.type)
129 def visitAlias(self, alias):
130 self.visit(alias.type)
132 def visitOpaque(self, opaque):
135 def visitInterface(self, interface):
138 def visitPolymorphic(self, polymorphic):
139 if not polymorphic.contextLess:
141 print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
142 print ' switch (selector) {'
143 for cases, type in polymorphic.iterSwitch():
146 self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
153 class ValueSerializer(stdapi.Visitor):
154 '''Visitor which generates code to serialize any type.
156 Simple types are serialized inline here, whereas the serialization of
157 complex types is dispatched to the serialization functions generated by
158 ComplexValueSerializer visitor above.
161 def visitLiteral(self, literal, instance):
162 print ' trace::localWriter.write%s(%s);' % (literal.kind, instance)
164 def visitString(self, string, instance):
166 cast = 'const char *'
169 cast = 'const wchar_t *'
171 if cast != string.expr:
172 # reinterpret_cast is necessary for GLubyte * <=> char *
173 instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
174 if string.length is not None:
175 length = ', %s' % string.length
178 print ' trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
180 def visitConst(self, const, instance):
181 self.visit(const.type, instance)
183 def visitStruct(self, struct, instance):
184 print ' _write__%s(%s);' % (struct.tag, instance)
186 def visitArray(self, array, instance):
187 length = '_c' + array.type.tag
188 index = '_i' + array.type.tag
189 print ' if (%s) {' % instance
190 print ' size_t %s = %s > 0 ? %s : 0;' % (length, array.length, array.length)
191 print ' trace::localWriter.beginArray(%s);' % length
192 print ' for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
193 print ' trace::localWriter.beginElement();'
194 self.visit(array.type, '(%s)[%s]' % (instance, index))
195 print ' trace::localWriter.endElement();'
197 print ' trace::localWriter.endArray();'
199 print ' trace::localWriter.writeNull();'
202 def visitBlob(self, blob, instance):
203 print ' trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
205 def visitEnum(self, enum, instance):
206 print ' trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
208 def visitBitmask(self, bitmask, instance):
209 print ' trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
211 def visitPointer(self, pointer, instance):
212 print ' if (%s) {' % instance
213 print ' trace::localWriter.beginArray(1);'
214 print ' trace::localWriter.beginElement();'
215 self.visit(pointer.type, "*" + instance)
216 print ' trace::localWriter.endElement();'
217 print ' trace::localWriter.endArray();'
219 print ' trace::localWriter.writeNull();'
222 def visitIntPointer(self, pointer, instance):
223 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
225 def visitObjPointer(self, pointer, instance):
226 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
228 def visitLinearPointer(self, pointer, instance):
229 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
231 def visitReference(self, reference, instance):
232 self.visit(reference.type, instance)
234 def visitHandle(self, handle, instance):
235 self.visit(handle.type, instance)
237 def visitAlias(self, alias, instance):
238 self.visit(alias.type, instance)
240 def visitOpaque(self, opaque, instance):
241 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
243 def visitInterface(self, interface, instance):
246 def visitPolymorphic(self, polymorphic, instance):
247 if polymorphic.contextLess:
248 print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
250 print ' switch (%s) {' % polymorphic.switchExpr
251 for cases, type in polymorphic.iterSwitch():
254 self.visit(type, 'static_cast<%s>(%s)' % (type, instance))
259 class WrapDecider(stdapi.Traverser):
260 '''Type visitor which will decide wheter this type will need wrapping or not.
262 For complex types (arrays, structures), we need to know this before hand.
266 self.needsWrapping = False
268 def visitLinearPointer(self, void):
271 def visitInterface(self, interface):
272 self.needsWrapping = True
275 class ValueWrapper(stdapi.Traverser):
276 '''Type visitor which will generate the code to wrap an instance.
278 Wrapping is necessary mostly for interfaces, however interface pointers can
279 appear anywhere inside complex types.
282 def visitStruct(self, struct, instance):
283 for type, name in struct.members:
284 self.visit(type, "(%s).%s" % (instance, name))
286 def visitArray(self, array, instance):
287 print " if (%s) {" % instance
288 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
289 self.visit(array.type, instance + "[_i]")
293 def visitPointer(self, pointer, instance):
294 print " if (%s) {" % instance
295 self.visit(pointer.type, "*" + instance)
298 def visitObjPointer(self, pointer, instance):
299 elem_type = pointer.type.mutable()
300 if isinstance(elem_type, stdapi.Interface):
301 self.visitInterfacePointer(elem_type, instance)
302 elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
303 self.visitInterfacePointer(elem_type.type, instance)
305 self.visitPointer(pointer, instance)
307 def visitInterface(self, interface, instance):
308 raise NotImplementedError
310 def visitInterfacePointer(self, interface, instance):
311 print " if (%s) {" % instance
312 print " %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
315 def visitPolymorphic(self, type, instance):
316 # XXX: There might be polymorphic values that need wrapping in the future
317 raise NotImplementedError
320 class ValueUnwrapper(ValueWrapper):
321 '''Reverse of ValueWrapper.'''
325 def visitStruct(self, struct, instance):
326 if not self.allocated:
327 # Argument is constant. We need to create a non const
329 print " %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
330 print ' *_t = %s;' % (instance,)
331 assert instance.startswith('*')
332 print ' %s = _t;' % (instance[1:],)
334 self.allocated = True
336 return ValueWrapper.visitStruct(self, struct, instance)
340 return ValueWrapper.visitStruct(self, struct, instance)
342 def visitArray(self, array, instance):
343 if self.allocated or isinstance(instance, stdapi.Interface):
344 return ValueWrapper.visitArray(self, array, instance)
345 elem_type = array.type.mutable()
346 print " if (%s && %s) {" % (instance, array.length)
347 print " %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
348 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
349 print " _t[_i] = %s[_i];" % instance
350 self.allocated = True
351 self.visit(array.type, "_t[_i]")
353 print " %s = _t;" % instance
356 def visitInterfacePointer(self, interface, instance):
357 print r' if (%s) {' % instance
358 print r' const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
359 print r' if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
360 print r' %s = pWrapper->m_pInstance;' % (instance,)
362 print r' os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
368 '''Base class to orchestrate the code generation of API tracing.'''
373 def serializerFactory(self):
374 '''Create a serializer.
376 Can be overriden by derived classes to inject their own serialzer.
379 return ValueSerializer()
381 def traceApi(self, api):
387 for header in api.headers:
391 # Generate the serializer functions
392 types = api.getAllTypes()
393 visitor = ComplexValueSerializer(self.serializerFactory())
394 map(visitor.visit, types)
398 self.traceInterfaces(api)
401 self.interface = None
403 map(self.traceFunctionDecl, api.functions)
404 map(self.traceFunctionImpl, api.functions)
409 def header(self, api):
410 print '#ifdef _WIN32'
411 print '# include <malloc.h> // alloca'
412 print '# ifndef alloca'
413 print '# define alloca _alloca'
416 print '# include <alloca.h> // alloca'
419 print '#include "trace.hpp"'
422 def footer(self, api):
425 def traceFunctionDecl(self, function):
426 # Per-function declarations
428 if not function.internal:
430 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
432 print 'static const char ** _%s_args = NULL;' % (function.name,)
433 print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
436 def isFunctionPublic(self, function):
439 def traceFunctionImpl(self, function):
440 if self.isFunctionPublic(function):
441 print 'extern "C" PUBLIC'
443 print 'extern "C" PRIVATE'
444 print function.prototype() + ' {'
445 if function.type is not stdapi.Void:
446 print ' %s _result;' % function.type
448 # No-op if tracing is disabled
449 print ' if (!trace::isTracingEnabled()) {'
450 Tracer.invokeFunction(self, function)
451 if function.type is not stdapi.Void:
452 print ' return _result;'
457 self.traceFunctionImplBody(function)
458 if function.type is not stdapi.Void:
459 print ' return _result;'
463 def traceFunctionImplBody(self, function):
464 if not function.internal:
465 print ' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
466 for arg in function.args:
468 self.unwrapArg(function, arg)
469 self.serializeArg(function, arg)
470 print ' trace::localWriter.endEnter();'
471 self.invokeFunction(function)
472 if not function.internal:
473 print ' trace::localWriter.beginLeave(_call);'
474 for arg in function.args:
476 self.serializeArg(function, arg)
477 self.wrapArg(function, arg)
478 if function.type is not stdapi.Void:
479 self.serializeRet(function, "_result")
480 print ' trace::localWriter.endLeave();'
481 if function.type is not stdapi.Void:
482 self.wrapRet(function, "_result")
484 def invokeFunction(self, function, prefix='_', suffix=''):
485 if function.type is stdapi.Void:
488 result = '_result = '
489 dispatch = prefix + function.name + suffix
490 print ' %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
492 def serializeArg(self, function, arg):
493 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
494 self.serializeArgValue(function, arg)
495 print ' trace::localWriter.endArg();'
497 def serializeArgValue(self, function, arg):
498 self.serializeValue(arg.type, arg.name)
500 def wrapArg(self, function, arg):
501 assert not isinstance(arg.type, stdapi.ObjPointer)
503 from specs.winapi import REFIID
505 for other_arg in function.args:
506 if not other_arg.output and other_arg.type is REFIID:
508 if riid is not None \
509 and isinstance(arg.type, stdapi.Pointer) \
510 and isinstance(arg.type.type, stdapi.ObjPointer):
511 self.wrapIid(function, riid, arg)
514 self.wrapValue(arg.type, arg.name)
516 def unwrapArg(self, function, arg):
517 self.unwrapValue(arg.type, arg.name)
519 def serializeRet(self, function, instance):
520 print ' trace::localWriter.beginReturn();'
521 self.serializeValue(function.type, instance)
522 print ' trace::localWriter.endReturn();'
524 def serializeValue(self, type, instance):
525 serializer = self.serializerFactory()
526 serializer.visit(type, instance)
528 def wrapRet(self, function, instance):
529 self.wrapValue(function.type, instance)
531 def unwrapRet(self, function, instance):
532 self.unwrapValue(function.type, instance)
534 def needsWrapping(self, type):
535 visitor = WrapDecider()
537 return visitor.needsWrapping
539 def wrapValue(self, type, instance):
540 if self.needsWrapping(type):
541 visitor = ValueWrapper()
542 visitor.visit(type, instance)
544 def unwrapValue(self, type, instance):
545 if self.needsWrapping(type):
546 visitor = ValueUnwrapper()
547 visitor.visit(type, instance)
549 def traceInterfaces(self, api):
550 interfaces = api.getAllInterfaces()
553 map(self.declareWrapperInterface, interfaces)
554 self.implementIidWrapper(api)
555 map(self.implementWrapperInterface, interfaces)
558 def declareWrapperInterface(self, interface):
559 print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
562 print " %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
563 print " virtual ~%s();" % getWrapperInterfaceName(interface)
565 for method in interface.iterMethods():
566 print " " + method.prototype() + ";"
569 for type, name, value in self.enumWrapperInterfaceVariables(interface):
570 print ' %s %s;' % (type, name)
574 def enumWrapperInterfaceVariables(self, interface):
576 ("DWORD", "m_dwMagic", "0xd8365d6c"),
577 ("%s *" % interface.name, "m_pInstance", "pInstance"),
580 def implementWrapperInterface(self, interface):
581 self.interface = interface
583 print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
584 for type, name, value in self.enumWrapperInterfaceVariables(interface):
585 print ' %s = %s;' % (name, value)
588 print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
592 for base, method in interface.iterBaseMethods():
594 self.implementWrapperInterfaceMethod(interface, base, method)
598 def implementWrapperInterfaceMethod(self, interface, base, method):
599 print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
600 if method.type is not stdapi.Void:
601 print ' %s _result;' % method.type
603 self.implementWrapperInterfaceMethodBody(interface, base, method)
605 if method.type is not stdapi.Void:
606 print ' return _result;'
610 def implementWrapperInterfaceMethodBody(self, interface, base, method):
611 assert not method.internal
613 print ' static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
614 print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
616 print ' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
618 print ' unsigned _call = trace::localWriter.beginEnter(&_sig);'
619 print ' trace::localWriter.beginArg(0);'
620 print ' trace::localWriter.writePointer((uintptr_t)m_pInstance);'
621 print ' trace::localWriter.endArg();'
622 for arg in method.args:
624 self.unwrapArg(method, arg)
625 self.serializeArg(method, arg)
626 print ' trace::localWriter.endEnter();'
628 self.invokeMethod(interface, base, method)
630 print ' trace::localWriter.beginLeave(_call);'
631 for arg in method.args:
633 self.serializeArg(method, arg)
634 self.wrapArg(method, arg)
636 if method.type is not stdapi.Void:
637 self.serializeRet(method, '_result')
638 print ' trace::localWriter.endLeave();'
639 if method.type is not stdapi.Void:
640 self.wrapRet(method, '_result')
642 if method.name == 'Release':
643 assert method.type is not stdapi.Void
644 print ' if (!_result)'
645 print ' delete this;'
647 def implementIidWrapper(self, api):
649 print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
650 print r' os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
651 print r' functionName, reason,'
652 print r' riid.Data1, riid.Data2, riid.Data3,'
653 print r' riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
657 print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
658 print r' if (!ppvObj || !*ppvObj) {'
662 for iface in api.getAllInterfaces():
663 print r' %sif (riid == IID_%s) {' % (else_, iface.name)
664 print r' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
667 print r' %s{' % else_
668 print r' warnIID(functionName, riid, "unknown");'
673 def wrapIid(self, function, riid, out):
674 # Cast output arg to `void **` if necessary
676 obj_type = out.type.type.type
677 if not obj_type is stdapi.Void:
678 assert isinstance(obj_type, stdapi.Interface)
679 out_name = 'reinterpret_cast<void * *>(%s)' % out_name
681 print r' if (%s && *%s) {' % (out.name, out.name)
682 functionName = function.name
684 if self.interface is not None:
685 functionName = self.interface.name + '::' + functionName
686 print r' if (*%s == m_pInstance &&' % (out_name,)
687 print r' (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
688 print r' *%s = this;' % (out_name,)
691 print r' %s{' % else_
692 print r' wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
696 def invokeMethod(self, interface, base, method):
697 if method.type is stdapi.Void:
700 result = '_result = '
701 print ' %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
703 def emit_memcpy(self, dest, src, length):
704 print ' unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
705 print ' trace::localWriter.beginArg(0);'
706 print ' trace::localWriter.writePointer((uintptr_t)%s);' % dest
707 print ' trace::localWriter.endArg();'
708 print ' trace::localWriter.beginArg(1);'
709 print ' trace::localWriter.writeBlob(%s, %s);' % (src, length)
710 print ' trace::localWriter.endArg();'
711 print ' trace::localWriter.beginArg(2);'
712 print ' trace::localWriter.writeUInt(%s);' % length
713 print ' trace::localWriter.endArg();'
714 print ' trace::localWriter.endEnter();'
715 print ' trace::localWriter.beginLeave(_call);'
716 print ' trace::localWriter.endLeave();'
718 def fake_call(self, function, args):
719 print ' unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
720 for arg, instance in zip(function.args, args):
721 assert not arg.output
722 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
723 self.serializeValue(arg.type, instance)
724 print ' trace::localWriter.endArg();'
725 print ' trace::localWriter.endEnter();'
726 print ' trace::localWriter.beginLeave(_fake_call);'
727 print ' trace::localWriter.endLeave();'