]> git.cworth.org Git - apitrace/blob - wrappers/trace.py
0d7741eb391e55ef5c6d51ef0429d7b206ef8ae3
[apitrace] / wrappers / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 # Adjust path
30 import os.path
31 import sys
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
33
34
35 import specs.stdapi as stdapi
36
37
38 def getWrapperInterfaceName(interface):
39     return "Wrap" + interface.expr
40
41
42
43 class ComplexValueSerializer(stdapi.OnceVisitor):
44     '''Type visitors which generates serialization functions for
45     complex types.
46     
47     Simple types are serialized inline.
48     '''
49
50     def __init__(self, serializer):
51         stdapi.OnceVisitor.__init__(self)
52         self.serializer = serializer
53
54     def visitVoid(self, literal):
55         pass
56
57     def visitLiteral(self, literal):
58         pass
59
60     def visitString(self, string):
61         pass
62
63     def visitConst(self, const):
64         self.visit(const.type)
65
66     def visitStruct(self, struct):
67         print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
68         for type, name,  in struct.members:
69             if name is None:
70                 print '    "",'
71             else:
72                 print '    "%s",' % (name,)
73         print '};'
74         print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
75         if struct.name is None:
76             structName = '""'
77         else:
78             structName = '"%s"' % struct.name
79         print '    %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag)
80         print '};'
81         print
82
83     def visitArray(self, array):
84         self.visit(array.type)
85
86     def visitAttribArray(self, array):
87         pass
88
89     def visitBlob(self, array):
90         pass
91
92     def visitEnum(self, enum):
93         print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
94         for value in enum.values:
95             print '    {"%s", %s},' % (value, value)
96         print '};'
97         print
98         print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
99         print '    %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
100         print '};'
101         print
102
103     def visitBitmask(self, bitmask):
104         print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
105         for value in bitmask.values:
106             print '    {"%s", %s},' % (value, value)
107         print '};'
108         print
109         print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
110         print '    %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
111         print '};'
112         print
113
114     def visitPointer(self, pointer):
115         self.visit(pointer.type)
116
117     def visitIntPointer(self, pointer):
118         pass
119
120     def visitObjPointer(self, pointer):
121         self.visit(pointer.type)
122
123     def visitLinearPointer(self, pointer):
124         self.visit(pointer.type)
125
126     def visitHandle(self, handle):
127         self.visit(handle.type)
128
129     def visitReference(self, reference):
130         self.visit(reference.type)
131
132     def visitAlias(self, alias):
133         self.visit(alias.type)
134
135     def visitOpaque(self, opaque):
136         pass
137
138     def visitInterface(self, interface):
139         pass
140
141     def visitPolymorphic(self, polymorphic):
142         if not polymorphic.contextLess:
143             return
144         print 'static void _write__%s(int selector, %s const & value) {' % (polymorphic.tag, polymorphic.expr)
145         print '    switch (selector) {'
146         for cases, type in polymorphic.iterSwitch():
147             for case in cases:
148                 print '    %s:' % case
149             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
150             print '        break;'
151         print '    }'
152         print '}'
153         print
154
155
156 class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin):
157     '''Visitor which generates code to serialize any type.
158     
159     Simple types are serialized inline here, whereas the serialization of
160     complex types is dispatched to the serialization functions generated by
161     ComplexValueSerializer visitor above.
162     '''
163
164     def visitLiteral(self, literal, instance):
165         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
166
167     def visitString(self, string, instance):
168         if not string.wide:
169             cast = 'const char *'
170             suffix = 'String'
171         else:
172             cast = 'const wchar_t *'
173             suffix = 'WString'
174         if cast != string.expr:
175             # reinterpret_cast is necessary for GLubyte * <=> char *
176             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
177         if string.length is not None:
178             length = ', %s' % self.expand(string.length)
179         else:
180             length = ''
181         print '    trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
182
183     def visitConst(self, const, instance):
184         self.visit(const.type, instance)
185
186     def visitStruct(self, struct, instance):
187         print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
188         for member in struct.members:
189             self.visitMember(member, instance)
190         print '    trace::localWriter.endStruct();'
191
192     def visitArray(self, array, instance):
193         length = '_c' + array.type.tag
194         index = '_i' + array.type.tag
195         array_length = self.expand(array.length)
196         print '    if (%s) {' % instance
197         print '        size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
198         print '        trace::localWriter.beginArray(%s);' % length
199         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
200         print '            trace::localWriter.beginElement();'
201         self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
202         print '            trace::localWriter.endElement();'
203         print '        }'
204         print '        trace::localWriter.endArray();'
205         print '    } else {'
206         print '        trace::localWriter.writeNull();'
207         print '    }'
208
209     def visitAttribArray(self, array, instance):
210         # For each element, decide if it is a key or a value (which depends on the previous key).
211         # If it is a value, store it as the right type - usually int, some bitfield, or some enum.
212         # It is currently assumed that an unknown key means that it is followed by an int value.
213
214         # determine the array length which must be passed to writeArray() up front
215         count = '_c' + array.keyType.tag
216         print '    int %s;' % count
217         print '    for (%(c)s = 0; %(array)s && %(array)s[%(c)s] != %(terminator)s; %(c)s += 2) {' \
218               % {'c': count, 'array': instance, 'terminator': array.terminator}
219         if array.hasKeysWithoutValues:
220             print '        switch (%(array)s[%(c)s]) {' % {'array': instance, 'c': count}
221             for key, valueType in array.valueTypes:
222                 if valueType is None:
223                     print '        case %s:' % key
224             print '            %s--;' % count # the next value is a key again and checked if it's the terminator
225             print '            break;'
226             print '        }'
227         print '    }'
228         print '    %(c)s += %(array)s ? 1 : 0;' % {'c': count, 'array': instance}
229         print '    trace::localWriter.beginArray(%s);' % count
230
231         # for each key / key-value pair write the key and the value, if the key requires one
232
233         index = '_i' + array.keyType.tag
234         print '    for (int %(i)s = 0; %(i)s < %(count)s; %(i)s++) {' % {'i': index, 'count': count}
235         print '        trace::localWriter.beginElement();'
236         self.visitEnum(array.keyType, "%(array)s[%(i)s]" % {'array': instance, 'i': index})
237         print '        trace::localWriter.endElement();'
238         print '        if (%(i)s + 1 >= %(count)s) {' % {'i': index, 'count': count}
239         print '            break;'
240         print '        }'
241         print '        switch (%(array)s[%(i)s++]) {' % {'array': instance, 'i': index}
242         # write generic value the usual way
243         for key, valueType in array.valueTypes:
244             if valueType is not None:
245                 print '        case %s:' % key
246                 print '            trace::localWriter.beginElement();'
247                 self.visitElement(index, valueType, '(%(array)s)[%(i)s]' % {'array': instance, 'i': index})
248                 print '            trace::localWriter.endElement();'
249                 print '            break;'
250         # known key with no value, just decrease the index so we treat the next value as a key
251         if array.hasKeysWithoutValues:
252             for key, valueType in array.valueTypes:
253                 if valueType is None:
254                     print '        case %s:' % key
255             print '            %s--;' % index
256             print '            break;'
257         # unknown key, write an int value
258         print '        default:'
259         print '            trace::localWriter.beginElement();'
260         print '            trace::localWriter.writeSInt(%(array)s[%(i)s]);'  % {'array': instance, 'i': index}
261         print '            trace::localWriter.endElement();'
262         print '            break;'
263         print '        }'
264         print '    }'
265         print '    trace::localWriter.endArray();'
266
267
268     def visitBlob(self, blob, instance):
269         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
270
271     def visitEnum(self, enum, instance):
272         print '    trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
273
274     def visitBitmask(self, bitmask, instance):
275         print '    trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
276
277     def visitPointer(self, pointer, instance):
278         print '    if (%s) {' % instance
279         print '        trace::localWriter.beginArray(1);'
280         print '        trace::localWriter.beginElement();'
281         self.visit(pointer.type, "*" + instance)
282         print '        trace::localWriter.endElement();'
283         print '        trace::localWriter.endArray();'
284         print '    } else {'
285         print '        trace::localWriter.writeNull();'
286         print '    }'
287
288     def visitIntPointer(self, pointer, instance):
289         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
290
291     def visitObjPointer(self, pointer, instance):
292         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
293
294     def visitLinearPointer(self, pointer, instance):
295         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
296
297     def visitReference(self, reference, instance):
298         self.visit(reference.type, instance)
299
300     def visitHandle(self, handle, instance):
301         self.visit(handle.type, instance)
302
303     def visitAlias(self, alias, instance):
304         self.visit(alias.type, instance)
305
306     def visitOpaque(self, opaque, instance):
307         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
308
309     def visitInterface(self, interface, instance):
310         assert False
311
312     def visitPolymorphic(self, polymorphic, instance):
313         if polymorphic.contextLess:
314             print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
315         else:
316             switchExpr = self.expand(polymorphic.switchExpr)
317             print '    switch (%s) {' % switchExpr
318             for cases, type in polymorphic.iterSwitch():
319                 for case in cases:
320                     print '    %s:' % case
321                 caseInstance = instance
322                 if type.expr is not None:
323                     caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
324                 self.visit(type, caseInstance)
325                 print '        break;'
326             if polymorphic.defaultType is None:
327                 print r'    default:'
328                 print r'        os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,)
329                 print r'        trace::localWriter.writeNull();'
330                 print r'        break;'
331             print '    }'
332
333
334 class WrapDecider(stdapi.Traverser):
335     '''Type visitor which will decide wheter this type will need wrapping or not.
336     
337     For complex types (arrays, structures), we need to know this before hand.
338     '''
339
340     def __init__(self):
341         self.needsWrapping = False
342
343     def visitLinearPointer(self, void):
344         pass
345
346     def visitInterface(self, interface):
347         self.needsWrapping = True
348
349
350 class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin):
351     '''Type visitor which will generate the code to wrap an instance.
352     
353     Wrapping is necessary mostly for interfaces, however interface pointers can
354     appear anywhere inside complex types.
355     '''
356
357     def visitStruct(self, struct, instance):
358         for member in struct.members:
359             self.visitMember(member, instance)
360
361     def visitArray(self, array, instance):
362         array_length = self.expand(array.length)
363         print "    if (%s) {" % instance
364         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
365         self.visitElement('_i', array.type, instance + "[_i]")
366         print "        }"
367         print "    }"
368
369     def visitPointer(self, pointer, instance):
370         print "    if (%s) {" % instance
371         self.visit(pointer.type, "*" + instance)
372         print "    }"
373     
374     def visitObjPointer(self, pointer, instance):
375         elem_type = pointer.type.mutable()
376         if isinstance(elem_type, stdapi.Interface):
377             self.visitInterfacePointer(elem_type, instance)
378         elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
379             self.visitInterfacePointer(elem_type.type, instance)
380         else:
381             self.visitPointer(pointer, instance)
382     
383     def visitInterface(self, interface, instance):
384         raise NotImplementedError
385
386     def visitInterfacePointer(self, interface, instance):
387         print "    if (%s) {" % instance
388         print "        %s = %s::_Create(__FUNCTION__, %s);" % (instance, getWrapperInterfaceName(interface), instance)
389         print "    }"
390     
391     def visitPolymorphic(self, type, instance):
392         # XXX: There might be polymorphic values that need wrapping in the future
393         raise NotImplementedError
394
395
396 class ValueUnwrapper(ValueWrapper):
397     '''Reverse of ValueWrapper.'''
398
399     allocated = False
400
401     def visitStruct(self, struct, instance):
402         if not self.allocated:
403             # Argument is constant. We need to create a non const
404             print '    {'
405             print "        %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
406             print '        *_t = %s;' % (instance,)
407             assert instance.startswith('*')
408             print '        %s = _t;' % (instance[1:],)
409             instance = '*_t'
410             self.allocated = True
411             try:
412                 return ValueWrapper.visitStruct(self, struct, instance)
413             finally:
414                 print '    }'
415         else:
416             return ValueWrapper.visitStruct(self, struct, instance)
417
418     def visitArray(self, array, instance):
419         if self.allocated or isinstance(instance, stdapi.Interface):
420             return ValueWrapper.visitArray(self, array, instance)
421         array_length = self.expand(array.length)
422         elem_type = array.type.mutable()
423         print "    if (%s && %s) {" % (instance, array_length)
424         print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
425         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
426         print "            _t[_i] = %s[_i];" % instance 
427         self.allocated = True
428         self.visit(array.type, "_t[_i]")
429         print "        }"
430         print "        %s = _t;" % instance
431         print "    }"
432
433     def visitInterfacePointer(self, interface, instance):
434         print r'    if (%s) {' % instance
435         print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
436         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
437         print r'            %s = pWrapper->m_pInstance;' % (instance,)
438         print r'        } else {'
439         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
440         print r'        }'
441         print r'    }'
442
443
444 class Tracer:
445     '''Base class to orchestrate the code generation of API tracing.'''
446
447     # 0-3 are reserved to memcpy, malloc, free, and realloc
448     __id = 4
449
450     def __init__(self):
451         self.api = None
452
453     def serializerFactory(self):
454         '''Create a serializer.
455         
456         Can be overriden by derived classes to inject their own serialzer.
457         '''
458
459         return ValueSerializer()
460
461     def traceApi(self, api):
462         self.api = api
463
464         self.header(api)
465
466         # Includes
467         for module in api.modules:
468             for header in module.headers:
469                 print header
470         print
471
472         # Generate the serializer functions
473         types = api.getAllTypes()
474         visitor = ComplexValueSerializer(self.serializerFactory())
475         map(visitor.visit, types)
476         print
477
478         # Interfaces wrapers
479         self.traceInterfaces(api)
480
481         # Function wrappers
482         self.interface = None
483         self.base = None
484         for function in api.getAllFunctions():
485             self.traceFunctionDecl(function)
486         for function in api.getAllFunctions():
487             self.traceFunctionImpl(function)
488         print
489
490         self.footer(api)
491
492     def header(self, api):
493         print '#ifdef _WIN32'
494         print '#  include <malloc.h> // alloca'
495         print '#  ifndef alloca'
496         print '#    define alloca _alloca'
497         print '#  endif'
498         print '#else'
499         print '#  include <alloca.h> // alloca'
500         print '#endif'
501         print
502         print '#include "trace.hpp"'
503         print
504         print 'static std::map<void *, void *> g_WrappedObjects;'
505
506     def footer(self, api):
507         pass
508
509     def traceFunctionDecl(self, function):
510         # Per-function declarations
511
512         if not function.internal:
513             if function.args:
514                 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
515             else:
516                 print 'static const char ** _%s_args = NULL;' % (function.name,)
517             print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, self.getFunctionSigId(), function.name, len(function.args), function.name)
518             print
519
520     def getFunctionSigId(self):
521         id = Tracer.__id
522         Tracer.__id += 1
523         return id
524
525     def isFunctionPublic(self, function):
526         return True
527
528     def traceFunctionImpl(self, function):
529         if self.isFunctionPublic(function):
530             print 'extern "C" PUBLIC'
531         else:
532             print 'extern "C" PRIVATE'
533         print function.prototype() + ' {'
534         if function.type is not stdapi.Void:
535             print '    %s _result;' % function.type
536
537         # No-op if tracing is disabled
538         print '    if (!trace::isTracingEnabled()) {'
539         self.doInvokeFunction(function)
540         if function.type is not stdapi.Void:
541             print '        return _result;'
542         else:
543             print '        return;'
544         print '    }'
545
546         self.traceFunctionImplBody(function)
547         if function.type is not stdapi.Void:
548             print '    return _result;'
549         print '}'
550         print
551
552     def traceFunctionImplBody(self, function):
553         if not function.internal:
554             print '    unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
555             for arg in function.args:
556                 if not arg.output:
557                     self.unwrapArg(function, arg)
558             for arg in function.args:
559                 if not arg.output:
560                     self.serializeArg(function, arg)
561             print '    trace::localWriter.endEnter();'
562         self.invokeFunction(function)
563         if not function.internal:
564             print '    trace::localWriter.beginLeave(_call);'
565             print '    if (%s) {' % self.wasFunctionSuccessful(function)
566             for arg in function.args:
567                 if arg.output:
568                     self.serializeArg(function, arg)
569                     self.wrapArg(function, arg)
570             print '    }'
571             if function.type is not stdapi.Void:
572                 self.serializeRet(function, "_result")
573             if function.type is not stdapi.Void:
574                 self.wrapRet(function, "_result")
575             print '    trace::localWriter.endLeave();'
576
577     def invokeFunction(self, function):
578         self.doInvokeFunction(function)
579
580     def doInvokeFunction(self, function, prefix='_', suffix=''):
581         # Same as invokeFunction() but called both when trace is enabled or disabled.
582         if function.type is stdapi.Void:
583             result = ''
584         else:
585             result = '_result = '
586         dispatch = prefix + function.name + suffix
587         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
588
589     def wasFunctionSuccessful(self, function):
590         if function.type is stdapi.Void:
591             return 'true'
592         if str(function.type) == 'HRESULT':
593             return 'SUCCEEDED(_result)'
594         return 'true'
595
596     def serializeArg(self, function, arg):
597         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
598         self.serializeArgValue(function, arg)
599         print '    trace::localWriter.endArg();'
600
601     def serializeArgValue(self, function, arg):
602         self.serializeValue(arg.type, arg.name)
603
604     def wrapArg(self, function, arg):
605         assert not isinstance(arg.type, stdapi.ObjPointer)
606
607         from specs.winapi import REFIID
608         riid = None
609         for other_arg in function.args:
610             if not other_arg.output and other_arg.type is REFIID:
611                 riid = other_arg
612         if riid is not None \
613            and isinstance(arg.type, stdapi.Pointer) \
614            and isinstance(arg.type.type, stdapi.ObjPointer):
615             self.wrapIid(function, riid, arg)
616             return
617
618         self.wrapValue(arg.type, arg.name)
619
620     def unwrapArg(self, function, arg):
621         self.unwrapValue(arg.type, arg.name)
622
623     def serializeRet(self, function, instance):
624         print '    trace::localWriter.beginReturn();'
625         self.serializeValue(function.type, instance)
626         print '    trace::localWriter.endReturn();'
627
628     def serializeValue(self, type, instance):
629         serializer = self.serializerFactory()
630         serializer.visit(type, instance)
631
632     def wrapRet(self, function, instance):
633         self.wrapValue(function.type, instance)
634
635     def needsWrapping(self, type):
636         visitor = WrapDecider()
637         visitor.visit(type)
638         return visitor.needsWrapping
639
640     def wrapValue(self, type, instance):
641         if self.needsWrapping(type):
642             visitor = ValueWrapper()
643             visitor.visit(type, instance)
644
645     def unwrapValue(self, type, instance):
646         if self.needsWrapping(type):
647             visitor = ValueUnwrapper()
648             visitor.visit(type, instance)
649
650     def traceInterfaces(self, api):
651         interfaces = api.getAllInterfaces()
652         if not interfaces:
653             return
654         map(self.declareWrapperInterface, interfaces)
655         self.implementIidWrapper(api)
656         map(self.implementWrapperInterface, interfaces)
657         print
658
659     def declareWrapperInterface(self, interface):
660         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
661         print "{"
662         print "private:"
663         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
664         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
665         print "public:"
666         print "    static %s* _Create(const char *functionName, %s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
667         print
668         for method in interface.iterMethods():
669             print "    " + method.prototype() + ";"
670         print
671         #print "private:"
672         for type, name, value in self.enumWrapperInterfaceVariables(interface):
673             print '    %s %s;' % (type, name)
674         for i in range(64):
675             print r'    virtual void _dummy%i(void) const {' % i
676             print r'        os::log("error: %s: unexpected virtual method\n");' % interface.name
677             print r'        os::abort();'
678             print r'    }'
679         print "};"
680         print
681
682     def enumWrapperInterfaceVariables(self, interface):
683         return [
684             ("DWORD", "m_dwMagic", "0xd8365d6c"),
685             ("%s *" % interface.name, "m_pInstance", "pInstance"),
686             ("void *", "m_pVtbl", "*(void **)pInstance"),
687             ("UINT", "m_NumMethods", len(list(interface.iterBaseMethods()))),
688         ] 
689
690     def implementWrapperInterface(self, interface):
691         self.interface = interface
692
693         # Private constructor
694         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
695         for type, name, value in self.enumWrapperInterfaceVariables(interface):
696             if value is not None:
697                 print '    %s = %s;' % (name, value)
698         print '}'
699         print
700
701         # Public constructor
702         print '%s *%s::_Create(const char *functionName, %s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
703         print r'    std::map<void *, void *>::const_iterator it = g_WrappedObjects.find(pInstance);'
704         print r'    if (it != g_WrappedObjects.end()) {'
705         print r'        Wrap%s *pWrapper = (Wrap%s *)it->second;' % (interface.name, interface.name)
706         print r'        assert(pWrapper);'
707         print r'        assert(pWrapper->m_dwMagic == 0xd8365d6c);'
708         print r'        assert(pWrapper->m_pInstance == pInstance);'
709         print r'        if (pWrapper->m_pVtbl == *(void **)pInstance &&'
710         print r'            pWrapper->m_NumMethods >= %s) {' % len(list(interface.iterBaseMethods()))
711         #print r'            os::log("%s: fetched pvObj=%p pWrapper=%p pVtbl=%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);'
712         print r'            return pWrapper;'
713         print r'        }'
714         print r'    }'
715         print r'    Wrap%s *pWrapper = new Wrap%s(pInstance);' % (interface.name, interface.name)
716         #print r'    os::log("%%s: created %s pvObj=%%p pWrapper=%%p pVtbl=%%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' % interface.name
717         print r'    g_WrappedObjects[pInstance] = pWrapper;'
718         print r'    return pWrapper;'
719         print '}'
720         print
721
722         # Destructor
723         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
724         #print r'        os::log("%s::Release: deleted pvObj=%%p pWrapper=%%p pVtbl=%%p\n", m_pInstance, this, m_pVtbl);' % interface.name
725         print r'        g_WrappedObjects.erase(m_pInstance);'
726         print '}'
727         print
728         
729         for base, method in interface.iterBaseMethods():
730             self.base = base
731             self.implementWrapperInterfaceMethod(interface, base, method)
732
733         print
734
735     def implementWrapperInterfaceMethod(self, interface, base, method):
736         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
737
738         if False:
739             print r'    os::log("%%s(%%p -> %%p)\n", "%s", this, m_pInstance);' % (getWrapperInterfaceName(interface) + '::' + method.name)
740
741         if method.type is not stdapi.Void:
742             print '    %s _result;' % method.type
743     
744         self.implementWrapperInterfaceMethodBody(interface, base, method)
745     
746         if method.type is not stdapi.Void:
747             print '    return _result;'
748         print '}'
749         print
750
751     def implementWrapperInterfaceMethodBody(self, interface, base, method):
752         assert not method.internal
753
754         print '    static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
755         print '    static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (self.getFunctionSigId(), interface.name + '::' + method.name, len(method.args) + 1)
756
757         print '    %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
758
759         print '    unsigned _call = trace::localWriter.beginEnter(&_sig);'
760         print '    trace::localWriter.beginArg(0);'
761         print '    trace::localWriter.writePointer((uintptr_t)m_pInstance);'
762         print '    trace::localWriter.endArg();'
763         for arg in method.args:
764             if not arg.output:
765                 self.unwrapArg(method, arg)
766         for arg in method.args:
767             if not arg.output:
768                 self.serializeArg(method, arg)
769         print '    trace::localWriter.endEnter();'
770         
771         self.invokeMethod(interface, base, method)
772
773         print '    trace::localWriter.beginLeave(_call);'
774
775         print '    if (%s) {' % self.wasFunctionSuccessful(method)
776         for arg in method.args:
777             if arg.output:
778                 self.serializeArg(method, arg)
779                 self.wrapArg(method, arg)
780         print '    }'
781
782         if method.type is not stdapi.Void:
783             self.serializeRet(method, '_result')
784         if method.type is not stdapi.Void:
785             self.wrapRet(method, '_result')
786
787         if method.name == 'Release':
788             assert method.type is not stdapi.Void
789             print r'    if (!_result) {'
790             print r'        delete this;'
791             print r'    }'
792         
793         print '    trace::localWriter.endLeave();'
794
795     def implementIidWrapper(self, api):
796         print r'static void'
797         print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
798         print r'    os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
799         print r'            functionName, reason,'
800         print r'            riid.Data1, riid.Data2, riid.Data3,'
801         print r'            riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
802         print r'}'
803         print 
804         print r'static void'
805         print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
806         print r'    if (!ppvObj || !*ppvObj) {'
807         print r'        return;'
808         print r'    }'
809         else_ = ''
810         for iface in api.getAllInterfaces():
811             print r'    %sif (riid == IID_%s) {' % (else_, iface.name)
812             print r'        *ppvObj = Wrap%s::_Create(functionName, (%s *) *ppvObj);' % (iface.name, iface.name)
813             print r'    }'
814             else_ = 'else '
815         print r'    %s{' % else_
816         print r'        warnIID(functionName, riid, "unknown");'
817         print r'    }'
818         print r'}'
819         print
820
821     def wrapIid(self, function, riid, out):
822         # Cast output arg to `void **` if necessary
823         out_name = out.name
824         obj_type = out.type.type.type
825         if not obj_type is stdapi.Void:
826             assert isinstance(obj_type, stdapi.Interface)
827             out_name = 'reinterpret_cast<void * *>(%s)' % out_name
828
829         print r'    if (%s && *%s) {' % (out.name, out.name)
830         functionName = function.name
831         else_ = ''
832         if self.interface is not None:
833             functionName = self.interface.name + '::' + functionName
834             print r'        if (*%s == m_pInstance &&' % (out_name,)
835             print r'            (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
836             print r'            *%s = this;' % (out_name,)
837             print r'        }'
838             else_ = 'else '
839         print r'        %s{' % else_
840         print r'             wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
841         print r'        }'
842         print r'    }'
843
844     def invokeMethod(self, interface, base, method):
845         if method.type is stdapi.Void:
846             result = ''
847         else:
848             result = '_result = '
849         print '    %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
850     
851     def emit_memcpy(self, dest, src, length):
852         print '        unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig, true);'
853         print '        trace::localWriter.beginArg(0);'
854         print '        trace::localWriter.writePointer((uintptr_t)%s);' % dest
855         print '        trace::localWriter.endArg();'
856         print '        trace::localWriter.beginArg(1);'
857         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
858         print '        trace::localWriter.endArg();'
859         print '        trace::localWriter.beginArg(2);'
860         print '        trace::localWriter.writeUInt(%s);' % length
861         print '        trace::localWriter.endArg();'
862         print '        trace::localWriter.endEnter();'
863         print '        trace::localWriter.beginLeave(_call);'
864         print '        trace::localWriter.endLeave();'
865     
866     def fake_call(self, function, args):
867         print '            unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
868         for arg, instance in zip(function.args, args):
869             assert not arg.output
870             print '            trace::localWriter.beginArg(%u);' % (arg.index,)
871             self.serializeValue(arg.type, instance)
872             print '            trace::localWriter.endArg();'
873         print '            trace::localWriter.endEnter();'
874         print '            trace::localWriter.beginLeave(_fake_call);'
875         print '            trace::localWriter.endLeave();'
876