]> git.cworth.org Git - apitrace/blob - trace.py
Cleanup/comment/format code.
[apitrace] / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 import specs.stdapi as stdapi
30
31
32 def interface_wrap_name(interface):
33     return "Wrap" + interface.expr
34
35
36 class ComplexValueSerializer(stdapi.OnceVisitor):
37     '''Type visitors which generates serialization functions for
38     complex types.
39     
40     Simple types are serialized inline.
41     '''
42
43     def __init__(self, serializer):
44         stdapi.OnceVisitor.__init__(self)
45         self.serializer = serializer
46
47     def visitVoid(self, literal):
48         pass
49
50     def visitLiteral(self, literal):
51         pass
52
53     def visitString(self, string):
54         pass
55
56     def visitConst(self, const):
57         self.visit(const.type)
58
59     def visitStruct(self, struct):
60         for type, name in struct.members:
61             self.visit(type)
62         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
63         print '    static const char * members[%u] = {' % (len(struct.members),)
64         for type, name,  in struct.members:
65             print '        "%s",' % (name,)
66         print '    };'
67         print '    static const trace::StructSig sig = {'
68         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
69         print '    };'
70         print '    trace::localWriter.beginStruct(&sig);'
71         for type, name in struct.members:
72             self.serializer.visit(type, 'value.%s' % (name,))
73         print '    trace::localWriter.endStruct();'
74         print '}'
75         print
76
77     def visitArray(self, array):
78         self.visit(array.type)
79
80     def visitBlob(self, array):
81         pass
82
83     def visitEnum(self, enum):
84         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
85         for value in enum.values:
86             print '   {"%s", %s},' % (value, value)
87         print '};'
88         print
89         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
90         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
91         print '};'
92         print
93
94     def visitBitmask(self, bitmask):
95         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
96         for value in bitmask.values:
97             print '   {"%s", %s},' % (value, value)
98         print '};'
99         print
100         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
101         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
102         print '};'
103         print
104
105     def visitPointer(self, pointer):
106         self.visit(pointer.type)
107
108     def visitHandle(self, handle):
109         self.visit(handle.type)
110
111     def visitAlias(self, alias):
112         self.visit(alias.type)
113
114     def visitOpaque(self, opaque):
115         pass
116
117     def visitInterface(self, interface):
118         print "class %s : public %s " % (interface_wrap_name(interface), interface.name)
119         print "{"
120         print "public:"
121         print "    %s(%s * pInstance);" % (interface_wrap_name(interface), interface.name)
122         print "    virtual ~%s();" % interface_wrap_name(interface)
123         print
124         for method in interface.iterMethods():
125             print "    " + method.prototype() + ";"
126         print
127         #print "private:"
128         print "    %s * m_pInstance;" % (interface.name,)
129         print "};"
130         print
131
132     def visitPolymorphic(self, polymorphic):
133         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
134         print '    switch (selector) {'
135         for cases, type in polymorphic.iterSwitch():
136             for case in cases:
137                 print '    %s:' % case
138             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
139             print '        break;'
140         print '    }'
141         print '}'
142         print
143
144
145 class ValueSerializer(stdapi.Visitor):
146     '''Visitor which generates code to serialize any type.
147     
148     Simple types are serialized inline here, whereas the serialization of
149     complex types is dispatched to the serialization functions generated by
150     ComplexValueSerializer visitor above.
151     '''
152
153     def visitLiteral(self, literal, instance):
154         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
155
156     def visitString(self, string, instance):
157         if string.length is not None:
158             print '    trace::localWriter.writeString((const char *)%s, %s);' % (instance, string.length)
159         else:
160             print '    trace::localWriter.writeString((const char *)%s);' % instance
161
162     def visitConst(self, const, instance):
163         self.visit(const.type, instance)
164
165     def visitStruct(self, struct, instance):
166         print '    _write__%s(%s);' % (struct.tag, instance)
167
168     def visitArray(self, array, instance):
169         length = '__c' + array.type.tag
170         index = '__i' + array.type.tag
171         print '    if (%s) {' % instance
172         print '        size_t %s = %s;' % (length, array.length)
173         print '        trace::localWriter.beginArray(%s);' % length
174         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
175         print '            trace::localWriter.beginElement();'
176         self.visit(array.type, '(%s)[%s]' % (instance, index))
177         print '            trace::localWriter.endElement();'
178         print '        }'
179         print '        trace::localWriter.endArray();'
180         print '    } else {'
181         print '        trace::localWriter.writeNull();'
182         print '    }'
183
184     def visitBlob(self, blob, instance):
185         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
186
187     def visitEnum(self, enum, instance):
188         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
189
190     def visitBitmask(self, bitmask, instance):
191         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
192
193     def visitPointer(self, pointer, instance):
194         print '    if (%s) {' % instance
195         print '        trace::localWriter.beginArray(1);'
196         print '        trace::localWriter.beginElement();'
197         self.visit(pointer.type, "*" + instance)
198         print '        trace::localWriter.endElement();'
199         print '        trace::localWriter.endArray();'
200         print '    } else {'
201         print '        trace::localWriter.writeNull();'
202         print '    }'
203
204     def visitHandle(self, handle, instance):
205         self.visit(handle.type, instance)
206
207     def visitAlias(self, alias, instance):
208         self.visit(alias.type, instance)
209
210     def visitOpaque(self, opaque, instance):
211         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
212
213     def visitInterface(self, interface, instance):
214         print '    trace::localWriter.writeOpaque((const void *)&%s);' % instance
215
216     def visitPolymorphic(self, polymorphic, instance):
217         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
218
219
220 class ValueWrapper(stdapi.Visitor):
221     '''Type visitor which will generate the code to wrap an instance.
222     
223     Wrapping is necessary mostly for interfaces, however interface pointers can
224     appear anywhere inside complex types.
225     '''
226
227     def visitVoid(self, type, instance):
228         raise NotImplementedError
229
230     def visitLiteral(self, type, instance):
231         pass
232
233     def visitString(self, type, instance):
234         pass
235
236     def visitConst(self, type, instance):
237         pass
238
239     def visitStruct(self, struct, instance):
240         for type, name in struct.members:
241             self.visit(type, "(%s).%s" % (instance, name))
242
243     def visitArray(self, array, instance):
244         # XXX: actually it is possible to return an array of pointers
245         pass
246
247     def visitBlob(self, blob, instance):
248         pass
249
250     def visitEnum(self, enum, instance):
251         pass
252
253     def visitBitmask(self, bitmask, instance):
254         pass
255
256     def visitPointer(self, pointer, instance):
257         print "    if (%s) {" % instance
258         self.visit(pointer.type, "*" + instance)
259         print "    }"
260
261     def visitHandle(self, handle, instance):
262         self.visit(handle.type, instance)
263
264     def visitAlias(self, alias, instance):
265         self.visit(alias.type, instance)
266
267     def visitOpaque(self, opaque, instance):
268         pass
269     
270     def visitInterface(self, interface, instance):
271         assert instance.startswith('*')
272         instance = instance[1:]
273         print "    if (%s) {" % instance
274         print "        %s = new %s(%s);" % (instance, interface_wrap_name(interface), instance)
275         print "    }"
276     
277     def visitPolymorphic(self, type, instance):
278         # XXX: There might be polymorphic values that need wrapping in the future
279         pass
280
281
282 class ValueUnwrapper(ValueWrapper):
283     '''Reverse of ValueWrapper.'''
284
285     def visitInterface(self, interface, instance):
286         assert instance.startswith('*')
287         instance = instance[1:]
288         print "    if (%s) {" % instance
289         print "        %s = static_cast<%s *>(%s)->m_pInstance;" % (instance, interface_wrap_name(interface), instance)
290         print "    }"
291
292
293 class Tracer:
294     '''Base class to orchestrate the code generation of API tracing.'''
295
296     def __init__(self):
297         self.api = None
298
299     def serializerFactory(self):
300         '''Create a serializer.
301         
302         Can be overriden by derived classes to inject their own serialzer.
303         '''
304
305         return ValueSerializer()
306
307     def trace_api(self, api):
308         self.api = api
309
310         self.header(api)
311
312         # Includes
313         for header in api.headers:
314             print header
315         print
316
317         # Generate the serializer functions
318         types = api.all_types()
319         visitor = ComplexValueSerializer(self.serializerFactory())
320         map(visitor.visit, types)
321         print
322
323         # Interfaces wrapers
324         interfaces = [type for type in types if isinstance(type, stdapi.Interface)]
325         map(self.traceInterfaceImpl, interfaces)
326         print
327
328         # Function wrappers
329         map(self.traceFunctionDecl, api.functions)
330         map(self.traceFunctionImpl, api.functions)
331         print
332
333         self.footer(api)
334
335     def header(self, api):
336         pass
337
338     def footer(self, api):
339         pass
340
341     def traceFunctionDecl(self, function):
342         # Per-function declarations
343
344         if function.args:
345             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
346         else:
347             print 'static const char ** __%s_args = NULL;' % (function.name,)
348         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
349         print
350
351     def isFunctionPublic(self, function):
352         return True
353
354     def traceFunctionImpl(self, function):
355         if self.isFunctionPublic(function):
356             print 'extern "C" PUBLIC'
357         else:
358             print 'extern "C" PRIVATE'
359         print function.prototype() + ' {'
360         if function.type is not stdapi.Void:
361             print '    %s __result;' % function.type
362         self.traceFunctionImplBody(function)
363         if function.type is not stdapi.Void:
364             self.wrapRet(function, "__result")
365             print '    return __result;'
366         print '}'
367         print
368
369     def traceFunctionImplBody(self, function):
370         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
371         for arg in function.args:
372             if not arg.output:
373                 self.unwrapArg(function, arg)
374                 self.serializeArg(function, arg)
375         print '    trace::localWriter.endEnter();'
376         self.invokeFunction(function)
377         print '    trace::localWriter.beginLeave(__call);'
378         for arg in function.args:
379             if arg.output:
380                 self.serializeArg(function, arg)
381                 self.wrapArg(function, arg)
382         if function.type is not stdapi.Void:
383             self.serializeRet(function, "__result")
384         print '    trace::localWriter.endLeave();'
385
386     def invokeFunction(self, function, prefix='__', suffix=''):
387         if function.type is stdapi.Void:
388             result = ''
389         else:
390             result = '__result = '
391         dispatch = prefix + function.name + suffix
392         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
393
394     def serializeArg(self, function, arg):
395         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
396         self.serializeArgValue(function, arg)
397         print '    trace::localWriter.endArg();'
398
399     def serializeArgValue(self, function, arg):
400         self.serializeValue(arg.type, arg.name)
401
402     def wrapArg(self, function, arg):
403         self.wrapValue(arg.type, arg.name)
404
405     def unwrapArg(self, function, arg):
406         self.unwrapValue(arg.type, arg.name)
407
408     def serializeRet(self, function, instance):
409         print '    trace::localWriter.beginReturn();'
410         self.serializeValue(function.type, instance)
411         print '    trace::localWriter.endReturn();'
412
413     def serializeValue(self, type, instance):
414         serializer = self.serializerFactory()
415         serializer.visit(type, instance)
416
417     def wrapRet(self, function, instance):
418         self.wrapValue(function.type, instance)
419
420     def unwrapRet(self, function, instance):
421         self.unwrapValue(function.type, instance)
422
423     def wrapValue(self, type, instance):
424         visitor = ValueWrapper()
425         visitor.visit(type, instance)
426
427     def unwrapValue(self, type, instance):
428         visitor = ValueUnwrapper()
429         visitor.visit(type, instance)
430
431     def traceInterfaceImpl(self, interface):
432         print '%s::%s(%s * pInstance) {' % (interface_wrap_name(interface), interface_wrap_name(interface), interface.name)
433         print '    m_pInstance = pInstance;'
434         print '}'
435         print
436         print '%s::~%s() {' % (interface_wrap_name(interface), interface_wrap_name(interface))
437         print '}'
438         print
439         for method in interface.iterMethods():
440             self.traceMethod(interface, method)
441         print
442
443     def traceMethod(self, interface, method):
444         print method.prototype(interface_wrap_name(interface) + '::' + method.name) + ' {'
445         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
446         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
447         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
448         print '    trace::localWriter.beginArg(0);'
449         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
450         print '    trace::localWriter.endArg();'
451         for arg in method.args:
452             if not arg.output:
453                 self.unwrapArg(method, arg)
454                 self.serializeArg(method, arg)
455         if method.type is stdapi.Void:
456             result = ''
457         else:
458             print '    %s __result;' % method.type
459             result = '__result = '
460         print '    trace::localWriter.endEnter();'
461         print '    %sm_pInstance->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
462         print '    trace::localWriter.beginLeave(__call);'
463         for arg in method.args:
464             if arg.output:
465                 self.serializeArg(method, arg)
466                 self.wrapArg(method, arg)
467         if method.type is not stdapi.Void:
468             print '    trace::localWriter.beginReturn();'
469             self.serializeValue(method.type, "__result")
470             print '    trace::localWriter.endReturn();'
471             self.wrapValue(method.type, '__result')
472         print '    trace::localWriter.endLeave();'
473         if method.name == 'QueryInterface':
474             print '    if (ppvObj && *ppvObj) {'
475             print '        if (*ppvObj == m_pInstance) {'
476             print '            *ppvObj = this;'
477             print '        }'
478             for iface in self.api.interfaces:
479                 print r'        else if (riid == IID_%s) {' % iface.name
480                 print r'            *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
481                 print r'        }'
482             print r'        else {'
483             print r'            os::log("apitrace: warning: unknown REFIID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
484             print r'                    riid.Data1, riid.Data2, riid.Data3,'
485             print r'                    riid.Data4[0],'
486             print r'                    riid.Data4[1],'
487             print r'                    riid.Data4[2],'
488             print r'                    riid.Data4[3],'
489             print r'                    riid.Data4[4],'
490             print r'                    riid.Data4[5],'
491             print r'                    riid.Data4[6],'
492             print r'                    riid.Data4[7]);'
493             print r'        }'
494             print '    }'
495         if method.name == 'Release':
496             assert method.type is not stdapi.Void
497             print '    if (!__result)'
498             print '        delete this;'
499         if method.type is not stdapi.Void:
500             print '    return __result;'
501         print '}'
502         print
503