]> git.cworth.org Git - apitrace/blob - trace.py
Describe C++ references accurately.
[apitrace] / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 import specs.stdapi as stdapi
30
31
32 def getWrapperInterfaceName(interface):
33     return "Wrap" + interface.expr
34
35
36 class ComplexValueSerializer(stdapi.OnceVisitor):
37     '''Type visitors which generates serialization functions for
38     complex types.
39     
40     Simple types are serialized inline.
41     '''
42
43     def __init__(self, serializer):
44         stdapi.OnceVisitor.__init__(self)
45         self.serializer = serializer
46
47     def visitVoid(self, literal):
48         pass
49
50     def visitLiteral(self, literal):
51         pass
52
53     def visitString(self, string):
54         pass
55
56     def visitConst(self, const):
57         self.visit(const.type)
58
59     def visitStruct(self, struct):
60         for type, name in struct.members:
61             self.visit(type)
62         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
63         print '    static const char * members[%u] = {' % (len(struct.members),)
64         for type, name,  in struct.members:
65             print '        "%s",' % (name,)
66         print '    };'
67         print '    static const trace::StructSig sig = {'
68         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
69         print '    };'
70         print '    trace::localWriter.beginStruct(&sig);'
71         for type, name in struct.members:
72             self.serializer.visit(type, 'value.%s' % (name,))
73         print '    trace::localWriter.endStruct();'
74         print '}'
75         print
76
77     def visitArray(self, array):
78         self.visit(array.type)
79
80     def visitBlob(self, array):
81         pass
82
83     def visitEnum(self, enum):
84         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
85         for value in enum.values:
86             print '   {"%s", %s},' % (value, value)
87         print '};'
88         print
89         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
90         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
91         print '};'
92         print
93
94     def visitBitmask(self, bitmask):
95         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
96         for value in bitmask.values:
97             print '   {"%s", %s},' % (value, value)
98         print '};'
99         print
100         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
101         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
102         print '};'
103         print
104
105     def visitPointer(self, pointer):
106         self.visit(pointer.type)
107
108     def visitIntPointer(self, pointer):
109         pass
110
111     def visitLinearPointer(self, pointer):
112         self.visit(pointer.type)
113
114     def visitHandle(self, handle):
115         self.visit(handle.type)
116
117     def visitReference(self, reference):
118         self.visit(reference.type)
119
120     def visitAlias(self, alias):
121         self.visit(alias.type)
122
123     def visitOpaque(self, opaque):
124         pass
125
126     def visitInterface(self, interface):
127         pass
128
129     def visitPolymorphic(self, polymorphic):
130         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
131         print '    switch (selector) {'
132         for cases, type in polymorphic.iterSwitch():
133             for case in cases:
134                 print '    %s:' % case
135             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
136             print '        break;'
137         print '    }'
138         print '}'
139         print
140
141
142 class ValueSerializer(stdapi.Visitor):
143     '''Visitor which generates code to serialize any type.
144     
145     Simple types are serialized inline here, whereas the serialization of
146     complex types is dispatched to the serialization functions generated by
147     ComplexValueSerializer visitor above.
148     '''
149
150     def visitLiteral(self, literal, instance):
151         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
152
153     def visitString(self, string, instance):
154         if string.kind == 'String':
155             cast = 'const char *'
156         elif string.kind == 'WString':
157             cast = 'const wchar_t *'
158         else:
159             assert False
160         if cast != string.expr:
161             # reinterpret_cast is necessary for GLubyte * <=> char *
162             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
163         if string.length is not None:
164             length = ', %s' % string.length
165         else:
166             length = ''
167         print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
168
169     def visitConst(self, const, instance):
170         self.visit(const.type, instance)
171
172     def visitStruct(self, struct, instance):
173         print '    _write__%s(%s);' % (struct.tag, instance)
174
175     def visitArray(self, array, instance):
176         length = '__c' + array.type.tag
177         index = '__i' + array.type.tag
178         print '    if (%s) {' % instance
179         print '        size_t %s = %s;' % (length, array.length)
180         print '        trace::localWriter.beginArray(%s);' % length
181         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
182         print '            trace::localWriter.beginElement();'
183         self.visit(array.type, '(%s)[%s]' % (instance, index))
184         print '            trace::localWriter.endElement();'
185         print '        }'
186         print '        trace::localWriter.endArray();'
187         print '    } else {'
188         print '        trace::localWriter.writeNull();'
189         print '    }'
190
191     def visitBlob(self, blob, instance):
192         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
193
194     def visitEnum(self, enum, instance):
195         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
196
197     def visitBitmask(self, bitmask, instance):
198         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
199
200     def visitPointer(self, pointer, instance):
201         print '    if (%s) {' % instance
202         print '        trace::localWriter.beginArray(1);'
203         print '        trace::localWriter.beginElement();'
204         self.visit(pointer.type, "*" + instance)
205         print '        trace::localWriter.endElement();'
206         print '        trace::localWriter.endArray();'
207         print '    } else {'
208         print '        trace::localWriter.writeNull();'
209         print '    }'
210
211     def visitIntPointer(self, pointer, instance):
212         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
213
214     def visitLinearPointer(self, pointer, instance):
215         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
216
217     def visitReference(self, reference, instance):
218         self.visit(reference.type, instance)
219
220     def visitHandle(self, handle, instance):
221         self.visit(handle.type, instance)
222
223     def visitAlias(self, alias, instance):
224         self.visit(alias.type, instance)
225
226     def visitOpaque(self, opaque, instance):
227         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
228
229     def visitInterface(self, interface, instance):
230         print '    trace::localWriter.writeOpaque((const void *)&%s);' % instance
231
232     def visitPolymorphic(self, polymorphic, instance):
233         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
234
235
236 class ValueWrapper(stdapi.Visitor):
237     '''Type visitor which will generate the code to wrap an instance.
238     
239     Wrapping is necessary mostly for interfaces, however interface pointers can
240     appear anywhere inside complex types.
241     '''
242
243     def visitVoid(self, type, instance):
244         raise NotImplementedError
245
246     def visitLiteral(self, type, instance):
247         pass
248
249     def visitString(self, type, instance):
250         pass
251
252     def visitConst(self, type, instance):
253         pass
254
255     def visitStruct(self, struct, instance):
256         for type, name in struct.members:
257             self.visit(type, "(%s).%s" % (instance, name))
258
259     def visitArray(self, array, instance):
260         # XXX: actually it is possible to return an array of pointers
261         pass
262
263     def visitBlob(self, blob, instance):
264         pass
265
266     def visitEnum(self, enum, instance):
267         pass
268
269     def visitBitmask(self, bitmask, instance):
270         pass
271
272     def visitPointer(self, pointer, instance):
273         print "    if (%s) {" % instance
274         self.visit(pointer.type, "*" + instance)
275         print "    }"
276     
277     def visitIntPointer(self, pointer, instance):
278         pass
279
280     def visitLinearPointer(self, pointer, instance):
281         pass
282
283     def visitReference(self, reference, instance):
284         self.visit(reference.type, instance)
285     
286     def visitHandle(self, handle, instance):
287         self.visit(handle.type, instance)
288
289     def visitAlias(self, alias, instance):
290         self.visit(alias.type, instance)
291
292     def visitOpaque(self, opaque, instance):
293         pass
294     
295     def visitInterface(self, interface, instance):
296         assert instance.startswith('*')
297         instance = instance[1:]
298         print "    if (%s) {" % instance
299         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
300         print "    }"
301     
302     def visitPolymorphic(self, type, instance):
303         # XXX: There might be polymorphic values that need wrapping in the future
304         pass
305
306
307 class ValueUnwrapper(ValueWrapper):
308     '''Reverse of ValueWrapper.'''
309
310     def visitInterface(self, interface, instance):
311         assert instance.startswith('*')
312         instance = instance[1:]
313         print r'    if (%s) {' % instance
314         print r'        %s *pWrapper = static_cast<%s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
315         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
316         print r'            %s = pWrapper->m_pInstance;' % (instance,)
317         print r'        } else {'
318         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
319         print r'        }'
320         print r'    }'
321
322
323 class Tracer:
324     '''Base class to orchestrate the code generation of API tracing.'''
325
326     def __init__(self):
327         self.api = None
328
329     def serializerFactory(self):
330         '''Create a serializer.
331         
332         Can be overriden by derived classes to inject their own serialzer.
333         '''
334
335         return ValueSerializer()
336
337     def trace_api(self, api):
338         self.api = api
339
340         self.header(api)
341
342         # Includes
343         for header in api.headers:
344             print header
345         print
346
347         # Generate the serializer functions
348         types = api.getAllTypes()
349         visitor = ComplexValueSerializer(self.serializerFactory())
350         map(visitor.visit, types)
351         print
352
353         # Interfaces wrapers
354         interfaces = api.getAllInterfaces()
355         map(self.declareWrapperInterface, interfaces)
356         map(self.implementWrapperInterface, interfaces)
357         print
358
359         # Function wrappers
360         map(self.traceFunctionDecl, api.functions)
361         map(self.traceFunctionImpl, api.functions)
362         print
363
364         self.footer(api)
365
366     def header(self, api):
367         pass
368
369     def footer(self, api):
370         pass
371
372     def traceFunctionDecl(self, function):
373         # Per-function declarations
374
375         if function.args:
376             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
377         else:
378             print 'static const char ** __%s_args = NULL;' % (function.name,)
379         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
380         print
381
382     def isFunctionPublic(self, function):
383         return True
384
385     def traceFunctionImpl(self, function):
386         if self.isFunctionPublic(function):
387             print 'extern "C" PUBLIC'
388         else:
389             print 'extern "C" PRIVATE'
390         print function.prototype() + ' {'
391         if function.type is not stdapi.Void:
392             print '    %s __result;' % function.type
393         self.traceFunctionImplBody(function)
394         if function.type is not stdapi.Void:
395             self.wrapRet(function, "__result")
396             print '    return __result;'
397         print '}'
398         print
399
400     def traceFunctionImplBody(self, function):
401         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
402         for arg in function.args:
403             if not arg.output:
404                 self.unwrapArg(function, arg)
405                 self.serializeArg(function, arg)
406         print '    trace::localWriter.endEnter();'
407         self.invokeFunction(function)
408         print '    trace::localWriter.beginLeave(__call);'
409         for arg in function.args:
410             if arg.output:
411                 self.serializeArg(function, arg)
412                 self.wrapArg(function, arg)
413         if function.type is not stdapi.Void:
414             self.serializeRet(function, "__result")
415         print '    trace::localWriter.endLeave();'
416
417     def invokeFunction(self, function, prefix='__', suffix=''):
418         if function.type is stdapi.Void:
419             result = ''
420         else:
421             result = '__result = '
422         dispatch = prefix + function.name + suffix
423         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
424
425     def serializeArg(self, function, arg):
426         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
427         self.serializeArgValue(function, arg)
428         print '    trace::localWriter.endArg();'
429
430     def serializeArgValue(self, function, arg):
431         self.serializeValue(arg.type, arg.name)
432
433     def wrapArg(self, function, arg):
434         self.wrapValue(arg.type, arg.name)
435
436     def unwrapArg(self, function, arg):
437         self.unwrapValue(arg.type, arg.name)
438
439     def serializeRet(self, function, instance):
440         print '    trace::localWriter.beginReturn();'
441         self.serializeValue(function.type, instance)
442         print '    trace::localWriter.endReturn();'
443
444     def serializeValue(self, type, instance):
445         serializer = self.serializerFactory()
446         serializer.visit(type, instance)
447
448     def wrapRet(self, function, instance):
449         self.wrapValue(function.type, instance)
450
451     def unwrapRet(self, function, instance):
452         self.unwrapValue(function.type, instance)
453
454     def wrapValue(self, type, instance):
455         visitor = ValueWrapper()
456         visitor.visit(type, instance)
457
458     def unwrapValue(self, type, instance):
459         visitor = ValueUnwrapper()
460         visitor.visit(type, instance)
461
462     def declareWrapperInterface(self, interface):
463         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
464         print "{"
465         print "public:"
466         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
467         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
468         print
469         for method in interface.iterMethods():
470             print "    " + method.prototype() + ";"
471         print
472         self.declareWrapperInterfaceVariables(interface)
473         print "};"
474         print
475
476     def declareWrapperInterfaceVariables(self, interface):
477         #print "private:"
478         print "    DWORD m_dwMagic;"
479         print "    %s * m_pInstance;" % (interface.name,)
480
481     def implementWrapperInterface(self, interface):
482         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
483         print '    m_dwMagic = 0xd8365d6c;'
484         print '    m_pInstance = pInstance;'
485         print '}'
486         print
487         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
488         print '}'
489         print
490         for base, method in interface.iterBaseMethods():
491             self.implementWrapperInterfaceMethod(interface, base, method)
492         print
493
494     def implementWrapperInterfaceMethod(self, interface, base, method):
495         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
496         if method.type is not stdapi.Void:
497             print '    %s __result;' % method.type
498     
499         self.implementWrapperInterfaceMethodBody(interface, base, method)
500     
501         if method.type is not stdapi.Void:
502             print '    return __result;'
503         print '}'
504         print
505
506     def implementWrapperInterfaceMethodBody(self, interface, base, method):
507         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
508         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
509         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
510         print '    trace::localWriter.beginArg(0);'
511         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
512         print '    trace::localWriter.endArg();'
513
514         from specs.winapi import REFIID
515         from specs.stdapi import Pointer, Opaque, Interface
516
517         riid = None
518         for arg in method.args:
519             if not arg.output:
520                 self.unwrapArg(method, arg)
521                 self.serializeArg(method, arg)
522                 if arg.type is REFIID:
523                     riid = arg
524         print '    trace::localWriter.endEnter();'
525         
526         self.invokeMethod(interface, base, method)
527
528         print '    trace::localWriter.beginLeave(__call);'
529         for arg in method.args:
530             if arg.output:
531                 self.serializeArg(method, arg)
532                 self.wrapArg(method, arg)
533                 if riid is not None and isinstance(arg.type, Pointer):
534                     if isinstance(arg.type.type, Opaque):
535                         self.wrapIid(riid, arg)
536                     else:
537                         assert isinstance(arg.type.type, Pointer)
538                         assert isinstance(arg.type.type.type, Interface)
539
540         if method.type is not stdapi.Void:
541             print '    trace::localWriter.beginReturn();'
542             self.serializeValue(method.type, "__result")
543             print '    trace::localWriter.endReturn();'
544             self.wrapValue(method.type, '__result')
545         print '    trace::localWriter.endLeave();'
546         if method.name == 'Release':
547             assert method.type is not stdapi.Void
548             print '    if (!__result)'
549             print '        delete this;'
550
551     def wrapIid(self, riid, out):
552             print '    if (%s && *%s) {' % (out.name, out.name)
553             print '        if (*%s == m_pInstance) {' % (out.name,)
554             print '            *%s = this;' % (out.name,)
555             print '        }'
556             for iface in self.api.getAllInterfaces():
557                 print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
558                 print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
559                 print r'        }'
560             print r'        else {'
561             print r'            os::log("apitrace: warning: %s: unknown REFIID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
562             print r'                    __FUNCTION__,'
563             print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
564             print r'                    %s.Data4[0],' % (riid.name,)
565             print r'                    %s.Data4[1],' % (riid.name,)
566             print r'                    %s.Data4[2],' % (riid.name,)
567             print r'                    %s.Data4[3],' % (riid.name,)
568             print r'                    %s.Data4[4],' % (riid.name,)
569             print r'                    %s.Data4[5],' % (riid.name,)
570             print r'                    %s.Data4[6],' % (riid.name,)
571             print r'                    %s.Data4[7]);' % (riid.name,)
572             print r'        }'
573             print '    }'
574
575     def invokeMethod(self, interface, base, method):
576         if method.type is stdapi.Void:
577             result = ''
578         else:
579             result = '__result = '
580         print '    %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args]))
581     
582     def emit_memcpy(self, dest, src, length):
583         print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
584         print '        trace::localWriter.beginArg(0);'
585         print '        trace::localWriter.writeOpaque(%s);' % dest
586         print '        trace::localWriter.endArg();'
587         print '        trace::localWriter.beginArg(1);'
588         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
589         print '        trace::localWriter.endArg();'
590         print '        trace::localWriter.beginArg(2);'
591         print '        trace::localWriter.writeUInt(%s);' % length
592         print '        trace::localWriter.endArg();'
593         print '        trace::localWriter.endEnter();'
594         print '        trace::localWriter.beginLeave(__call);'
595         print '        trace::localWriter.endLeave();'
596