]> git.cworth.org Git - apitrace/blob - wrappers/trace.py
Cleanup how pointers are handled.
[apitrace] / wrappers / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 # Adjust path
30 import os.path
31 import sys
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
33
34
35 import specs.stdapi as stdapi
36
37
38 def getWrapperInterfaceName(interface):
39     return "Wrap" + interface.expr
40
41
42 class ComplexValueSerializer(stdapi.OnceVisitor):
43     '''Type visitors which generates serialization functions for
44     complex types.
45     
46     Simple types are serialized inline.
47     '''
48
49     def __init__(self, serializer):
50         stdapi.OnceVisitor.__init__(self)
51         self.serializer = serializer
52
53     def visitVoid(self, literal):
54         pass
55
56     def visitLiteral(self, literal):
57         pass
58
59     def visitString(self, string):
60         pass
61
62     def visitConst(self, const):
63         self.visit(const.type)
64
65     def visitStruct(self, struct):
66         for type, name in struct.members:
67             self.visit(type)
68         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
69         print '    static const char * members[%u] = {' % (len(struct.members),)
70         for type, name,  in struct.members:
71             print '        "%s",' % (name,)
72         print '    };'
73         print '    static const trace::StructSig sig = {'
74         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
75         print '    };'
76         print '    trace::localWriter.beginStruct(&sig);'
77         for type, name in struct.members:
78             self.serializer.visit(type, 'value.%s' % (name,))
79         print '    trace::localWriter.endStruct();'
80         print '}'
81         print
82
83     def visitArray(self, array):
84         self.visit(array.type)
85
86     def visitBlob(self, array):
87         pass
88
89     def visitEnum(self, enum):
90         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
91         for value in enum.values:
92             print '   {"%s", %s},' % (value, value)
93         print '};'
94         print
95         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
96         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
97         print '};'
98         print
99
100     def visitBitmask(self, bitmask):
101         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
102         for value in bitmask.values:
103             print '   {"%s", %s},' % (value, value)
104         print '};'
105         print
106         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
107         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
108         print '};'
109         print
110
111     def visitPointer(self, pointer):
112         self.visit(pointer.type)
113
114     def visitIntPointer(self, pointer):
115         pass
116
117     def visitObjPointer(self, pointer):
118         self.visit(pointer.type)
119
120     def visitLinearPointer(self, pointer):
121         self.visit(pointer.type)
122
123     def visitHandle(self, handle):
124         self.visit(handle.type)
125
126     def visitReference(self, reference):
127         self.visit(reference.type)
128
129     def visitAlias(self, alias):
130         self.visit(alias.type)
131
132     def visitOpaque(self, opaque):
133         pass
134
135     def visitInterface(self, interface):
136         pass
137
138     def visitPolymorphic(self, polymorphic):
139         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
140         print '    switch (selector) {'
141         for cases, type in polymorphic.iterSwitch():
142             for case in cases:
143                 print '    %s:' % case
144             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
145             print '        break;'
146         print '    }'
147         print '}'
148         print
149
150
151 class ValueSerializer(stdapi.Visitor):
152     '''Visitor which generates code to serialize any type.
153     
154     Simple types are serialized inline here, whereas the serialization of
155     complex types is dispatched to the serialization functions generated by
156     ComplexValueSerializer visitor above.
157     '''
158
159     def visitLiteral(self, literal, instance):
160         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
161
162     def visitString(self, string, instance):
163         if string.kind == 'String':
164             cast = 'const char *'
165         elif string.kind == 'WString':
166             cast = 'const wchar_t *'
167         else:
168             assert False
169         if cast != string.expr:
170             # reinterpret_cast is necessary for GLubyte * <=> char *
171             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
172         if string.length is not None:
173             length = ', %s' % string.length
174         else:
175             length = ''
176         print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
177
178     def visitConst(self, const, instance):
179         self.visit(const.type, instance)
180
181     def visitStruct(self, struct, instance):
182         print '    _write__%s(%s);' % (struct.tag, instance)
183
184     def visitArray(self, array, instance):
185         length = '__c' + array.type.tag
186         index = '__i' + array.type.tag
187         print '    if (%s) {' % instance
188         print '        size_t %s = %s;' % (length, array.length)
189         print '        trace::localWriter.beginArray(%s);' % length
190         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
191         print '            trace::localWriter.beginElement();'
192         self.visit(array.type, '(%s)[%s]' % (instance, index))
193         print '            trace::localWriter.endElement();'
194         print '        }'
195         print '        trace::localWriter.endArray();'
196         print '    } else {'
197         print '        trace::localWriter.writeNull();'
198         print '    }'
199
200     def visitBlob(self, blob, instance):
201         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
202
203     def visitEnum(self, enum, instance):
204         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
205
206     def visitBitmask(self, bitmask, instance):
207         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
208
209     def visitPointer(self, pointer, instance):
210         print '    if (%s) {' % instance
211         print '        trace::localWriter.beginArray(1);'
212         print '        trace::localWriter.beginElement();'
213         self.visit(pointer.type, "*" + instance)
214         print '        trace::localWriter.endElement();'
215         print '        trace::localWriter.endArray();'
216         print '    } else {'
217         print '        trace::localWriter.writeNull();'
218         print '    }'
219
220     def visitIntPointer(self, pointer, instance):
221         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
222
223     def visitObjPointer(self, pointer, instance):
224         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
225
226     def visitLinearPointer(self, pointer, instance):
227         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
228
229     def visitReference(self, reference, instance):
230         self.visit(reference.type, instance)
231
232     def visitHandle(self, handle, instance):
233         self.visit(handle.type, instance)
234
235     def visitAlias(self, alias, instance):
236         self.visit(alias.type, instance)
237
238     def visitOpaque(self, opaque, instance):
239         print '    trace::localWriter.writePointer((uintptr_t)%s);' % instance
240
241     def visitInterface(self, interface, instance):
242         assert False
243
244     def visitPolymorphic(self, polymorphic, instance):
245         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
246
247
248 class WrapDecider(stdapi.Traverser):
249     '''Type visitor which will decide wheter this type will need wrapping or not.
250     
251     For complex types (arrays, structures), we need to know this before hand.
252     '''
253
254     def __init__(self):
255         self.needsWrapping = False
256
257     def visitLinearPointer(self, void):
258         pass
259
260     def visitInterface(self, interface):
261         self.needsWrapping = True
262
263
264 class ValueWrapper(stdapi.Traverser):
265     '''Type visitor which will generate the code to wrap an instance.
266     
267     Wrapping is necessary mostly for interfaces, however interface pointers can
268     appear anywhere inside complex types.
269     '''
270
271     def visitStruct(self, struct, instance):
272         for type, name in struct.members:
273             self.visit(type, "(%s).%s" % (instance, name))
274
275     def visitArray(self, array, instance):
276         print "    if (%s) {" % instance
277         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
278         self.visit(array.type, instance + "[_i]")
279         print "        }"
280         print "    }"
281
282     def visitPointer(self, pointer, instance):
283         print "    if (%s) {" % instance
284         self.visit(pointer.type, "*" + instance)
285         print "    }"
286     
287     def visitObjPointer(self, pointer, instance):
288         elem_type = pointer.type.mutable()
289         if isinstance(elem_type, stdapi.Interface):
290             self.visitInterfacePointer(elem_type, instance)
291         else:
292             self.visitPointer(self, pointer, instance)
293     
294     def visitInterface(self, interface, instance):
295         raise NotImplementedError
296
297     def visitInterfacePointer(self, interface, instance):
298         print "    if (%s) {" % instance
299         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
300         print "    }"
301     
302     def visitPolymorphic(self, type, instance):
303         # XXX: There might be polymorphic values that need wrapping in the future
304         raise NotImplementedError
305
306
307 class ValueUnwrapper(ValueWrapper):
308     '''Reverse of ValueWrapper.'''
309
310     allocated = False
311
312     def visitArray(self, array, instance):
313         if self.allocated or isinstance(instance, stdapi.Interface):
314             return ValueWrapper.visitArray(self, array, instance)
315         elem_type = array.type.mutable()
316         print "    if (%s && %s) {" % (instance, array.length)
317         print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
318         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
319         print "            _t[_i] = %s[_i];" % instance 
320         self.allocated = True
321         self.visit(array.type, "_t[_i]")
322         print "        }"
323         print "        %s = _t;" % instance
324         print "    }"
325
326     def visitInterfacePointer(self, interface, instance):
327         print r'    if (%s) {' % instance
328         print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
329         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
330         print r'            %s = pWrapper->m_pInstance;' % (instance,)
331         print r'        } else {'
332         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
333         print r'        }'
334         print r'    }'
335
336
337 class Tracer:
338     '''Base class to orchestrate the code generation of API tracing.'''
339
340     def __init__(self):
341         self.api = None
342
343     def serializerFactory(self):
344         '''Create a serializer.
345         
346         Can be overriden by derived classes to inject their own serialzer.
347         '''
348
349         return ValueSerializer()
350
351     def traceApi(self, api):
352         self.api = api
353
354         self.header(api)
355
356         # Includes
357         for header in api.headers:
358             print header
359         print
360
361         # Generate the serializer functions
362         types = api.getAllTypes()
363         visitor = ComplexValueSerializer(self.serializerFactory())
364         map(visitor.visit, types)
365         print
366
367         # Interfaces wrapers
368         self.traceInterfaces(api)
369
370         # Function wrappers
371         self.interface = None
372         self.base = None
373         map(self.traceFunctionDecl, api.functions)
374         map(self.traceFunctionImpl, api.functions)
375         print
376
377         self.footer(api)
378
379     def header(self, api):
380         pass
381
382     def footer(self, api):
383         pass
384
385     def traceFunctionDecl(self, function):
386         # Per-function declarations
387
388         if function.args:
389             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
390         else:
391             print 'static const char ** __%s_args = NULL;' % (function.name,)
392         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
393         print
394
395     def isFunctionPublic(self, function):
396         return True
397
398     def traceFunctionImpl(self, function):
399         if self.isFunctionPublic(function):
400             print 'extern "C" PUBLIC'
401         else:
402             print 'extern "C" PRIVATE'
403         print function.prototype() + ' {'
404         if function.type is not stdapi.Void:
405             print '    %s __result;' % function.type
406         self.traceFunctionImplBody(function)
407         if function.type is not stdapi.Void:
408             self.wrapRet(function, "__result")
409             print '    return __result;'
410         print '}'
411         print
412
413     def traceFunctionImplBody(self, function):
414         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
415         for arg in function.args:
416             if not arg.output:
417                 self.unwrapArg(function, arg)
418                 self.serializeArg(function, arg)
419         print '    trace::localWriter.endEnter();'
420         self.invokeFunction(function)
421         print '    trace::localWriter.beginLeave(__call);'
422         for arg in function.args:
423             if arg.output:
424                 self.serializeArg(function, arg)
425                 self.wrapArg(function, arg)
426         if function.type is not stdapi.Void:
427             self.serializeRet(function, "__result")
428         print '    trace::localWriter.endLeave();'
429
430     def invokeFunction(self, function, prefix='__', suffix=''):
431         if function.type is stdapi.Void:
432             result = ''
433         else:
434             result = '__result = '
435         dispatch = prefix + function.name + suffix
436         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
437
438     def serializeArg(self, function, arg):
439         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
440         self.serializeArgValue(function, arg)
441         print '    trace::localWriter.endArg();'
442
443     def serializeArgValue(self, function, arg):
444         self.serializeValue(arg.type, arg.name)
445
446     def wrapArg(self, function, arg):
447         assert not isinstance(arg.type, stdapi.ObjPointer)
448
449         from specs.winapi import REFIID
450         riid = None
451         for other_arg in function.args:
452             if not other_arg.output and other_arg.type is REFIID:
453                 riid = other_arg
454         if riid is not None and isinstance(arg.type, stdapi.Pointer):
455             assert isinstance(arg.type.type, stdapi.ObjPointer)
456             obj_type = arg.type.type.type
457             assert obj_type is stdapi.Void
458             self.wrapIid(function, riid, arg)
459             return
460
461         self.wrapValue(arg.type, arg.name)
462
463     def unwrapArg(self, function, arg):
464         self.unwrapValue(arg.type, arg.name)
465
466     def serializeRet(self, function, instance):
467         print '    trace::localWriter.beginReturn();'
468         self.serializeValue(function.type, instance)
469         print '    trace::localWriter.endReturn();'
470
471     def serializeValue(self, type, instance):
472         serializer = self.serializerFactory()
473         serializer.visit(type, instance)
474
475     def wrapRet(self, function, instance):
476         self.wrapValue(function.type, instance)
477
478     def unwrapRet(self, function, instance):
479         self.unwrapValue(function.type, instance)
480
481     def needsWrapping(self, type):
482         visitor = WrapDecider()
483         visitor.visit(type)
484         return visitor.needsWrapping
485
486     def wrapValue(self, type, instance):
487         if self.needsWrapping(type):
488             visitor = ValueWrapper()
489             visitor.visit(type, instance)
490
491     def unwrapValue(self, type, instance):
492         if self.needsWrapping(type):
493             visitor = ValueUnwrapper()
494             visitor.visit(type, instance)
495
496     def traceInterfaces(self, api):
497         interfaces = api.getAllInterfaces()
498         if not interfaces:
499             return
500         map(self.declareWrapperInterface, interfaces)
501         self.implementIidWrapper(api)
502         map(self.implementWrapperInterface, interfaces)
503         print
504
505     def declareWrapperInterface(self, interface):
506         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
507         print "{"
508         print "public:"
509         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
510         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
511         print
512         for method in interface.iterMethods():
513             print "    " + method.prototype() + ";"
514         print
515         self.declareWrapperInterfaceVariables(interface)
516         print "};"
517         print
518
519     def declareWrapperInterfaceVariables(self, interface):
520         #print "private:"
521         print "    DWORD m_dwMagic;"
522         print "    %s * m_pInstance;" % (interface.name,)
523
524     def implementWrapperInterface(self, interface):
525         self.interface = interface
526
527         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
528         print '    m_dwMagic = 0xd8365d6c;'
529         print '    m_pInstance = pInstance;'
530         print '}'
531         print
532         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
533         print '}'
534         print
535         
536         for base, method in interface.iterBaseMethods():
537             self.base = base
538             self.implementWrapperInterfaceMethod(interface, base, method)
539
540         print
541
542     def implementWrapperInterfaceMethod(self, interface, base, method):
543         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
544         if method.type is not stdapi.Void:
545             print '    %s __result;' % method.type
546     
547         self.implementWrapperInterfaceMethodBody(interface, base, method)
548     
549         if method.type is not stdapi.Void:
550             print '    return __result;'
551         print '}'
552         print
553
554     def implementWrapperInterfaceMethodBody(self, interface, base, method):
555         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
556         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
557         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
558         print '    trace::localWriter.beginArg(0);'
559         print '    trace::localWriter.writePointer((uintptr_t)m_pInstance);'
560         print '    trace::localWriter.endArg();'
561         for arg in method.args:
562             if not arg.output:
563                 self.unwrapArg(method, arg)
564                 self.serializeArg(method, arg)
565         print '    trace::localWriter.endEnter();'
566         
567         self.invokeMethod(interface, base, method)
568
569         print '    trace::localWriter.beginLeave(__call);'
570         for arg in method.args:
571             if arg.output:
572                 self.serializeArg(method, arg)
573                 self.wrapArg(method, arg)
574
575         if method.type is not stdapi.Void:
576             print '    trace::localWriter.beginReturn();'
577             self.serializeValue(method.type, "__result")
578             print '    trace::localWriter.endReturn();'
579             self.wrapValue(method.type, '__result')
580         print '    trace::localWriter.endLeave();'
581         if method.name == 'Release':
582             assert method.type is not stdapi.Void
583             print '    if (!__result)'
584             print '        delete this;'
585
586     def implementIidWrapper(self, api):
587         print r'static void'
588         print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
589         print r'    os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
590         print r'            functionName, reason,'
591         print r'            riid.Data1, riid.Data2, riid.Data3,'
592         print r'            riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
593         print r'}'
594         print 
595         print r'static void'
596         print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
597         print r'    if (!ppvObj || !*ppvObj) {'
598         print r'        return;'
599         print r'    }'
600         else_ = ''
601         for iface in api.getAllInterfaces():
602             print r'    %sif (riid == IID_%s) {' % (else_, iface.name)
603             print r'        *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
604             print r'    }'
605             else_ = 'else '
606         print r'    %s{' % else_
607         print r'        warnIID(functionName, riid, "unknown");'
608         print r'    }'
609         print r'}'
610         print
611
612     def wrapIid(self, function, riid, out):
613         print r'    if (%s && *%s) {' % (out.name, out.name)
614         functionName = function.name
615         else_ = ''
616         if self.interface is not None:
617             functionName = self.interface.name + '::' + functionName
618             print r'        %sif (*%s == m_pInstance) {' % (else_, out.name,)
619             print r'            *%s = this;' % (out.name,)
620             print r'            if (%s) {' % ' && '.join('%s != IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases()) 
621             print r'                warnIID("%s", %s, "unexpected");' % (functionName, riid.name)
622             print r'            }'
623             print r'        }'
624             else_ = 'else '
625         print r'        %s{' % else_
626         print r'             wrapIID("%s", %s, %s);' % (functionName, riid.name, out.name) 
627         print r'        }'
628         print r'    }'
629
630     def invokeMethod(self, interface, base, method):
631         if method.type is stdapi.Void:
632             result = ''
633         else:
634             result = '__result = '
635         print '    %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args]))
636     
637     def emit_memcpy(self, dest, src, length):
638         print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
639         print '        trace::localWriter.beginArg(0);'
640         print '        trace::localWriter.writePointer((uintptr_t)%s);' % dest
641         print '        trace::localWriter.endArg();'
642         print '        trace::localWriter.beginArg(1);'
643         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
644         print '        trace::localWriter.endArg();'
645         print '        trace::localWriter.beginArg(2);'
646         print '        trace::localWriter.writeUInt(%s);' % length
647         print '        trace::localWriter.endArg();'
648         print '        trace::localWriter.endEnter();'
649         print '        trace::localWriter.beginLeave(__call);'
650         print '        trace::localWriter.endLeave();'
651