1 ##########################################################################
3 # Copyright 2008-2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
26 """Common trace code generation."""
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
35 import specs.stdapi as stdapi
38 def getWrapperInterfaceName(interface):
39 return "Wrap" + interface.expr
42 class ComplexValueSerializer(stdapi.OnceVisitor):
43 '''Type visitors which generates serialization functions for
46 Simple types are serialized inline.
49 def __init__(self, serializer):
50 stdapi.OnceVisitor.__init__(self)
51 self.serializer = serializer
53 def visitVoid(self, literal):
56 def visitLiteral(self, literal):
59 def visitString(self, string):
62 def visitConst(self, const):
63 self.visit(const.type)
65 def visitStruct(self, struct):
66 for type, name in struct.members:
68 print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
69 print ' static const char * members[%u] = {' % (len(struct.members),)
70 for type, name, in struct.members:
71 print ' "%s",' % (name,)
73 print ' static const trace::StructSig sig = {'
74 print ' %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
76 print ' trace::localWriter.beginStruct(&sig);'
77 for type, name in struct.members:
78 self.serializer.visit(type, 'value.%s' % (name,))
79 print ' trace::localWriter.endStruct();'
83 def visitArray(self, array):
84 self.visit(array.type)
86 def visitBlob(self, array):
89 def visitEnum(self, enum):
90 print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
91 for value in enum.values:
92 print ' {"%s", %s},' % (value, value)
95 print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
96 print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
100 def visitBitmask(self, bitmask):
101 print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
102 for value in bitmask.values:
103 print ' {"%s", %s},' % (value, value)
106 print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
107 print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
111 def visitPointer(self, pointer):
112 self.visit(pointer.type)
114 def visitIntPointer(self, pointer):
117 def visitObjPointer(self, pointer):
118 self.visit(pointer.type)
120 def visitLinearPointer(self, pointer):
121 self.visit(pointer.type)
123 def visitHandle(self, handle):
124 self.visit(handle.type)
126 def visitReference(self, reference):
127 self.visit(reference.type)
129 def visitAlias(self, alias):
130 self.visit(alias.type)
132 def visitOpaque(self, opaque):
135 def visitInterface(self, interface):
138 def visitPolymorphic(self, polymorphic):
139 if not polymorphic.contextLess:
141 print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
142 print ' switch (selector) {'
143 for cases, type in polymorphic.iterSwitch():
146 self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
153 class ValueSerializer(stdapi.Visitor):
154 '''Visitor which generates code to serialize any type.
156 Simple types are serialized inline here, whereas the serialization of
157 complex types is dispatched to the serialization functions generated by
158 ComplexValueSerializer visitor above.
161 def visitLiteral(self, literal, instance):
162 print ' trace::localWriter.write%s(%s);' % (literal.kind, instance)
164 def visitString(self, string, instance):
166 cast = 'const char *'
169 cast = 'const wchar_t *'
171 if cast != string.expr:
172 # reinterpret_cast is necessary for GLubyte * <=> char *
173 instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
174 if string.length is not None:
175 length = ', %s' % string.length
178 print ' trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
180 def visitConst(self, const, instance):
181 self.visit(const.type, instance)
183 def visitStruct(self, struct, instance):
184 print ' _write__%s(%s);' % (struct.tag, instance)
186 def visitArray(self, array, instance):
187 length = '_c' + array.type.tag
188 index = '_i' + array.type.tag
189 print ' if (%s) {' % instance
190 print ' size_t %s = %s > 0 ? %s : 0;' % (length, array.length, array.length)
191 print ' trace::localWriter.beginArray(%s);' % length
192 print ' for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
193 print ' trace::localWriter.beginElement();'
194 self.visit(array.type, '(%s)[%s]' % (instance, index))
195 print ' trace::localWriter.endElement();'
197 print ' trace::localWriter.endArray();'
199 print ' trace::localWriter.writeNull();'
202 def visitBlob(self, blob, instance):
203 print ' trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
205 def visitEnum(self, enum, instance):
206 print ' trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
208 def visitBitmask(self, bitmask, instance):
209 print ' trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
211 def visitPointer(self, pointer, instance):
212 print ' if (%s) {' % instance
213 print ' trace::localWriter.beginArray(1);'
214 print ' trace::localWriter.beginElement();'
215 self.visit(pointer.type, "*" + instance)
216 print ' trace::localWriter.endElement();'
217 print ' trace::localWriter.endArray();'
219 print ' trace::localWriter.writeNull();'
222 def visitIntPointer(self, pointer, instance):
223 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
225 def visitObjPointer(self, pointer, instance):
226 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
228 def visitLinearPointer(self, pointer, instance):
229 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
231 def visitReference(self, reference, instance):
232 self.visit(reference.type, instance)
234 def visitHandle(self, handle, instance):
235 self.visit(handle.type, instance)
237 def visitAlias(self, alias, instance):
238 self.visit(alias.type, instance)
240 def visitOpaque(self, opaque, instance):
241 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
243 def visitInterface(self, interface, instance):
246 def visitPolymorphic(self, polymorphic, instance):
247 if polymorphic.contextLess:
248 print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
250 print ' switch (%s) {' % polymorphic.switchExpr
251 for cases, type in polymorphic.iterSwitch():
254 self.visit(type, 'static_cast<%s>(%s)' % (type, instance))
259 class WrapDecider(stdapi.Traverser):
260 '''Type visitor which will decide wheter this type will need wrapping or not.
262 For complex types (arrays, structures), we need to know this before hand.
266 self.needsWrapping = False
268 def visitLinearPointer(self, void):
271 def visitInterface(self, interface):
272 self.needsWrapping = True
275 class ValueWrapper(stdapi.Traverser):
276 '''Type visitor which will generate the code to wrap an instance.
278 Wrapping is necessary mostly for interfaces, however interface pointers can
279 appear anywhere inside complex types.
282 def visitStruct(self, struct, instance):
283 for type, name in struct.members:
284 self.visit(type, "(%s).%s" % (instance, name))
286 def visitArray(self, array, instance):
287 print " if (%s) {" % instance
288 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
289 self.visit(array.type, instance + "[_i]")
293 def visitPointer(self, pointer, instance):
294 print " if (%s) {" % instance
295 self.visit(pointer.type, "*" + instance)
298 def visitObjPointer(self, pointer, instance):
299 elem_type = pointer.type.mutable()
300 if isinstance(elem_type, stdapi.Interface):
301 self.visitInterfacePointer(elem_type, instance)
303 self.visitPointer(pointer, instance)
305 def visitInterface(self, interface, instance):
306 raise NotImplementedError
308 def visitInterfacePointer(self, interface, instance):
309 print " if (%s) {" % instance
310 print " %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
313 def visitPolymorphic(self, type, instance):
314 # XXX: There might be polymorphic values that need wrapping in the future
315 raise NotImplementedError
318 class ValueUnwrapper(ValueWrapper):
319 '''Reverse of ValueWrapper.'''
323 def visitArray(self, array, instance):
324 if self.allocated or isinstance(instance, stdapi.Interface):
325 return ValueWrapper.visitArray(self, array, instance)
326 elem_type = array.type.mutable()
327 print " if (%s && %s) {" % (instance, array.length)
328 print " %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
329 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
330 print " _t[_i] = %s[_i];" % instance
331 self.allocated = True
332 self.visit(array.type, "_t[_i]")
334 print " %s = _t;" % instance
337 def visitInterfacePointer(self, interface, instance):
338 print r' if (%s) {' % instance
339 print r' const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
340 print r' if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
341 print r' %s = pWrapper->m_pInstance;' % (instance,)
343 print r' os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
349 '''Base class to orchestrate the code generation of API tracing.'''
354 def serializerFactory(self):
355 '''Create a serializer.
357 Can be overriden by derived classes to inject their own serialzer.
360 return ValueSerializer()
362 def traceApi(self, api):
368 for header in api.headers:
372 # Generate the serializer functions
373 types = api.getAllTypes()
374 visitor = ComplexValueSerializer(self.serializerFactory())
375 map(visitor.visit, types)
379 self.traceInterfaces(api)
382 self.interface = None
384 map(self.traceFunctionDecl, api.functions)
385 map(self.traceFunctionImpl, api.functions)
390 def header(self, api):
391 print '#ifdef _WIN32'
392 print '# include <malloc.h> // alloca'
393 print '# ifndef alloca'
394 print '# define alloca _alloca'
397 print '# include <alloca.h> // alloca'
400 print '#include "trace.hpp"'
403 def footer(self, api):
406 def traceFunctionDecl(self, function):
407 # Per-function declarations
409 if not function.internal:
411 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
413 print 'static const char ** _%s_args = NULL;' % (function.name,)
414 print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
417 def isFunctionPublic(self, function):
420 def traceFunctionImpl(self, function):
421 if self.isFunctionPublic(function):
422 print 'extern "C" PUBLIC'
424 print 'extern "C" PRIVATE'
425 print function.prototype() + ' {'
426 if function.type is not stdapi.Void:
427 print ' %s _result;' % function.type
429 # No-op if tracing is disabled
430 print ' if (!trace::isTracingEnabled()) {'
431 Tracer.invokeFunction(self, function)
432 if function.type is not stdapi.Void:
433 print ' return _result;'
438 self.traceFunctionImplBody(function)
439 if function.type is not stdapi.Void:
440 print ' return _result;'
444 def traceFunctionImplBody(self, function):
445 if not function.internal:
446 print ' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
447 for arg in function.args:
449 self.unwrapArg(function, arg)
450 self.serializeArg(function, arg)
451 print ' trace::localWriter.endEnter();'
452 self.invokeFunction(function)
453 if not function.internal:
454 print ' trace::localWriter.beginLeave(_call);'
455 for arg in function.args:
457 self.serializeArg(function, arg)
458 self.wrapArg(function, arg)
459 if function.type is not stdapi.Void:
460 self.serializeRet(function, "_result")
461 print ' trace::localWriter.endLeave();'
462 if function.type is not stdapi.Void:
463 self.wrapRet(function, "_result")
465 def invokeFunction(self, function, prefix='_', suffix=''):
466 if function.type is stdapi.Void:
469 result = '_result = '
470 dispatch = prefix + function.name + suffix
471 print ' %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
473 def serializeArg(self, function, arg):
474 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
475 self.serializeArgValue(function, arg)
476 print ' trace::localWriter.endArg();'
478 def serializeArgValue(self, function, arg):
479 self.serializeValue(arg.type, arg.name)
481 def wrapArg(self, function, arg):
482 assert not isinstance(arg.type, stdapi.ObjPointer)
484 from specs.winapi import REFIID
486 for other_arg in function.args:
487 if not other_arg.output and other_arg.type is REFIID:
489 if riid is not None \
490 and isinstance(arg.type, stdapi.Pointer) \
491 and isinstance(arg.type.type, stdapi.ObjPointer):
492 self.wrapIid(function, riid, arg)
495 self.wrapValue(arg.type, arg.name)
497 def unwrapArg(self, function, arg):
498 self.unwrapValue(arg.type, arg.name)
500 def serializeRet(self, function, instance):
501 print ' trace::localWriter.beginReturn();'
502 self.serializeValue(function.type, instance)
503 print ' trace::localWriter.endReturn();'
505 def serializeValue(self, type, instance):
506 serializer = self.serializerFactory()
507 serializer.visit(type, instance)
509 def wrapRet(self, function, instance):
510 self.wrapValue(function.type, instance)
512 def unwrapRet(self, function, instance):
513 self.unwrapValue(function.type, instance)
515 def needsWrapping(self, type):
516 visitor = WrapDecider()
518 return visitor.needsWrapping
520 def wrapValue(self, type, instance):
521 if self.needsWrapping(type):
522 visitor = ValueWrapper()
523 visitor.visit(type, instance)
525 def unwrapValue(self, type, instance):
526 if self.needsWrapping(type):
527 visitor = ValueUnwrapper()
528 visitor.visit(type, instance)
530 def traceInterfaces(self, api):
531 interfaces = api.getAllInterfaces()
534 map(self.declareWrapperInterface, interfaces)
535 self.implementIidWrapper(api)
536 map(self.implementWrapperInterface, interfaces)
539 def declareWrapperInterface(self, interface):
540 print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
543 print " %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
544 print " virtual ~%s();" % getWrapperInterfaceName(interface)
546 for method in interface.iterMethods():
547 print " " + method.prototype() + ";"
550 for type, name, value in self.enumWrapperInterfaceVariables(interface):
551 print ' %s %s;' % (type, name)
555 def enumWrapperInterfaceVariables(self, interface):
557 ("DWORD", "m_dwMagic", "0xd8365d6c"),
558 ("%s *" % interface.name, "m_pInstance", "pInstance"),
561 def implementWrapperInterface(self, interface):
562 self.interface = interface
564 print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
565 for type, name, value in self.enumWrapperInterfaceVariables(interface):
566 print ' %s = %s;' % (name, value)
569 print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
573 for base, method in interface.iterBaseMethods():
575 self.implementWrapperInterfaceMethod(interface, base, method)
579 def implementWrapperInterfaceMethod(self, interface, base, method):
580 print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
581 if method.type is not stdapi.Void:
582 print ' %s _result;' % method.type
584 self.implementWrapperInterfaceMethodBody(interface, base, method)
586 if method.type is not stdapi.Void:
587 print ' return _result;'
591 def implementWrapperInterfaceMethodBody(self, interface, base, method):
592 assert not method.internal
594 print ' static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
595 print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
597 print ' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
599 print ' unsigned _call = trace::localWriter.beginEnter(&_sig);'
600 print ' trace::localWriter.beginArg(0);'
601 print ' trace::localWriter.writePointer((uintptr_t)m_pInstance);'
602 print ' trace::localWriter.endArg();'
603 for arg in method.args:
605 self.unwrapArg(method, arg)
606 self.serializeArg(method, arg)
607 print ' trace::localWriter.endEnter();'
609 self.invokeMethod(interface, base, method)
611 print ' trace::localWriter.beginLeave(_call);'
612 for arg in method.args:
614 self.serializeArg(method, arg)
615 self.wrapArg(method, arg)
617 if method.type is not stdapi.Void:
618 self.serializeRet(method, '_result')
619 print ' trace::localWriter.endLeave();'
620 if method.type is not stdapi.Void:
621 self.wrapRet(method, '_result')
623 if method.name == 'Release':
624 assert method.type is not stdapi.Void
625 print ' if (!_result)'
626 print ' delete this;'
628 def implementIidWrapper(self, api):
630 print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
631 print r' os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
632 print r' functionName, reason,'
633 print r' riid.Data1, riid.Data2, riid.Data3,'
634 print r' riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
638 print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
639 print r' if (!ppvObj || !*ppvObj) {'
643 for iface in api.getAllInterfaces():
644 print r' %sif (riid == IID_%s) {' % (else_, iface.name)
645 print r' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
648 print r' %s{' % else_
649 print r' warnIID(functionName, riid, "unknown");'
654 def wrapIid(self, function, riid, out):
655 # Cast output arg to `void **` if necessary
657 obj_type = out.type.type.type
658 if not obj_type is stdapi.Void:
659 assert isinstance(obj_type, stdapi.Interface)
660 out_name = 'reinterpret_cast<void * *>(%s)' % out_name
662 print r' if (%s && *%s) {' % (out.name, out.name)
663 functionName = function.name
665 if self.interface is not None:
666 functionName = self.interface.name + '::' + functionName
667 print r' if (*%s == m_pInstance &&' % (out_name,)
668 print r' (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
669 print r' *%s = this;' % (out_name,)
672 print r' %s{' % else_
673 print r' wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
677 def invokeMethod(self, interface, base, method):
678 if method.type is stdapi.Void:
681 result = '_result = '
682 print ' %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
684 def emit_memcpy(self, dest, src, length):
685 print ' unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
686 print ' trace::localWriter.beginArg(0);'
687 print ' trace::localWriter.writePointer((uintptr_t)%s);' % dest
688 print ' trace::localWriter.endArg();'
689 print ' trace::localWriter.beginArg(1);'
690 print ' trace::localWriter.writeBlob(%s, %s);' % (src, length)
691 print ' trace::localWriter.endArg();'
692 print ' trace::localWriter.beginArg(2);'
693 print ' trace::localWriter.writeUInt(%s);' % length
694 print ' trace::localWriter.endArg();'
695 print ' trace::localWriter.endEnter();'
696 print ' trace::localWriter.beginLeave(_call);'
697 print ' trace::localWriter.endLeave();'
699 def fake_call(self, function, args):
700 print ' unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
701 for arg, instance in zip(function.args, args):
702 assert not arg.output
703 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
704 self.serializeValue(arg.type, instance)
705 print ' trace::localWriter.endArg();'
706 print ' trace::localWriter.endEnter();'
707 print ' trace::localWriter.beginLeave(_fake_call);'
708 print ' trace::localWriter.endLeave();'