]> git.cworth.org Git - apitrace/blob - wrappers/trace.py
10f71a4dc7a32362dfafe3d96d68705ed3578db8
[apitrace] / wrappers / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 # Adjust path
30 import os.path
31 import sys
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
33
34
35 import specs.stdapi as stdapi
36
37
38 def getWrapperInterfaceName(interface):
39     return "Wrap" + interface.expr
40
41
42
43 class ComplexValueSerializer(stdapi.OnceVisitor):
44     '''Type visitors which generates serialization functions for
45     complex types.
46     
47     Simple types are serialized inline.
48     '''
49
50     def __init__(self, serializer):
51         stdapi.OnceVisitor.__init__(self)
52         self.serializer = serializer
53
54     def visitVoid(self, literal):
55         pass
56
57     def visitLiteral(self, literal):
58         pass
59
60     def visitString(self, string):
61         pass
62
63     def visitConst(self, const):
64         self.visit(const.type)
65
66     def visitStruct(self, struct):
67         print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
68         for type, name,  in struct.members:
69             if name is None:
70                 print '    "",'
71             else:
72                 print '    "%s",' % (name,)
73         print '};'
74         print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
75         if struct.name is None:
76             structName = '""'
77         else:
78             structName = '"%s"' % struct.name
79         print '    %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag)
80         print '};'
81         print
82
83     def visitArray(self, array):
84         self.visit(array.type)
85
86     def visitAttribArray(self, array):
87         pass
88
89     def visitBlob(self, array):
90         pass
91
92     def visitEnum(self, enum):
93         print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
94         for value in enum.values:
95             print '    {"%s", %s},' % (value, value)
96         print '};'
97         print
98         print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
99         print '    %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
100         print '};'
101         print
102
103     def visitBitmask(self, bitmask):
104         print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
105         for value in bitmask.values:
106             print '    {"%s", %s},' % (value, value)
107         print '};'
108         print
109         print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
110         print '    %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
111         print '};'
112         print
113
114     def visitPointer(self, pointer):
115         self.visit(pointer.type)
116
117     def visitIntPointer(self, pointer):
118         pass
119
120     def visitObjPointer(self, pointer):
121         self.visit(pointer.type)
122
123     def visitLinearPointer(self, pointer):
124         self.visit(pointer.type)
125
126     def visitHandle(self, handle):
127         self.visit(handle.type)
128
129     def visitReference(self, reference):
130         self.visit(reference.type)
131
132     def visitAlias(self, alias):
133         self.visit(alias.type)
134
135     def visitOpaque(self, opaque):
136         pass
137
138     def visitInterface(self, interface):
139         pass
140
141     def visitPolymorphic(self, polymorphic):
142         if not polymorphic.contextLess:
143             return
144         print 'static void _write__%s(int selector, %s const & value) {' % (polymorphic.tag, polymorphic.expr)
145         print '    switch (selector) {'
146         for cases, type in polymorphic.iterSwitch():
147             for case in cases:
148                 print '    %s:' % case
149             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
150             print '        break;'
151         print '    }'
152         print '}'
153         print
154
155
156 class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin):
157     '''Visitor which generates code to serialize any type.
158     
159     Simple types are serialized inline here, whereas the serialization of
160     complex types is dispatched to the serialization functions generated by
161     ComplexValueSerializer visitor above.
162     '''
163
164     def visitLiteral(self, literal, instance):
165         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
166
167     def visitString(self, string, instance):
168         if not string.wide:
169             cast = 'const char *'
170             suffix = 'String'
171         else:
172             cast = 'const wchar_t *'
173             suffix = 'WString'
174         if cast != string.expr:
175             # reinterpret_cast is necessary for GLubyte * <=> char *
176             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
177         if string.length is not None:
178             length = ', %s' % self.expand(string.length)
179         else:
180             length = ''
181         print '    trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
182
183     def visitConst(self, const, instance):
184         self.visit(const.type, instance)
185
186     def visitStruct(self, struct, instance):
187         print '    trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
188         for member in struct.members:
189             self.visitMember(member, instance)
190         print '    trace::localWriter.endStruct();'
191
192     def visitArray(self, array, instance):
193         length = '_c' + array.type.tag
194         index = '_i' + array.type.tag
195         array_length = self.expand(array.length)
196         print '    if (%s) {' % instance
197         print '        size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
198         print '        trace::localWriter.beginArray(%s);' % length
199         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
200         print '            trace::localWriter.beginElement();'
201         self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
202         print '            trace::localWriter.endElement();'
203         print '        }'
204         print '        trace::localWriter.endArray();'
205         print '    } else {'
206         print '        trace::localWriter.writeNull();'
207         print '    }'
208
209     def visitAttribArray(self, array, instance):
210         # For each element, decide if it is a key or a value (which depends on the previous key).
211         # If it is a value, store it as the right type - usually int, some bitfield, or some enum.
212         # It is currently assumed that an unknown key means that it is followed by an int value.
213
214         # determine the array length which must be passed to writeArray() up front
215         count = '_c' + array.keyType.tag
216         print '    {'
217         print '    int %s;' % count
218         print '    for (%(c)s = 0; %(array)s && %(array)s[%(c)s] != %(terminator)s; %(c)s += 2) {' \
219               % {'c': count, 'array': instance, 'terminator': array.terminator}
220         if array.hasKeysWithoutValues:
221             print '        switch (int(%(array)s[%(c)s])) {' % {'array': instance, 'c': count}
222             for key, valueType in array.valueTypes:
223                 if valueType is None:
224                     print '        case %s:' % key
225             print '            %s--;' % count # the next value is a key again and checked if it's the terminator
226             print '            break;'
227             print '        }'
228         print '    }'
229         print '    %(c)s += %(array)s ? 1 : 0;' % {'c': count, 'array': instance}
230         print '    trace::localWriter.beginArray(%s);' % count
231
232         # for each key / key-value pair write the key and the value, if the key requires one
233
234         index = '_i' + array.keyType.tag
235         print '    for (int %(i)s = 0; %(i)s < %(count)s; %(i)s++) {' % {'i': index, 'count': count}
236         print '        trace::localWriter.beginElement();'
237         self.visitEnum(array.keyType, "%(array)s[%(i)s]" % {'array': instance, 'i': index})
238         print '        trace::localWriter.endElement();'
239         print '        if (%(i)s + 1 >= %(count)s) {' % {'i': index, 'count': count}
240         print '            break;'
241         print '        }'
242         print '        switch (int(%(array)s[%(i)s++])) {' % {'array': instance, 'i': index}
243         # write generic value the usual way
244         for key, valueType in array.valueTypes:
245             if valueType is not None:
246                 print '        case %s:' % key
247                 print '            trace::localWriter.beginElement();'
248                 self.visitElement(index, valueType, '(%(array)s)[%(i)s]' % {'array': instance, 'i': index})
249                 print '            trace::localWriter.endElement();'
250                 print '            break;'
251         # known key with no value, just decrease the index so we treat the next value as a key
252         if array.hasKeysWithoutValues:
253             for key, valueType in array.valueTypes:
254                 if valueType is None:
255                     print '        case %s:' % key
256             print '            %s--;' % index
257             print '            break;'
258         # unknown key, write an int value
259         print '        default:'
260         print '            trace::localWriter.beginElement();'
261         print '            os::log("apitrace: warning: %s: unknown key 0x%04X, interpreting value as int\\n", ' + \
262                            '__FUNCTION__, int(%(array)s[%(i)s]));'  % {'array': instance, 'i': index}
263         print '            trace::localWriter.writeSInt(%(array)s[%(i)s]);' % {'array': instance, 'i': index}
264         print '            trace::localWriter.endElement();'
265         print '            break;'
266         print '        }'
267         print '    }'
268         print '    trace::localWriter.endArray();'
269         print '    }'
270
271
272     def visitBlob(self, blob, instance):
273         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
274
275     def visitEnum(self, enum, instance):
276         print '    trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
277
278     def visitBitmask(self, bitmask, instance):
279         print '    trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
280
281     def visitPointer(self, pointer, instance):
282         print '    if (%s) {' % instance
283         print '        trace::localWriter.beginArray(1);'
284         print '        trace::localWriter.beginElement();'
285         self.visit(pointer.type, "*" + instance)
286         print '        trace::localWriter.endElement();'
287         print '        trace::localWriter.endArray();'
288         print '    } else {'
289         print '        trace::localWriter.writeNull();'
290         print '    }'
291
292     def visitIntPointer(self, pointer, instance):
293         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
294
295     def visitObjPointer(self, pointer, instance):
296         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
297
298     def visitLinearPointer(self, pointer, instance):
299         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
300
301     def visitReference(self, reference, instance):
302         self.visit(reference.type, instance)
303
304     def visitHandle(self, handle, instance):
305         self.visit(handle.type, instance)
306
307     def visitAlias(self, alias, instance):
308         self.visit(alias.type, instance)
309
310     def visitOpaque(self, opaque, instance):
311         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
312
313     def visitInterface(self, interface, instance):
314         assert False
315
316     def visitPolymorphic(self, polymorphic, instance):
317         if polymorphic.contextLess:
318             print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
319         else:
320             switchExpr = self.expand(polymorphic.switchExpr)
321             print '    switch (%s) {' % switchExpr
322             for cases, type in polymorphic.iterSwitch():
323                 for case in cases:
324                     print '    %s:' % case
325                 caseInstance = instance
326                 if type.expr is not None:
327                     caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
328                 self.visit(type, caseInstance)
329                 print '        break;'
330             if polymorphic.defaultType is None:
331                 print r'    default:'
332                 print r'        os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,)
333                 print r'        trace::localWriter.writeNull();'
334                 print r'        break;'
335             print '    }'
336
337
338 class WrapDecider(stdapi.Traverser):
339     '''Type visitor which will decide wheter this type will need wrapping or not.
340     
341     For complex types (arrays, structures), we need to know this before hand.
342     '''
343
344     def __init__(self):
345         self.needsWrapping = False
346
347     def visitLinearPointer(self, void):
348         pass
349
350     def visitInterface(self, interface):
351         self.needsWrapping = True
352
353
354 class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin):
355     '''Type visitor which will generate the code to wrap an instance.
356     
357     Wrapping is necessary mostly for interfaces, however interface pointers can
358     appear anywhere inside complex types.
359     '''
360
361     def visitStruct(self, struct, instance):
362         for member in struct.members:
363             self.visitMember(member, instance)
364
365     def visitArray(self, array, instance):
366         array_length = self.expand(array.length)
367         print "    if (%s) {" % instance
368         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
369         self.visitElement('_i', array.type, instance + "[_i]")
370         print "        }"
371         print "    }"
372
373     def visitPointer(self, pointer, instance):
374         print "    if (%s) {" % instance
375         self.visit(pointer.type, "*" + instance)
376         print "    }"
377     
378     def visitObjPointer(self, pointer, instance):
379         elem_type = pointer.type.mutable()
380         if isinstance(elem_type, stdapi.Interface):
381             self.visitInterfacePointer(elem_type, instance)
382         elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
383             self.visitInterfacePointer(elem_type.type, instance)
384         else:
385             self.visitPointer(pointer, instance)
386     
387     def visitInterface(self, interface, instance):
388         raise NotImplementedError
389
390     def visitInterfacePointer(self, interface, instance):
391         print "    if (%s) {" % instance
392         print "        %s = %s::_Create(__FUNCTION__, %s);" % (instance, getWrapperInterfaceName(interface), instance)
393         print "    }"
394     
395     def visitPolymorphic(self, type, instance):
396         # XXX: There might be polymorphic values that need wrapping in the future
397         raise NotImplementedError
398
399
400 class ValueUnwrapper(ValueWrapper):
401     '''Reverse of ValueWrapper.'''
402
403     allocated = False
404
405     def visitStruct(self, struct, instance):
406         if not self.allocated:
407             # Argument is constant. We need to create a non const
408             print '    {'
409             print "        %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
410             print '        *_t = %s;' % (instance,)
411             assert instance.startswith('*')
412             print '        %s = _t;' % (instance[1:],)
413             instance = '*_t'
414             self.allocated = True
415             try:
416                 return ValueWrapper.visitStruct(self, struct, instance)
417             finally:
418                 print '    }'
419         else:
420             return ValueWrapper.visitStruct(self, struct, instance)
421
422     def visitArray(self, array, instance):
423         if self.allocated or isinstance(instance, stdapi.Interface):
424             return ValueWrapper.visitArray(self, array, instance)
425         array_length = self.expand(array.length)
426         elem_type = array.type.mutable()
427         print "    if (%s && %s) {" % (instance, array_length)
428         print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
429         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
430         print "            _t[_i] = %s[_i];" % instance 
431         self.allocated = True
432         self.visit(array.type, "_t[_i]")
433         print "        }"
434         print "        %s = _t;" % instance
435         print "    }"
436
437     def visitInterfacePointer(self, interface, instance):
438         print r'    if (%s) {' % instance
439         print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
440         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
441         print r'            %s = pWrapper->m_pInstance;' % (instance,)
442         print r'        } else {'
443         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
444         print r'        }'
445         print r'    }'
446
447
448 class Tracer:
449     '''Base class to orchestrate the code generation of API tracing.'''
450
451     # 0-3 are reserved to memcpy, malloc, free, and realloc
452     __id = 4
453
454     def __init__(self):
455         self.api = None
456
457     def serializerFactory(self):
458         '''Create a serializer.
459         
460         Can be overriden by derived classes to inject their own serialzer.
461         '''
462
463         return ValueSerializer()
464
465     def traceApi(self, api):
466         self.api = api
467
468         self.header(api)
469
470         # Includes
471         for module in api.modules:
472             for header in module.headers:
473                 print header
474         print
475
476         # Generate the serializer functions
477         types = api.getAllTypes()
478         visitor = ComplexValueSerializer(self.serializerFactory())
479         map(visitor.visit, types)
480         print
481
482         # Interfaces wrapers
483         self.traceInterfaces(api)
484
485         # Function wrappers
486         self.interface = None
487         self.base = None
488         for function in api.getAllFunctions():
489             self.traceFunctionDecl(function)
490         for function in api.getAllFunctions():
491             self.traceFunctionImpl(function)
492         print
493
494         self.footer(api)
495
496     def header(self, api):
497         print '#ifdef _WIN32'
498         print '#  include <malloc.h> // alloca'
499         print '#  ifndef alloca'
500         print '#    define alloca _alloca'
501         print '#  endif'
502         print '#else'
503         print '#  include <alloca.h> // alloca'
504         print '#endif'
505         print
506         print '#include "trace.hpp"'
507         print
508         print 'static std::map<void *, void *> g_WrappedObjects;'
509
510     def footer(self, api):
511         pass
512
513     def traceFunctionDecl(self, function):
514         # Per-function declarations
515
516         if not function.internal:
517             if function.args:
518                 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
519             else:
520                 print 'static const char ** _%s_args = NULL;' % (function.name,)
521             print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, self.getFunctionSigId(), function.name, len(function.args), function.name)
522             print
523
524     def getFunctionSigId(self):
525         id = Tracer.__id
526         Tracer.__id += 1
527         return id
528
529     def isFunctionPublic(self, function):
530         return True
531
532     def traceFunctionImpl(self, function):
533         if self.isFunctionPublic(function):
534             print 'extern "C" PUBLIC'
535         else:
536             print 'extern "C" PRIVATE'
537         print function.prototype() + ' {'
538         if function.type is not stdapi.Void:
539             print '    %s _result;' % function.type
540
541         # No-op if tracing is disabled
542         print '    if (!trace::isTracingEnabled()) {'
543         self.doInvokeFunction(function)
544         if function.type is not stdapi.Void:
545             print '        return _result;'
546         else:
547             print '        return;'
548         print '    }'
549
550         self.traceFunctionImplBody(function)
551         if function.type is not stdapi.Void:
552             print '    return _result;'
553         print '}'
554         print
555
556     def traceFunctionImplBody(self, function):
557         if not function.internal:
558             print '    unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
559             for arg in function.args:
560                 if not arg.output:
561                     self.unwrapArg(function, arg)
562             for arg in function.args:
563                 if not arg.output:
564                     self.serializeArg(function, arg)
565             print '    trace::localWriter.endEnter();'
566         self.invokeFunction(function)
567         if not function.internal:
568             print '    trace::localWriter.beginLeave(_call);'
569             print '    if (%s) {' % self.wasFunctionSuccessful(function)
570             for arg in function.args:
571                 if arg.output:
572                     self.serializeArg(function, arg)
573                     self.wrapArg(function, arg)
574             print '    }'
575             if function.type is not stdapi.Void:
576                 self.serializeRet(function, "_result")
577             if function.type is not stdapi.Void:
578                 self.wrapRet(function, "_result")
579             print '    trace::localWriter.endLeave();'
580
581     def invokeFunction(self, function):
582         self.doInvokeFunction(function)
583
584     def doInvokeFunction(self, function, prefix='_', suffix=''):
585         # Same as invokeFunction() but called both when trace is enabled or disabled.
586         if function.type is stdapi.Void:
587             result = ''
588         else:
589             result = '_result = '
590         dispatch = prefix + function.name + suffix
591         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
592
593     def wasFunctionSuccessful(self, function):
594         if function.type is stdapi.Void:
595             return 'true'
596         if str(function.type) == 'HRESULT':
597             return 'SUCCEEDED(_result)'
598         return 'true'
599
600     def serializeArg(self, function, arg):
601         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
602         self.serializeArgValue(function, arg)
603         print '    trace::localWriter.endArg();'
604
605     def serializeArgValue(self, function, arg):
606         self.serializeValue(arg.type, arg.name)
607
608     def wrapArg(self, function, arg):
609         assert not isinstance(arg.type, stdapi.ObjPointer)
610
611         from specs.winapi import REFIID
612         riid = None
613         for other_arg in function.args:
614             if not other_arg.output and other_arg.type is REFIID:
615                 riid = other_arg
616         if riid is not None \
617            and isinstance(arg.type, stdapi.Pointer) \
618            and isinstance(arg.type.type, stdapi.ObjPointer):
619             self.wrapIid(function, riid, arg)
620             return
621
622         self.wrapValue(arg.type, arg.name)
623
624     def unwrapArg(self, function, arg):
625         self.unwrapValue(arg.type, arg.name)
626
627     def serializeRet(self, function, instance):
628         print '    trace::localWriter.beginReturn();'
629         self.serializeValue(function.type, instance)
630         print '    trace::localWriter.endReturn();'
631
632     def serializeValue(self, type, instance):
633         serializer = self.serializerFactory()
634         serializer.visit(type, instance)
635
636     def wrapRet(self, function, instance):
637         self.wrapValue(function.type, instance)
638
639     def needsWrapping(self, type):
640         visitor = WrapDecider()
641         visitor.visit(type)
642         return visitor.needsWrapping
643
644     def wrapValue(self, type, instance):
645         if self.needsWrapping(type):
646             visitor = ValueWrapper()
647             visitor.visit(type, instance)
648
649     def unwrapValue(self, type, instance):
650         if self.needsWrapping(type):
651             visitor = ValueUnwrapper()
652             visitor.visit(type, instance)
653
654     def traceInterfaces(self, api):
655         interfaces = api.getAllInterfaces()
656         if not interfaces:
657             return
658         map(self.declareWrapperInterface, interfaces)
659         self.implementIidWrapper(api)
660         map(self.implementWrapperInterface, interfaces)
661         print
662
663     def declareWrapperInterface(self, interface):
664         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
665         print "{"
666         print "private:"
667         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
668         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
669         print "public:"
670         print "    static %s* _Create(const char *functionName, %s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
671         print
672         for method in interface.iterMethods():
673             print "    " + method.prototype() + ";"
674         print
675         #print "private:"
676         for type, name, value in self.enumWrapperInterfaceVariables(interface):
677             print '    %s %s;' % (type, name)
678         for i in range(64):
679             print r'    virtual void _dummy%i(void) const {' % i
680             print r'        os::log("error: %s: unexpected virtual method\n");' % interface.name
681             print r'        os::abort();'
682             print r'    }'
683         print "};"
684         print
685
686     def enumWrapperInterfaceVariables(self, interface):
687         return [
688             ("DWORD", "m_dwMagic", "0xd8365d6c"),
689             ("%s *" % interface.name, "m_pInstance", "pInstance"),
690             ("void *", "m_pVtbl", "*(void **)pInstance"),
691             ("UINT", "m_NumMethods", len(list(interface.iterBaseMethods()))),
692         ] 
693
694     def implementWrapperInterface(self, interface):
695         self.interface = interface
696
697         # Private constructor
698         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
699         for type, name, value in self.enumWrapperInterfaceVariables(interface):
700             if value is not None:
701                 print '    %s = %s;' % (name, value)
702         print '}'
703         print
704
705         # Public constructor
706         print '%s *%s::_Create(const char *functionName, %s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
707         print r'    std::map<void *, void *>::const_iterator it = g_WrappedObjects.find(pInstance);'
708         print r'    if (it != g_WrappedObjects.end()) {'
709         print r'        Wrap%s *pWrapper = (Wrap%s *)it->second;' % (interface.name, interface.name)
710         print r'        assert(pWrapper);'
711         print r'        assert(pWrapper->m_dwMagic == 0xd8365d6c);'
712         print r'        assert(pWrapper->m_pInstance == pInstance);'
713         print r'        if (pWrapper->m_pVtbl == *(void **)pInstance &&'
714         print r'            pWrapper->m_NumMethods >= %s) {' % len(list(interface.iterBaseMethods()))
715         #print r'            os::log("%s: fetched pvObj=%p pWrapper=%p pVtbl=%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);'
716         print r'            return pWrapper;'
717         print r'        }'
718         print r'    }'
719         print r'    Wrap%s *pWrapper = new Wrap%s(pInstance);' % (interface.name, interface.name)
720         #print r'    os::log("%%s: created %s pvObj=%%p pWrapper=%%p pVtbl=%%p\n", functionName, pInstance, pWrapper, pWrapper->m_pVtbl);' % interface.name
721         print r'    g_WrappedObjects[pInstance] = pWrapper;'
722         print r'    return pWrapper;'
723         print '}'
724         print
725
726         # Destructor
727         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
728         #print r'        os::log("%s::Release: deleted pvObj=%%p pWrapper=%%p pVtbl=%%p\n", m_pInstance, this, m_pVtbl);' % interface.name
729         print r'        g_WrappedObjects.erase(m_pInstance);'
730         print '}'
731         print
732         
733         for base, method in interface.iterBaseMethods():
734             self.base = base
735             self.implementWrapperInterfaceMethod(interface, base, method)
736
737         print
738
739     def implementWrapperInterfaceMethod(self, interface, base, method):
740         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
741
742         if False:
743             print r'    os::log("%%s(%%p -> %%p)\n", "%s", this, m_pInstance);' % (getWrapperInterfaceName(interface) + '::' + method.name)
744
745         if method.type is not stdapi.Void:
746             print '    %s _result;' % method.type
747     
748         self.implementWrapperInterfaceMethodBody(interface, base, method)
749     
750         if method.type is not stdapi.Void:
751             print '    return _result;'
752         print '}'
753         print
754
755     def implementWrapperInterfaceMethodBody(self, interface, base, method):
756         assert not method.internal
757
758         print '    static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
759         print '    static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (self.getFunctionSigId(), interface.name + '::' + method.name, len(method.args) + 1)
760
761         print '    %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
762
763         print '    unsigned _call = trace::localWriter.beginEnter(&_sig);'
764         print '    trace::localWriter.beginArg(0);'
765         print '    trace::localWriter.writePointer((uintptr_t)m_pInstance);'
766         print '    trace::localWriter.endArg();'
767         for arg in method.args:
768             if not arg.output:
769                 self.unwrapArg(method, arg)
770         for arg in method.args:
771             if not arg.output:
772                 self.serializeArg(method, arg)
773         print '    trace::localWriter.endEnter();'
774         
775         self.invokeMethod(interface, base, method)
776
777         print '    trace::localWriter.beginLeave(_call);'
778
779         print '    if (%s) {' % self.wasFunctionSuccessful(method)
780         for arg in method.args:
781             if arg.output:
782                 self.serializeArg(method, arg)
783                 self.wrapArg(method, arg)
784         print '    }'
785
786         if method.type is not stdapi.Void:
787             self.serializeRet(method, '_result')
788         if method.type is not stdapi.Void:
789             self.wrapRet(method, '_result')
790
791         if method.name == 'Release':
792             assert method.type is not stdapi.Void
793             print r'    if (!_result) {'
794             print r'        delete this;'
795             print r'    }'
796         
797         print '    trace::localWriter.endLeave();'
798
799     def implementIidWrapper(self, api):
800         print r'static void'
801         print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
802         print r'    os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
803         print r'            functionName, reason,'
804         print r'            riid.Data1, riid.Data2, riid.Data3,'
805         print r'            riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
806         print r'}'
807         print 
808         print r'static void'
809         print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
810         print r'    if (!ppvObj || !*ppvObj) {'
811         print r'        return;'
812         print r'    }'
813         else_ = ''
814         for iface in api.getAllInterfaces():
815             print r'    %sif (riid == IID_%s) {' % (else_, iface.name)
816             print r'        *ppvObj = Wrap%s::_Create(functionName, (%s *) *ppvObj);' % (iface.name, iface.name)
817             print r'    }'
818             else_ = 'else '
819         print r'    %s{' % else_
820         print r'        warnIID(functionName, riid, "unknown");'
821         print r'    }'
822         print r'}'
823         print
824
825     def wrapIid(self, function, riid, out):
826         # Cast output arg to `void **` if necessary
827         out_name = out.name
828         obj_type = out.type.type.type
829         if not obj_type is stdapi.Void:
830             assert isinstance(obj_type, stdapi.Interface)
831             out_name = 'reinterpret_cast<void * *>(%s)' % out_name
832
833         print r'    if (%s && *%s) {' % (out.name, out.name)
834         functionName = function.name
835         else_ = ''
836         if self.interface is not None:
837             functionName = self.interface.name + '::' + functionName
838             print r'        if (*%s == m_pInstance &&' % (out_name,)
839             print r'            (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
840             print r'            *%s = this;' % (out_name,)
841             print r'        }'
842             else_ = 'else '
843         print r'        %s{' % else_
844         print r'             wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
845         print r'        }'
846         print r'    }'
847
848     def invokeMethod(self, interface, base, method):
849         if method.type is stdapi.Void:
850             result = ''
851         else:
852             result = '_result = '
853         print '    %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
854     
855     def emit_memcpy(self, dest, src, length):
856         print '        unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig, true);'
857         print '        trace::localWriter.beginArg(0);'
858         print '        trace::localWriter.writePointer((uintptr_t)%s);' % dest
859         print '        trace::localWriter.endArg();'
860         print '        trace::localWriter.beginArg(1);'
861         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
862         print '        trace::localWriter.endArg();'
863         print '        trace::localWriter.beginArg(2);'
864         print '        trace::localWriter.writeUInt(%s);' % length
865         print '        trace::localWriter.endArg();'
866         print '        trace::localWriter.endEnter();'
867         print '        trace::localWriter.beginLeave(_call);'
868         print '        trace::localWriter.endLeave();'
869     
870     def fake_call(self, function, args):
871         print '            unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
872         for arg, instance in zip(function.args, args):
873             assert not arg.output
874             print '            trace::localWriter.beginArg(%u);' % (arg.index,)
875             self.serializeValue(arg.type, instance)
876             print '            trace::localWriter.endArg();'
877         print '            trace::localWriter.endEnter();'
878         print '            trace::localWriter.beginLeave(_fake_call);'
879         print '            trace::localWriter.endLeave();'
880