]> git.cworth.org Git - apitrace/blob - wrappers/trace.py
Warn on unknown attrib_list keys.
[apitrace] / wrappers / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 # Adjust path
30 import os.path
31 import sys
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
33
34
35 import specs.stdapi as stdapi
36
37
38 def getWrapperInterfaceName(interface):
39     return "Wrap" + interface.expr
40
41
42
43 class ComplexValueSerializer(stdapi.OnceVisitor):
44     '''Type visitors which generates serialization functions for
45     complex types.
46     
47     Simple types are serialized inline.
48     '''
49
50     def __init__(self, serializer):
51         stdapi.OnceVisitor.__init__(self)
52         self.serializer = serializer
53
54     def visitVoid(self, literal):
55         pass
56
57     def visitLiteral(self, literal):
58         pass
59
60     def visitString(self, string):
61         pass
62
63     def visitConst(self, const):
64         self.visit(const.type)
65
66     def visitStruct(self, struct):
67         print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
68         for type, name,  in struct.members:
69             if name is None:
70                 print '    "",'
71             else:
72                 print '    "%s",' % (name,)
73         print '};'
74         print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
75         if struct.name is None:
76             structName = '""'
77         else:
78             structName = '"%s"' % struct.name
79         print '    %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag)
80         print '};'
81         print
82
83     def visitArray(self, array):
84         self.visit(array.type)
85
86     def visitAttribArray(self, array):
87         pass
88
89     def visitBlob(self, array):
90         pass
91
92     def visitEnum(self, enum):
93         print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
94         for value in enum.values:
95             print '    {"%s", %s},' % (value, value)
96         print '};'
97         print
98         print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
99         print '    %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
100         print '};'
101         print
102
103     def visitBitmask(self, bitmask):
104         print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
105         for value in bitmask.values:
106             print '    {"%s", %s},' % (value, value)
107         print '};'
108         print
109         print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
110         print '    %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
111         print '};'
112         print
113
114     def visitPointer(self, pointer):
115         self.visit(pointer.type)
116
117     def visitIntPointer(self, pointer):
118         pass
119
120     def visitObjPointer(self, pointer):
121         self.visit(pointer.type)
122
123     def visitLinearPointer(self, pointer):
124         self.visit(pointer.type)
125
126     def visitHandle(self, handle):
127         self.visit(handle.type)
128
129     def visitReference(self, reference):
130         self.visit(reference.type)
131
132     def visitAlias(self, alias):
133         self.visit(alias.type)
134
135     def visitOpaque(self, opaque):
136         pass
137
138     def visitInterface(self, interface):
139         pass
140
141     def visitPolymorphic(self, polymorphic):
142         if not polymorphic.contextLess:
143             return
144         print 'static void _write__%s(int selector, %s const & value) {' % (polymorphic.tag, polymorphic.expr)
145         print '    switch (selector) {'
146         for cases, type in polymorphic.iterSwitch():
147             for case in cases:
148                 print '    %s:' % case
149             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
150             print '        break;'
151         print '    }'
152         print '}'
153         print
154
155
156 class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin):
157     '''Visitor which generates code to serialize any type.
158     
159     Simple types are serialized inline here, whereas the serialization of
160     complex types is dispatched to the serialization functions generated by
161     ComplexValueSerializer visitor above.
162     '''
163
164     def visitLiteral(self, literal, instance):
165         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
166
167     def visitString(self, string, instance):
168         if not string.wide:
169             cast = 'const char *'
170             suffix = 'String'
171         else:
172             cast = 'const wchar_t *'
173             suffix = 'WString'
174         if cast != string.expr:
175             # reinterpret_cast is necessary for GLubyte * <=> char *
176             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
177         if string.length is not None:
178             length = ', %s' % self.expand(string.length)
179         else:
180             length = ''
181         print '    trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
182
183     def visitConst(self, const, instance):
184         self.visit(const.type, instance)
185
186     def visitStruct(self, struct, instance):
187         print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
188         for member in struct.members:
189             self.visitMember(member, instance)
190         print '    trace::localWriter.endStruct();'
191
192     def visitArray(self, array, instance):
193         length = '_c' + array.type.tag
194         index = '_i' + array.type.tag
195         array_length = self.expand(array.length)
196         print '    if (%s) {' % instance
197         print '        size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
198         print '        trace::localWriter.beginArray(%s);' % length
199         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
200         print '            trace::localWriter.beginElement();'
201         self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
202         print '            trace::localWriter.endElement();'
203         print '        }'
204         print '        trace::localWriter.endArray();'
205         print '    } else {'
206         print '        trace::localWriter.writeNull();'
207         print '    }'
208
209     def visitAttribArray(self, array, instance):
210         # For each element, decide if it is a key or a value (which depends on the previous key).
211         # If it is a value, store it as the right type - usually int, some bitfield, or some enum.
212         # It is currently assumed that an unknown key means that it is followed by an int value.
213
214         # determine the array length which must be passed to writeArray() up front
215         count = '_c' + array.keyType.tag
216         print '    int %s;' % count
217         print '    for (%(c)s = 0; %(array)s && %(array)s[%(c)s] != %(terminator)s; %(c)s += 2) {' \
218               % {'c': count, 'array': instance, 'terminator': array.terminator}
219         if array.hasKeysWithoutValues:
220             print '        switch (%(array)s[%(c)s]) {' % {'array': instance, 'c': count}
221             for key, valueType in array.valueTypes:
222                 if valueType is None:
223                     print '        case %s:' % key
224             print '            %s--;' % count # the next value is a key again and checked if it's the terminator
225             print '            break;'
226             print '        }'
227         print '    }'
228         print '    %(c)s += %(array)s ? 1 : 0;' % {'c': count, 'array': instance}
229         print '    trace::localWriter.beginArray(%s);' % count
230
231         # for each key / key-value pair write the key and the value, if the key requires one
232
233         index = '_i' + array.keyType.tag
234         print '    for (int %(i)s = 0; %(i)s < %(count)s; %(i)s++) {' % {'i': index, 'count': count}
235         print '        trace::localWriter.beginElement();'
236         self.visitEnum(array.keyType, "%(array)s[%(i)s]" % {'array': instance, 'i': index})
237         print '        trace::localWriter.endElement();'
238         print '        if (%(i)s + 1 >= %(count)s) {' % {'i': index, 'count': count}
239         print '            break;'
240         print '        }'
241         print '        switch (%(array)s[%(i)s++]) {' % {'array': instance, 'i': index}
242         # write generic value the usual way
243         for key, valueType in array.valueTypes:
244             if valueType is not None:
245                 print '        case %s:' % key
246                 print '            trace::localWriter.beginElement();'
247                 self.visitElement(index, valueType, '(%(array)s)[%(i)s]' % {'array': instance, 'i': index})
248                 print '            trace::localWriter.endElement();'
249                 print '            break;'
250         # known key with no value, just decrease the index so we treat the next value as a key
251         if array.hasKeysWithoutValues:
252             for key, valueType in array.valueTypes:
253                 if valueType is None:
254                     print '        case %s:' % key
255             print '            %s--;' % index
256             print '            break;'
257         # unknown key, write an int value
258         print '        default:'
259         print '            trace::localWriter.beginElement();'
260         print '            os::log("apitrace: warning: %s: unknown key 0x%04X, interpreting value as int\\n", ' + \
261                            '__FUNCTION__, %(array)s[%(i)s]);'  % {'array': instance, 'i': index}
262         print '            trace::localWriter.writeSInt(%(array)s[%(i)s]);' % {'array': instance, 'i': index}
263         print '            trace::localWriter.endElement();'
264         print '            break;'
265         print '        }'
266         print '    }'
267         print '    trace::localWriter.endArray();'
268
269
270     def visitBlob(self, blob, instance):
271         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
272
273     def visitEnum(self, enum, instance):
274         print '    trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
275
276     def visitBitmask(self, bitmask, instance):
277         print '    trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
278
279     def visitPointer(self, pointer, instance):
280         print '    if (%s) {' % instance
281         print '        trace::localWriter.beginArray(1);'
282         print '        trace::localWriter.beginElement();'
283         self.visit(pointer.type, "*" + instance)
284         print '        trace::localWriter.endElement();'
285         print '        trace::localWriter.endArray();'
286         print '    } else {'
287         print '        trace::localWriter.writeNull();'
288         print '    }'
289
290     def visitIntPointer(self, pointer, instance):
291         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
292
293     def visitObjPointer(self, pointer, instance):
294         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
295
296     def visitLinearPointer(self, pointer, instance):
297         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
298
299     def visitReference(self, reference, instance):
300         self.visit(reference.type, instance)
301
302     def visitHandle(self, handle, instance):
303         self.visit(handle.type, instance)
304
305     def visitAlias(self, alias, instance):
306         self.visit(alias.type, instance)
307
308     def visitOpaque(self, opaque, instance):
309         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
310
311     def visitInterface(self, interface, instance):
312         assert False
313
314     def visitPolymorphic(self, polymorphic, instance):
315         if polymorphic.contextLess:
316             print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
317         else:
318             switchExpr = self.expand(polymorphic.switchExpr)
319             print '    switch (%s) {' % switchExpr
320             for cases, type in polymorphic.iterSwitch():
321                 for case in cases:
322                     print '    %s:' % case
323                 caseInstance = instance
324                 if type.expr is not None:
325                     caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
326                 self.visit(type, caseInstance)
327                 print '        break;'
328             if polymorphic.defaultType is None:
329                 print r'    default:'
330                 print r'        os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,)
331                 print r'        trace::localWriter.writeNull();'
332                 print r'        break;'
333             print '    }'
334
335
336 class WrapDecider(stdapi.Traverser):
337     '''Type visitor which will decide wheter this type will need wrapping or not.
338     
339     For complex types (arrays, structures), we need to know this before hand.
340     '''
341
342     def __init__(self):
343         self.needsWrapping = False
344
345     def visitLinearPointer(self, void):
346         pass
347
348     def visitInterface(self, interface):
349         self.needsWrapping = True
350
351
352 class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin):
353     '''Type visitor which will generate the code to wrap an instance.
354     
355     Wrapping is necessary mostly for interfaces, however interface pointers can
356     appear anywhere inside complex types.
357     '''
358
359     def visitStruct(self, struct, instance):
360         for member in struct.members:
361             self.visitMember(member, instance)
362
363     def visitArray(self, array, instance):
364         array_length = self.expand(array.length)
365         print "    if (%s) {" % instance
366         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
367         self.visitElement('_i', array.type, instance + "[_i]")
368         print "        }"
369         print "    }"
370
371     def visitPointer(self, pointer, instance):
372         print "    if (%s) {" % instance
373         self.visit(pointer.type, "*" + instance)
374         print "    }"
375     
376     def visitObjPointer(self, pointer, instance):
377         elem_type = pointer.type.mutable()
378         if isinstance(elem_type, stdapi.Interface):
379             self.visitInterfacePointer(elem_type, instance)
380         elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
381             self.visitInterfacePointer(elem_type.type, instance)
382         else:
383             self.visitPointer(pointer, instance)
384     
385     def visitInterface(self, interface, instance):
386         raise NotImplementedError
387
388     def visitInterfacePointer(self, interface, instance):
389         print "    if (%s) {" % instance
390         print "        %s = %s::_Create(__FUNCTION__, %s);" % (instance, getWrapperInterfaceName(interface), instance)
391         print "    }"
392     
393     def visitPolymorphic(self, type, instance):
394         # XXX: There might be polymorphic values that need wrapping in the future
395         raise NotImplementedError
396
397
398 class ValueUnwrapper(ValueWrapper):
399     '''Reverse of ValueWrapper.'''
400
401     allocated = False
402
403     def visitStruct(self, struct, instance):
404         if not self.allocated:
405             # Argument is constant. We need to create a non const
406             print '    {'
407             print "        %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
408             print '        *_t = %s;' % (instance,)
409             assert instance.startswith('*')
410             print '        %s = _t;' % (instance[1:],)
411             instance = '*_t'
412             self.allocated = True
413             try:
414                 return ValueWrapper.visitStruct(self, struct, instance)
415             finally:
416                 print '    }'
417         else:
418             return ValueWrapper.visitStruct(self, struct, instance)
419
420     def visitArray(self, array, instance):
421         if self.allocated or isinstance(instance, stdapi.Interface):
422             return ValueWrapper.visitArray(self, array, instance)
423         array_length = self.expand(array.length)
424         elem_type = array.type.mutable()
425         print "    if (%s && %s) {" % (instance, array_length)
426         print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
427         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
428         print "            _t[_i] = %s[_i];" % instance 
429         self.allocated = True
430         self.visit(array.type, "_t[_i]")
431         print "        }"
432         print "        %s = _t;" % instance
433         print "    }"
434
435     def visitInterfacePointer(self, interface, instance):
436         print r'    if (%s) {' % instance
437         print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
438         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
439         print r'            %s = pWrapper->m_pInstance;' % (instance,)
440         print r'        } else {'
441         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
442         print r'        }'
443         print r'    }'
444
445
446 class Tracer:
447     '''Base class to orchestrate the code generation of API tracing.'''
448
449     # 0-3 are reserved to memcpy, malloc, free, and realloc
450     __id = 4
451
452     def __init__(self):
453         self.api = None
454
455     def serializerFactory(self):
456         '''Create a serializer.
457         
458         Can be overriden by derived classes to inject their own serialzer.
459         '''
460
461         return ValueSerializer()
462
463     def traceApi(self, api):
464         self.api = api
465
466         self.header(api)
467
468         # Includes
469         for module in api.modules:
470             for header in module.headers:
471                 print header
472         print
473
474         # Generate the serializer functions
475         types = api.getAllTypes()
476         visitor = ComplexValueSerializer(self.serializerFactory())
477         map(visitor.visit, types)
478         print
479
480         # Interfaces wrapers
481         self.traceInterfaces(api)
482
483         # Function wrappers
484         self.interface = None
485         self.base = None
486         for function in api.getAllFunctions():
487             self.traceFunctionDecl(function)
488         for function in api.getAllFunctions():
489             self.traceFunctionImpl(function)
490         print
491
492         self.footer(api)
493
494     def header(self, api):
495         print '#ifdef _WIN32'
496         print '#  include <malloc.h> // alloca'
497         print '#  ifndef alloca'
498         print '#    define alloca _alloca'
499         print '#  endif'
500         print '#else'
501         print '#  include <alloca.h> // alloca'
502         print '#endif'
503         print
504         print '#include "trace.hpp"'
505         print
506         print 'static std::map<void *, void *> g_WrappedObjects;'
507
508     def footer(self, api):
509         pass
510
511     def traceFunctionDecl(self, function):
512         # Per-function declarations
513
514         if not function.internal:
515             if function.args:
516                 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
517             else:
518                 print 'static const char ** _%s_args = NULL;' % (function.name,)
519             print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, self.getFunctionSigId(), function.name, len(function.args), function.name)
520             print
521
522     def getFunctionSigId(self):
523         id = Tracer.__id
524         Tracer.__id += 1
525         return id
526
527     def isFunctionPublic(self, function):
528         return True
529
530     def traceFunctionImpl(self, function):
531         if self.isFunctionPublic(function):
532             print 'extern "C" PUBLIC'
533         else:
534             print 'extern "C" PRIVATE'
535         print function.prototype() + ' {'
536         if function.type is not stdapi.Void:
537             print '    %s _result;' % function.type
538
539         # No-op if tracing is disabled
540         print '    if (!trace::isTracingEnabled()) {'
541         self.doInvokeFunction(function)
542         if function.type is not stdapi.Void:
543             print '        return _result;'
544         else:
545             print '        return;'
546         print '    }'
547
548         self.traceFunctionImplBody(function)
549         if function.type is not stdapi.Void:
550             print '    return _result;'
551         print '}'
552         print
553
554     def traceFunctionImplBody(self, function):
555         if not function.internal:
556             print '    unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
557             for arg in function.args:
558                 if not arg.output:
559                     self.unwrapArg(function, arg)
560             for arg in function.args:
561                 if not arg.output:
562                     self.serializeArg(function, arg)
563             print '    trace::localWriter.endEnter();'
564         self.invokeFunction(function)
565         if not function.internal:
566             print '    trace::localWriter.beginLeave(_call);'
567             print '    if (%s) {' % self.wasFunctionSuccessful(function)
568             for arg in function.args:
569                 if arg.output:
570                     self.serializeArg(function, arg)
571                     self.wrapArg(function, arg)
572             print '    }'
573             if function.type is not stdapi.Void:
574                 self.serializeRet(function, "_result")
575             if function.type is not stdapi.Void:
576                 self.wrapRet(function, "_result")
577             print '    trace::localWriter.endLeave();'
578
579     def invokeFunction(self, function):
580         self.doInvokeFunction(function)
581
582     def doInvokeFunction(self, function, prefix='_', suffix=''):
583         # Same as invokeFunction() but called both when trace is enabled or disabled.
584         if function.type is stdapi.Void:
585             result = ''
586         else:
587             result = '_result = '
588         dispatch = prefix + function.name + suffix
589         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
590
591     def wasFunctionSuccessful(self, function):
592         if function.type is stdapi.Void:
593             return 'true'
594         if str(function.type) == 'HRESULT':
595             return 'SUCCEEDED(_result)'
596         return 'true'
597
598     def serializeArg(self, function, arg):
599         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
600         self.serializeArgValue(function, arg)
601         print '    trace::localWriter.endArg();'
602
603     def serializeArgValue(self, function, arg):
604         self.serializeValue(arg.type, arg.name)
605
606     def wrapArg(self, function, arg):
607         assert not isinstance(arg.type, stdapi.ObjPointer)
608
609         from specs.winapi import REFIID
610         riid = None
611         for other_arg in function.args:
612             if not other_arg.output and other_arg.type is REFIID:
613                 riid = other_arg
614         if riid is not None \
615            and isinstance(arg.type, stdapi.Pointer) \
616            and isinstance(arg.type.type, stdapi.ObjPointer):
617             self.wrapIid(function, riid, arg)
618             return
619
620         self.wrapValue(arg.type, arg.name)
621
622     def unwrapArg(self, function, arg):
623         self.unwrapValue(arg.type, arg.name)
624
625     def serializeRet(self, function, instance):
626         print '    trace::localWriter.beginReturn();'
627         self.serializeValue(function.type, instance)
628         print '    trace::localWriter.endReturn();'
629
630     def serializeValue(self, type, instance):
631         serializer = self.serializerFactory()
632         serializer.visit(type, instance)
633
634     def wrapRet(self, function, instance):
635         self.wrapValue(function.type, instance)
636
637     def needsWrapping(self, type):
638         visitor = WrapDecider()
639         visitor.visit(type)
640         return visitor.needsWrapping
641
642     def wrapValue(self, type, instance):
643         if self.needsWrapping(type):
644             visitor = ValueWrapper()
645             visitor.visit(type, instance)
646
647     def unwrapValue(self, type, instance):
648         if self.needsWrapping(type):
649             visitor = ValueUnwrapper()
650             visitor.visit(type, instance)
651
652     def traceInterfaces(self, api):
653         interfaces = api.getAllInterfaces()
654         if not interfaces:
655             return
656         map(self.declareWrapperInterface, interfaces)
657         self.implementIidWrapper(api)
658         map(self.implementWrapperInterface, interfaces)
659         print
660
661     def declareWrapperInterface(self, interface):
662         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
663         print "{"
664         print "private:"
665         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
666         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
667         print "public:"
668         print "    static %s* _Create(const char *functionName, %s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
669         print
670         for method in interface.iterMethods():
671             print "    " + method.prototype() + ";"
672         print
673         #print "private:"
674         for type, name, value in self.enumWrapperInterfaceVariables(interface):
675             print '    %s %s;' % (type, name)
676         for i in range(64):
677             print r'    virtual void _dummy%i(void) const {' % i
678             print r'        os::log("error: %s: unexpected virtual method\n");' % interface.name
679             print r'        os::abort();'
680             print r'    }'
681         print "};"
682         print
683
684     def enumWrapperInterfaceVariables(self, interface):
685         return [
686             ("DWORD", "m_dwMagic", "0xd8365d6c"),
687             ("%s *" % interface.name, "m_pInstance", "pInstance"),
688             ("void *", "m_pVtbl", "*(void **)pInstance"),
689             ("UINT", "m_NumMethods", len(list(interface.iterBaseMethods()))),
690         ] 
691
692     def implementWrapperInterface(self, interface):
693         self.interface = interface
694
695         # Private constructor
696         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
697         for type, name, value in self.enumWrapperInterfaceVariables(interface):
698             if value is not None:
699                 print '    %s = %s;' % (name, value)
700         print '}'
701         print
702
703         # Public constructor
704         print '%s *%s::_Create(const char *functionName, %s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
705         print r'    std::map<void *, void *>::const_iterator it = g_WrappedObjects.find(pInstance);'
706         print r'    if (it != g_WrappedObjects.end()) {'
707         print r'        Wrap%s *pWrapper = (Wrap%s *)it->second;' % (interface.name, interface.name)
708         print r'        assert(pWrapper);'
709         print r'        assert(pWrapper->m_dwMagic == 0xd8365d6c);'
710         print r'        assert(pWrapper->m_pInstance == pInstance);'
711         print r'        if (pWrapper->m_pVtbl == *(void **)pInstance &&'
712         print r'            pWrapper->m_NumMethods >= %s) {' % len(list(interface.iterBaseMethods()))
713         #print r'            os::log("%s: fetched pvObj=%p pWrapper=%p pVtbl=%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);'
714         print r'            return pWrapper;'
715         print r'        }'
716         print r'    }'
717         print r'    Wrap%s *pWrapper = new Wrap%s(pInstance);' % (interface.name, interface.name)
718         #print r'    os::log("%%s: created %s pvObj=%%p pWrapper=%%p pVtbl=%%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' % interface.name
719         print r'    g_WrappedObjects[pInstance] = pWrapper;'
720         print r'    return pWrapper;'
721         print '}'
722         print
723
724         # Destructor
725         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
726         #print r'        os::log("%s::Release: deleted pvObj=%%p pWrapper=%%p pVtbl=%%p\n", m_pInstance, this, m_pVtbl);' % interface.name
727         print r'        g_WrappedObjects.erase(m_pInstance);'
728         print '}'
729         print
730         
731         for base, method in interface.iterBaseMethods():
732             self.base = base
733             self.implementWrapperInterfaceMethod(interface, base, method)
734
735         print
736
737     def implementWrapperInterfaceMethod(self, interface, base, method):
738         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
739
740         if False:
741             print r'    os::log("%%s(%%p -> %%p)\n", "%s", this, m_pInstance);' % (getWrapperInterfaceName(interface) + '::' + method.name)
742
743         if method.type is not stdapi.Void:
744             print '    %s _result;' % method.type
745     
746         self.implementWrapperInterfaceMethodBody(interface, base, method)
747     
748         if method.type is not stdapi.Void:
749             print '    return _result;'
750         print '}'
751         print
752
753     def implementWrapperInterfaceMethodBody(self, interface, base, method):
754         assert not method.internal
755
756         print '    static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
757         print '    static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (self.getFunctionSigId(), interface.name + '::' + method.name, len(method.args) + 1)
758
759         print '    %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
760
761         print '    unsigned _call = trace::localWriter.beginEnter(&_sig);'
762         print '    trace::localWriter.beginArg(0);'
763         print '    trace::localWriter.writePointer((uintptr_t)m_pInstance);'
764         print '    trace::localWriter.endArg();'
765         for arg in method.args:
766             if not arg.output:
767                 self.unwrapArg(method, arg)
768         for arg in method.args:
769             if not arg.output:
770                 self.serializeArg(method, arg)
771         print '    trace::localWriter.endEnter();'
772         
773         self.invokeMethod(interface, base, method)
774
775         print '    trace::localWriter.beginLeave(_call);'
776
777         print '    if (%s) {' % self.wasFunctionSuccessful(method)
778         for arg in method.args:
779             if arg.output:
780                 self.serializeArg(method, arg)
781                 self.wrapArg(method, arg)
782         print '    }'
783
784         if method.type is not stdapi.Void:
785             self.serializeRet(method, '_result')
786         if method.type is not stdapi.Void:
787             self.wrapRet(method, '_result')
788
789         if method.name == 'Release':
790             assert method.type is not stdapi.Void
791             print r'    if (!_result) {'
792             print r'        delete this;'
793             print r'    }'
794         
795         print '    trace::localWriter.endLeave();'
796
797     def implementIidWrapper(self, api):
798         print r'static void'
799         print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
800         print r'    os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
801         print r'            functionName, reason,'
802         print r'            riid.Data1, riid.Data2, riid.Data3,'
803         print r'            riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
804         print r'}'
805         print 
806         print r'static void'
807         print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
808         print r'    if (!ppvObj || !*ppvObj) {'
809         print r'        return;'
810         print r'    }'
811         else_ = ''
812         for iface in api.getAllInterfaces():
813             print r'    %sif (riid == IID_%s) {' % (else_, iface.name)
814             print r'        *ppvObj = Wrap%s::_Create(functionName, (%s *) *ppvObj);' % (iface.name, iface.name)
815             print r'    }'
816             else_ = 'else '
817         print r'    %s{' % else_
818         print r'        warnIID(functionName, riid, "unknown");'
819         print r'    }'
820         print r'}'
821         print
822
823     def wrapIid(self, function, riid, out):
824         # Cast output arg to `void **` if necessary
825         out_name = out.name
826         obj_type = out.type.type.type
827         if not obj_type is stdapi.Void:
828             assert isinstance(obj_type, stdapi.Interface)
829             out_name = 'reinterpret_cast<void * *>(%s)' % out_name
830
831         print r'    if (%s && *%s) {' % (out.name, out.name)
832         functionName = function.name
833         else_ = ''
834         if self.interface is not None:
835             functionName = self.interface.name + '::' + functionName
836             print r'        if (*%s == m_pInstance &&' % (out_name,)
837             print r'            (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
838             print r'            *%s = this;' % (out_name,)
839             print r'        }'
840             else_ = 'else '
841         print r'        %s{' % else_
842         print r'             wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
843         print r'        }'
844         print r'    }'
845
846     def invokeMethod(self, interface, base, method):
847         if method.type is stdapi.Void:
848             result = ''
849         else:
850             result = '_result = '
851         print '    %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
852     
853     def emit_memcpy(self, dest, src, length):
854         print '        unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig, true);'
855         print '        trace::localWriter.beginArg(0);'
856         print '        trace::localWriter.writePointer((uintptr_t)%s);' % dest
857         print '        trace::localWriter.endArg();'
858         print '        trace::localWriter.beginArg(1);'
859         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
860         print '        trace::localWriter.endArg();'
861         print '        trace::localWriter.beginArg(2);'
862         print '        trace::localWriter.writeUInt(%s);' % length
863         print '        trace::localWriter.endArg();'
864         print '        trace::localWriter.endEnter();'
865         print '        trace::localWriter.beginLeave(_call);'
866         print '        trace::localWriter.endLeave();'
867     
868     def fake_call(self, function, args):
869         print '            unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
870         for arg, instance in zip(function.args, args):
871             assert not arg.output
872             print '            trace::localWriter.beginArg(%u);' % (arg.index,)
873             self.serializeValue(arg.type, instance)
874             print '            trace::localWriter.endArg();'
875         print '            trace::localWriter.endEnter();'
876         print '            trace::localWriter.beginLeave(_fake_call);'
877         print '            trace::localWriter.endLeave();'
878