]> git.cworth.org Git - apitrace/blob - wrappers/trace.py
17e4e9c9baf03a0158dbc4ebf02ee6a3afdd02f8
[apitrace] / wrappers / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 # Adjust path
30 import os.path
31 import sys
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
33
34
35 import specs.stdapi as stdapi
36
37
38 def getWrapperInterfaceName(interface):
39     return "Wrap" + interface.expr
40
41
42
43 class ComplexValueSerializer(stdapi.OnceVisitor):
44     '''Type visitors which generates serialization functions for
45     complex types.
46     
47     Simple types are serialized inline.
48     '''
49
50     def __init__(self, serializer):
51         stdapi.OnceVisitor.__init__(self)
52         self.serializer = serializer
53
54     def visitVoid(self, literal):
55         pass
56
57     def visitLiteral(self, literal):
58         pass
59
60     def visitString(self, string):
61         pass
62
63     def visitConst(self, const):
64         self.visit(const.type)
65
66     def visitStruct(self, struct):
67         print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
68         for type, name,  in struct.members:
69             if name is None:
70                 print '    "",'
71             else:
72                 print '    "%s",' % (name,)
73         print '};'
74         print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
75         if struct.name is None:
76             structName = '""'
77         else:
78             structName = '"%s"' % struct.name
79         print '    %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag)
80         print '};'
81         print
82
83     def visitArray(self, array):
84         self.visit(array.type)
85
86     def visitAttribArray(self, array):
87         pass
88
89     def visitBlob(self, array):
90         pass
91
92     def visitEnum(self, enum):
93         print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
94         for value in enum.values:
95             print '    {"%s", %s},' % (value, value)
96         print '};'
97         print
98         print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
99         print '    %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
100         print '};'
101         print
102
103     def visitBitmask(self, bitmask):
104         print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
105         for value in bitmask.values:
106             print '    {"%s", %s},' % (value, value)
107         print '};'
108         print
109         print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
110         print '    %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
111         print '};'
112         print
113
114     def visitPointer(self, pointer):
115         self.visit(pointer.type)
116
117     def visitIntPointer(self, pointer):
118         pass
119
120     def visitObjPointer(self, pointer):
121         self.visit(pointer.type)
122
123     def visitLinearPointer(self, pointer):
124         self.visit(pointer.type)
125
126     def visitHandle(self, handle):
127         self.visit(handle.type)
128
129     def visitReference(self, reference):
130         self.visit(reference.type)
131
132     def visitAlias(self, alias):
133         self.visit(alias.type)
134
135     def visitOpaque(self, opaque):
136         pass
137
138     def visitInterface(self, interface):
139         pass
140
141     def visitPolymorphic(self, polymorphic):
142         if not polymorphic.contextLess:
143             return
144         print 'static void _write__%s(int selector, %s const & value) {' % (polymorphic.tag, polymorphic.expr)
145         print '    switch (selector) {'
146         for cases, type in polymorphic.iterSwitch():
147             for case in cases:
148                 print '    %s:' % case
149             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
150             print '        break;'
151         print '    }'
152         print '}'
153         print
154
155
156 class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin):
157     '''Visitor which generates code to serialize any type.
158     
159     Simple types are serialized inline here, whereas the serialization of
160     complex types is dispatched to the serialization functions generated by
161     ComplexValueSerializer visitor above.
162     '''
163
164     def visitLiteral(self, literal, instance):
165         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
166
167     def visitString(self, string, instance):
168         if not string.wide:
169             cast = 'const char *'
170             suffix = 'String'
171         else:
172             cast = 'const wchar_t *'
173             suffix = 'WString'
174         if cast != string.expr:
175             # reinterpret_cast is necessary for GLubyte * <=> char *
176             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
177         if string.length is not None:
178             length = ', %s' % self.expand(string.length)
179         else:
180             length = ''
181         print '    trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
182
183     def visitConst(self, const, instance):
184         self.visit(const.type, instance)
185
186     def visitStruct(self, struct, instance):
187         print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
188         for member in struct.members:
189             self.visitMember(member, instance)
190         print '    trace::localWriter.endStruct();'
191
192     def visitArray(self, array, instance):
193         length = '_c' + array.type.tag
194         index = '_i' + array.type.tag
195         array_length = self.expand(array.length)
196         print '    if (%s) {' % instance
197         print '        size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
198         print '        trace::localWriter.beginArray(%s);' % length
199         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
200         print '            trace::localWriter.beginElement();'
201         self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
202         print '            trace::localWriter.endElement();'
203         print '        }'
204         print '        trace::localWriter.endArray();'
205         print '    } else {'
206         print '        trace::localWriter.writeNull();'
207         print '    }'
208
209     def visitAttribArray(self, array, instance):
210         # iterate element by element and for each, decide if it is a key or value (which depends on the
211         # previous key). If it is a value, look up what it means and store it as the right type - usually
212         # int, some bitfield, or some enum.
213
214         # determine the array length, which is unfortunately needed by writeArray() up front
215         count = '_c' + array.keyType.tag
216         print '    int %s;' % count
217         print '    for (%(c)s = 0; %(array)s && %(array)s[%(c)s]; %(c)s++) {' % {'c': count, 'array': instance}
218         print '        switch (%(array)s[%(c)s]) {' % {'array': instance, 'c': count}
219         for key, valueType in array.valueTypes:
220             if valueType is not None:
221                 print '        case %s:' % key
222         print '            %s++;' % count # only a null key marks the end; skip null values
223         print '            break;'
224         print '        }'
225         print '    }'
226         # ### not handling null attrib_list differently from empty
227         print '    trace::localWriter.beginArray(%s);' % count
228
229         # for each key / key-value pair write the key and the value, if the key requires one
230
231         index = '_i' + array.keyType.tag
232         print '    for (int %(i)s = 0; %(i)s < %(count)s; %(i)s++) {' % {'i': index, 'count': count}
233         self.visitEnum(array.keyType, "%(array)s[%(i)s]" % {'array': instance, 'i': index})
234         print '        switch (%(array)s[%(i)s]) {' % {'array': instance, 'i': index}
235         # write generic value the usual way
236         for key, valueType in array.valueTypes:
237             if valueType is not None:
238                 print '        case %s:' % key
239                 print '            %s++;' % index
240                 print '            trace::localWriter.beginElement();'
241                 self.visitElement(index, valueType, '(%(array)s)[%(i)s]' % {'array': instance, 'i': index})
242                 print '            trace::localWriter.endElement();'
243                 print '            break;'
244         # unknown key, write an int value
245         print '        default:'
246         print '            %s++;' % index
247         print '            trace::localWriter.beginElement();'
248         print '            trace::localWriter.writeSInt(%(array)s[%(i)s]);'  % {'array': instance, 'i': index}
249         print '            trace::localWriter.endElement();'
250         # known key with no value, do nothing
251         for key, valueType in array.valueTypes:
252             if valueType is None:
253                 print '        case %s:' % key
254         print '            break;'
255         print '        }'
256         print '    }'
257         print '    trace::localWriter.endArray();'
258
259
260     def visitBlob(self, blob, instance):
261         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
262
263     def visitEnum(self, enum, instance):
264         print '    trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
265
266     def visitBitmask(self, bitmask, instance):
267         print '    trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
268
269     def visitPointer(self, pointer, instance):
270         print '    if (%s) {' % instance
271         print '        trace::localWriter.beginArray(1);'
272         print '        trace::localWriter.beginElement();'
273         self.visit(pointer.type, "*" + instance)
274         print '        trace::localWriter.endElement();'
275         print '        trace::localWriter.endArray();'
276         print '    } else {'
277         print '        trace::localWriter.writeNull();'
278         print '    }'
279
280     def visitIntPointer(self, pointer, instance):
281         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
282
283     def visitObjPointer(self, pointer, instance):
284         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
285
286     def visitLinearPointer(self, pointer, instance):
287         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
288
289     def visitReference(self, reference, instance):
290         self.visit(reference.type, instance)
291
292     def visitHandle(self, handle, instance):
293         self.visit(handle.type, instance)
294
295     def visitAlias(self, alias, instance):
296         self.visit(alias.type, instance)
297
298     def visitOpaque(self, opaque, instance):
299         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
300
301     def visitInterface(self, interface, instance):
302         assert False
303
304     def visitPolymorphic(self, polymorphic, instance):
305         if polymorphic.contextLess:
306             print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
307         else:
308             switchExpr = self.expand(polymorphic.switchExpr)
309             print '    switch (%s) {' % switchExpr
310             for cases, type in polymorphic.iterSwitch():
311                 for case in cases:
312                     print '    %s:' % case
313                 caseInstance = instance
314                 if type.expr is not None:
315                     caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
316                 self.visit(type, caseInstance)
317                 print '        break;'
318             if polymorphic.defaultType is None:
319                 print r'    default:'
320                 print r'        os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,)
321                 print r'        trace::localWriter.writeNull();'
322                 print r'        break;'
323             print '    }'
324
325
326 class WrapDecider(stdapi.Traverser):
327     '''Type visitor which will decide wheter this type will need wrapping or not.
328     
329     For complex types (arrays, structures), we need to know this before hand.
330     '''
331
332     def __init__(self):
333         self.needsWrapping = False
334
335     def visitLinearPointer(self, void):
336         pass
337
338     def visitInterface(self, interface):
339         self.needsWrapping = True
340
341
342 class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin):
343     '''Type visitor which will generate the code to wrap an instance.
344     
345     Wrapping is necessary mostly for interfaces, however interface pointers can
346     appear anywhere inside complex types.
347     '''
348
349     def visitStruct(self, struct, instance):
350         for member in struct.members:
351             self.visitMember(member, instance)
352
353     def visitArray(self, array, instance):
354         array_length = self.expand(array.length)
355         print "    if (%s) {" % instance
356         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
357         self.visitElement('_i', array.type, instance + "[_i]")
358         print "        }"
359         print "    }"
360
361     def visitPointer(self, pointer, instance):
362         print "    if (%s) {" % instance
363         self.visit(pointer.type, "*" + instance)
364         print "    }"
365     
366     def visitObjPointer(self, pointer, instance):
367         elem_type = pointer.type.mutable()
368         if isinstance(elem_type, stdapi.Interface):
369             self.visitInterfacePointer(elem_type, instance)
370         elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
371             self.visitInterfacePointer(elem_type.type, instance)
372         else:
373             self.visitPointer(pointer, instance)
374     
375     def visitInterface(self, interface, instance):
376         raise NotImplementedError
377
378     def visitInterfacePointer(self, interface, instance):
379         print "    if (%s) {" % instance
380         print "        %s = %s::_Create(__FUNCTION__, %s);" % (instance, getWrapperInterfaceName(interface), instance)
381         print "    }"
382     
383     def visitPolymorphic(self, type, instance):
384         # XXX: There might be polymorphic values that need wrapping in the future
385         raise NotImplementedError
386
387
388 class ValueUnwrapper(ValueWrapper):
389     '''Reverse of ValueWrapper.'''
390
391     allocated = False
392
393     def visitStruct(self, struct, instance):
394         if not self.allocated:
395             # Argument is constant. We need to create a non const
396             print '    {'
397             print "        %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
398             print '        *_t = %s;' % (instance,)
399             assert instance.startswith('*')
400             print '        %s = _t;' % (instance[1:],)
401             instance = '*_t'
402             self.allocated = True
403             try:
404                 return ValueWrapper.visitStruct(self, struct, instance)
405             finally:
406                 print '    }'
407         else:
408             return ValueWrapper.visitStruct(self, struct, instance)
409
410     def visitArray(self, array, instance):
411         if self.allocated or isinstance(instance, stdapi.Interface):
412             return ValueWrapper.visitArray(self, array, instance)
413         array_length = self.expand(array.length)
414         elem_type = array.type.mutable()
415         print "    if (%s && %s) {" % (instance, array_length)
416         print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
417         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
418         print "            _t[_i] = %s[_i];" % instance 
419         self.allocated = True
420         self.visit(array.type, "_t[_i]")
421         print "        }"
422         print "        %s = _t;" % instance
423         print "    }"
424
425     def visitInterfacePointer(self, interface, instance):
426         print r'    if (%s) {' % instance
427         print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
428         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
429         print r'            %s = pWrapper->m_pInstance;' % (instance,)
430         print r'        } else {'
431         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
432         print r'        }'
433         print r'    }'
434
435
436 class Tracer:
437     '''Base class to orchestrate the code generation of API tracing.'''
438
439     # 0-3 are reserved to memcpy, malloc, free, and realloc
440     __id = 4
441
442     def __init__(self):
443         self.api = None
444
445     def serializerFactory(self):
446         '''Create a serializer.
447         
448         Can be overriden by derived classes to inject their own serialzer.
449         '''
450
451         return ValueSerializer()
452
453     def traceApi(self, api):
454         self.api = api
455
456         self.header(api)
457
458         # Includes
459         for module in api.modules:
460             for header in module.headers:
461                 print header
462         print
463
464         # Generate the serializer functions
465         types = api.getAllTypes()
466         visitor = ComplexValueSerializer(self.serializerFactory())
467         map(visitor.visit, types)
468         print
469
470         # Interfaces wrapers
471         self.traceInterfaces(api)
472
473         # Function wrappers
474         self.interface = None
475         self.base = None
476         for function in api.getAllFunctions():
477             self.traceFunctionDecl(function)
478         for function in api.getAllFunctions():
479             self.traceFunctionImpl(function)
480         print
481
482         self.footer(api)
483
484     def header(self, api):
485         print '#ifdef _WIN32'
486         print '#  include <malloc.h> // alloca'
487         print '#  ifndef alloca'
488         print '#    define alloca _alloca'
489         print '#  endif'
490         print '#else'
491         print '#  include <alloca.h> // alloca'
492         print '#endif'
493         print
494         print '#include "trace.hpp"'
495         print
496         print 'static std::map<void *, void *> g_WrappedObjects;'
497
498     def footer(self, api):
499         pass
500
501     def traceFunctionDecl(self, function):
502         # Per-function declarations
503
504         if not function.internal:
505             if function.args:
506                 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
507             else:
508                 print 'static const char ** _%s_args = NULL;' % (function.name,)
509             print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, self.getFunctionSigId(), function.name, len(function.args), function.name)
510             print
511
512     def getFunctionSigId(self):
513         id = Tracer.__id
514         Tracer.__id += 1
515         return id
516
517     def isFunctionPublic(self, function):
518         return True
519
520     def traceFunctionImpl(self, function):
521         if self.isFunctionPublic(function):
522             print 'extern "C" PUBLIC'
523         else:
524             print 'extern "C" PRIVATE'
525         print function.prototype() + ' {'
526         if function.type is not stdapi.Void:
527             print '    %s _result;' % function.type
528
529         # No-op if tracing is disabled
530         print '    if (!trace::isTracingEnabled()) {'
531         self.doInvokeFunction(function)
532         if function.type is not stdapi.Void:
533             print '        return _result;'
534         else:
535             print '        return;'
536         print '    }'
537
538         self.traceFunctionImplBody(function)
539         if function.type is not stdapi.Void:
540             print '    return _result;'
541         print '}'
542         print
543
544     def traceFunctionImplBody(self, function):
545         if not function.internal:
546             print '    unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
547             for arg in function.args:
548                 if not arg.output:
549                     self.unwrapArg(function, arg)
550             for arg in function.args:
551                 if not arg.output:
552                     self.serializeArg(function, arg)
553             print '    trace::localWriter.endEnter();'
554         self.invokeFunction(function)
555         if not function.internal:
556             print '    trace::localWriter.beginLeave(_call);'
557             print '    if (%s) {' % self.wasFunctionSuccessful(function)
558             for arg in function.args:
559                 if arg.output:
560                     self.serializeArg(function, arg)
561                     self.wrapArg(function, arg)
562             print '    }'
563             if function.type is not stdapi.Void:
564                 self.serializeRet(function, "_result")
565             if function.type is not stdapi.Void:
566                 self.wrapRet(function, "_result")
567             print '    trace::localWriter.endLeave();'
568
569     def invokeFunction(self, function):
570         self.doInvokeFunction(function)
571
572     def doInvokeFunction(self, function, prefix='_', suffix=''):
573         # Same as invokeFunction() but called both when trace is enabled or disabled.
574         if function.type is stdapi.Void:
575             result = ''
576         else:
577             result = '_result = '
578         dispatch = prefix + function.name + suffix
579         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
580
581     def wasFunctionSuccessful(self, function):
582         if function.type is stdapi.Void:
583             return 'true'
584         if str(function.type) == 'HRESULT':
585             return 'SUCCEEDED(_result)'
586         return 'true'
587
588     def serializeArg(self, function, arg):
589         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
590         self.serializeArgValue(function, arg)
591         print '    trace::localWriter.endArg();'
592
593     def serializeArgValue(self, function, arg):
594         self.serializeValue(arg.type, arg.name)
595
596     def wrapArg(self, function, arg):
597         assert not isinstance(arg.type, stdapi.ObjPointer)
598
599         from specs.winapi import REFIID
600         riid = None
601         for other_arg in function.args:
602             if not other_arg.output and other_arg.type is REFIID:
603                 riid = other_arg
604         if riid is not None \
605            and isinstance(arg.type, stdapi.Pointer) \
606            and isinstance(arg.type.type, stdapi.ObjPointer):
607             self.wrapIid(function, riid, arg)
608             return
609
610         self.wrapValue(arg.type, arg.name)
611
612     def unwrapArg(self, function, arg):
613         self.unwrapValue(arg.type, arg.name)
614
615     def serializeRet(self, function, instance):
616         print '    trace::localWriter.beginReturn();'
617         self.serializeValue(function.type, instance)
618         print '    trace::localWriter.endReturn();'
619
620     def serializeValue(self, type, instance):
621         serializer = self.serializerFactory()
622         serializer.visit(type, instance)
623
624     def wrapRet(self, function, instance):
625         self.wrapValue(function.type, instance)
626
627     def needsWrapping(self, type):
628         visitor = WrapDecider()
629         visitor.visit(type)
630         return visitor.needsWrapping
631
632     def wrapValue(self, type, instance):
633         if self.needsWrapping(type):
634             visitor = ValueWrapper()
635             visitor.visit(type, instance)
636
637     def unwrapValue(self, type, instance):
638         if self.needsWrapping(type):
639             visitor = ValueUnwrapper()
640             visitor.visit(type, instance)
641
642     def traceInterfaces(self, api):
643         interfaces = api.getAllInterfaces()
644         if not interfaces:
645             return
646         map(self.declareWrapperInterface, interfaces)
647         self.implementIidWrapper(api)
648         map(self.implementWrapperInterface, interfaces)
649         print
650
651     def declareWrapperInterface(self, interface):
652         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
653         print "{"
654         print "private:"
655         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
656         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
657         print "public:"
658         print "    static %s* _Create(const char *functionName, %s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
659         print
660         for method in interface.iterMethods():
661             print "    " + method.prototype() + ";"
662         print
663         #print "private:"
664         for type, name, value in self.enumWrapperInterfaceVariables(interface):
665             print '    %s %s;' % (type, name)
666         for i in range(64):
667             print r'    virtual void _dummy%i(void) const {' % i
668             print r'        os::log("error: %s: unexpected virtual method\n");' % interface.name
669             print r'        os::abort();'
670             print r'    }'
671         print "};"
672         print
673
674     def enumWrapperInterfaceVariables(self, interface):
675         return [
676             ("DWORD", "m_dwMagic", "0xd8365d6c"),
677             ("%s *" % interface.name, "m_pInstance", "pInstance"),
678             ("void *", "m_pVtbl", "*(void **)pInstance"),
679             ("UINT", "m_NumMethods", len(list(interface.iterBaseMethods()))),
680         ] 
681
682     def implementWrapperInterface(self, interface):
683         self.interface = interface
684
685         # Private constructor
686         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
687         for type, name, value in self.enumWrapperInterfaceVariables(interface):
688             if value is not None:
689                 print '    %s = %s;' % (name, value)
690         print '}'
691         print
692
693         # Public constructor
694         print '%s *%s::_Create(const char *functionName, %s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
695         print r'    std::map<void *, void *>::const_iterator it = g_WrappedObjects.find(pInstance);'
696         print r'    if (it != g_WrappedObjects.end()) {'
697         print r'        Wrap%s *pWrapper = (Wrap%s *)it->second;' % (interface.name, interface.name)
698         print r'        assert(pWrapper);'
699         print r'        assert(pWrapper->m_dwMagic == 0xd8365d6c);'
700         print r'        assert(pWrapper->m_pInstance == pInstance);'
701         print r'        if (pWrapper->m_pVtbl == *(void **)pInstance &&'
702         print r'            pWrapper->m_NumMethods >= %s) {' % len(list(interface.iterBaseMethods()))
703         #print r'            os::log("%s: fetched pvObj=%p pWrapper=%p pVtbl=%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);'
704         print r'            return pWrapper;'
705         print r'        }'
706         print r'    }'
707         print r'    Wrap%s *pWrapper = new Wrap%s(pInstance);' % (interface.name, interface.name)
708         #print r'    os::log("%%s: created %s pvObj=%%p pWrapper=%%p pVtbl=%%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' % interface.name
709         print r'    g_WrappedObjects[pInstance] = pWrapper;'
710         print r'    return pWrapper;'
711         print '}'
712         print
713
714         # Destructor
715         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
716         #print r'        os::log("%s::Release: deleted pvObj=%%p pWrapper=%%p pVtbl=%%p\n", m_pInstance, this, m_pVtbl);' % interface.name
717         print r'        g_WrappedObjects.erase(m_pInstance);'
718         print '}'
719         print
720         
721         for base, method in interface.iterBaseMethods():
722             self.base = base
723             self.implementWrapperInterfaceMethod(interface, base, method)
724
725         print
726
727     def implementWrapperInterfaceMethod(self, interface, base, method):
728         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
729
730         if False:
731             print r'    os::log("%%s(%%p -> %%p)\n", "%s", this, m_pInstance);' % (getWrapperInterfaceName(interface) + '::' + method.name)
732
733         if method.type is not stdapi.Void:
734             print '    %s _result;' % method.type
735     
736         self.implementWrapperInterfaceMethodBody(interface, base, method)
737     
738         if method.type is not stdapi.Void:
739             print '    return _result;'
740         print '}'
741         print
742
743     def implementWrapperInterfaceMethodBody(self, interface, base, method):
744         assert not method.internal
745
746         print '    static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
747         print '    static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (self.getFunctionSigId(), interface.name + '::' + method.name, len(method.args) + 1)
748
749         print '    %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
750
751         print '    unsigned _call = trace::localWriter.beginEnter(&_sig);'
752         print '    trace::localWriter.beginArg(0);'
753         print '    trace::localWriter.writePointer((uintptr_t)m_pInstance);'
754         print '    trace::localWriter.endArg();'
755         for arg in method.args:
756             if not arg.output:
757                 self.unwrapArg(method, arg)
758         for arg in method.args:
759             if not arg.output:
760                 self.serializeArg(method, arg)
761         print '    trace::localWriter.endEnter();'
762         
763         self.invokeMethod(interface, base, method)
764
765         print '    trace::localWriter.beginLeave(_call);'
766
767         print '    if (%s) {' % self.wasFunctionSuccessful(method)
768         for arg in method.args:
769             if arg.output:
770                 self.serializeArg(method, arg)
771                 self.wrapArg(method, arg)
772         print '    }'
773
774         if method.type is not stdapi.Void:
775             self.serializeRet(method, '_result')
776         if method.type is not stdapi.Void:
777             self.wrapRet(method, '_result')
778
779         if method.name == 'Release':
780             assert method.type is not stdapi.Void
781             print r'    if (!_result) {'
782             print r'        delete this;'
783             print r'    }'
784         
785         print '    trace::localWriter.endLeave();'
786
787     def implementIidWrapper(self, api):
788         print r'static void'
789         print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
790         print r'    os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
791         print r'            functionName, reason,'
792         print r'            riid.Data1, riid.Data2, riid.Data3,'
793         print r'            riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
794         print r'}'
795         print 
796         print r'static void'
797         print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
798         print r'    if (!ppvObj || !*ppvObj) {'
799         print r'        return;'
800         print r'    }'
801         else_ = ''
802         for iface in api.getAllInterfaces():
803             print r'    %sif (riid == IID_%s) {' % (else_, iface.name)
804             print r'        *ppvObj = Wrap%s::_Create(functionName, (%s *) *ppvObj);' % (iface.name, iface.name)
805             print r'    }'
806             else_ = 'else '
807         print r'    %s{' % else_
808         print r'        warnIID(functionName, riid, "unknown");'
809         print r'    }'
810         print r'}'
811         print
812
813     def wrapIid(self, function, riid, out):
814         # Cast output arg to `void **` if necessary
815         out_name = out.name
816         obj_type = out.type.type.type
817         if not obj_type is stdapi.Void:
818             assert isinstance(obj_type, stdapi.Interface)
819             out_name = 'reinterpret_cast<void * *>(%s)' % out_name
820
821         print r'    if (%s && *%s) {' % (out.name, out.name)
822         functionName = function.name
823         else_ = ''
824         if self.interface is not None:
825             functionName = self.interface.name + '::' + functionName
826             print r'        if (*%s == m_pInstance &&' % (out_name,)
827             print r'            (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
828             print r'            *%s = this;' % (out_name,)
829             print r'        }'
830             else_ = 'else '
831         print r'        %s{' % else_
832         print r'             wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
833         print r'        }'
834         print r'    }'
835
836     def invokeMethod(self, interface, base, method):
837         if method.type is stdapi.Void:
838             result = ''
839         else:
840             result = '_result = '
841         print '    %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
842     
843     def emit_memcpy(self, dest, src, length):
844         print '        unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig, true);'
845         print '        trace::localWriter.beginArg(0);'
846         print '        trace::localWriter.writePointer((uintptr_t)%s);' % dest
847         print '        trace::localWriter.endArg();'
848         print '        trace::localWriter.beginArg(1);'
849         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
850         print '        trace::localWriter.endArg();'
851         print '        trace::localWriter.beginArg(2);'
852         print '        trace::localWriter.writeUInt(%s);' % length
853         print '        trace::localWriter.endArg();'
854         print '        trace::localWriter.endEnter();'
855         print '        trace::localWriter.beginLeave(_call);'
856         print '        trace::localWriter.endLeave();'
857     
858     def fake_call(self, function, args):
859         print '            unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
860         for arg, instance in zip(function.args, args):
861             assert not arg.output
862             print '            trace::localWriter.beginArg(%u);' % (arg.index,)
863             self.serializeValue(arg.type, instance)
864             print '            trace::localWriter.endArg();'
865         print '            trace::localWriter.endEnter();'
866         print '            trace::localWriter.beginLeave(_fake_call);'
867         print '            trace::localWriter.endLeave();'
868