]> git.cworth.org Git - apitrace/blob - trace.py
Distinguish linear pointers.
[apitrace] / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 import specs.stdapi as stdapi
30
31
32 def interface_wrap_name(interface):
33     return "Wrap" + interface.expr
34
35
36 class ComplexValueSerializer(stdapi.OnceVisitor):
37     '''Type visitors which generates serialization functions for
38     complex types.
39     
40     Simple types are serialized inline.
41     '''
42
43     def __init__(self, serializer):
44         stdapi.OnceVisitor.__init__(self)
45         self.serializer = serializer
46
47     def visitVoid(self, literal):
48         pass
49
50     def visitLiteral(self, literal):
51         pass
52
53     def visitString(self, string):
54         pass
55
56     def visitConst(self, const):
57         self.visit(const.type)
58
59     def visitStruct(self, struct):
60         for type, name in struct.members:
61             self.visit(type)
62         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
63         print '    static const char * members[%u] = {' % (len(struct.members),)
64         for type, name,  in struct.members:
65             print '        "%s",' % (name,)
66         print '    };'
67         print '    static const trace::StructSig sig = {'
68         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
69         print '    };'
70         print '    trace::localWriter.beginStruct(&sig);'
71         for type, name in struct.members:
72             self.serializer.visit(type, 'value.%s' % (name,))
73         print '    trace::localWriter.endStruct();'
74         print '}'
75         print
76
77     def visitArray(self, array):
78         self.visit(array.type)
79
80     def visitBlob(self, array):
81         pass
82
83     def visitEnum(self, enum):
84         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
85         for value in enum.values:
86             print '   {"%s", %s},' % (value, value)
87         print '};'
88         print
89         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
90         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
91         print '};'
92         print
93
94     def visitBitmask(self, bitmask):
95         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
96         for value in bitmask.values:
97             print '   {"%s", %s},' % (value, value)
98         print '};'
99         print
100         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
101         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
102         print '};'
103         print
104
105     def visitPointer(self, pointer):
106         self.visit(pointer.type)
107
108     def visitIntPointer(self, pointer):
109         pass
110
111     def visitLinearPointer(self, pointer):
112         self.visit(pointer.type)
113
114     def visitHandle(self, handle):
115         self.visit(handle.type)
116
117     def visitAlias(self, alias):
118         self.visit(alias.type)
119
120     def visitOpaque(self, opaque):
121         pass
122
123     def visitInterface(self, interface):
124         print "class %s : public %s " % (interface_wrap_name(interface), interface.name)
125         print "{"
126         print "public:"
127         print "    %s(%s * pInstance);" % (interface_wrap_name(interface), interface.name)
128         print "    virtual ~%s();" % interface_wrap_name(interface)
129         print
130         for method in interface.iterMethods():
131             print "    " + method.prototype() + ";"
132         print
133         #print "private:"
134         print "    %s * m_pInstance;" % (interface.name,)
135         print "};"
136         print
137
138     def visitPolymorphic(self, polymorphic):
139         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
140         print '    switch (selector) {'
141         for cases, type in polymorphic.iterSwitch():
142             for case in cases:
143                 print '    %s:' % case
144             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
145             print '        break;'
146         print '    }'
147         print '}'
148         print
149
150
151 class ValueSerializer(stdapi.Visitor):
152     '''Visitor which generates code to serialize any type.
153     
154     Simple types are serialized inline here, whereas the serialization of
155     complex types is dispatched to the serialization functions generated by
156     ComplexValueSerializer visitor above.
157     '''
158
159     def visitLiteral(self, literal, instance):
160         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
161
162     def visitString(self, string, instance):
163         if string.length is not None:
164             print '    trace::localWriter.writeString((const char *)%s, %s);' % (instance, string.length)
165         else:
166             print '    trace::localWriter.writeString((const char *)%s);' % instance
167
168     def visitConst(self, const, instance):
169         self.visit(const.type, instance)
170
171     def visitStruct(self, struct, instance):
172         print '    _write__%s(%s);' % (struct.tag, instance)
173
174     def visitArray(self, array, instance):
175         length = '__c' + array.type.tag
176         index = '__i' + array.type.tag
177         print '    if (%s) {' % instance
178         print '        size_t %s = %s;' % (length, array.length)
179         print '        trace::localWriter.beginArray(%s);' % length
180         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
181         print '            trace::localWriter.beginElement();'
182         self.visit(array.type, '(%s)[%s]' % (instance, index))
183         print '            trace::localWriter.endElement();'
184         print '        }'
185         print '        trace::localWriter.endArray();'
186         print '    } else {'
187         print '        trace::localWriter.writeNull();'
188         print '    }'
189
190     def visitBlob(self, blob, instance):
191         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
192
193     def visitEnum(self, enum, instance):
194         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
195
196     def visitBitmask(self, bitmask, instance):
197         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
198
199     def visitPointer(self, pointer, instance):
200         print '    if (%s) {' % instance
201         print '        trace::localWriter.beginArray(1);'
202         print '        trace::localWriter.beginElement();'
203         self.visit(pointer.type, "*" + instance)
204         print '        trace::localWriter.endElement();'
205         print '        trace::localWriter.endArray();'
206         print '    } else {'
207         print '        trace::localWriter.writeNull();'
208         print '    }'
209
210     def visitIntPointer(self, pointer, instance):
211         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
212
213     def visitLinearPointer(self, pointer, instance):
214         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
215
216     def visitHandle(self, handle, instance):
217         self.visit(handle.type, instance)
218
219     def visitAlias(self, alias, instance):
220         self.visit(alias.type, instance)
221
222     def visitOpaque(self, opaque, instance):
223         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
224
225     def visitInterface(self, interface, instance):
226         print '    trace::localWriter.writeOpaque((const void *)&%s);' % instance
227
228     def visitPolymorphic(self, polymorphic, instance):
229         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
230
231
232 class ValueWrapper(stdapi.Visitor):
233     '''Type visitor which will generate the code to wrap an instance.
234     
235     Wrapping is necessary mostly for interfaces, however interface pointers can
236     appear anywhere inside complex types.
237     '''
238
239     def visitVoid(self, type, instance):
240         raise NotImplementedError
241
242     def visitLiteral(self, type, instance):
243         pass
244
245     def visitString(self, type, instance):
246         pass
247
248     def visitConst(self, type, instance):
249         pass
250
251     def visitStruct(self, struct, instance):
252         for type, name in struct.members:
253             self.visit(type, "(%s).%s" % (instance, name))
254
255     def visitArray(self, array, instance):
256         # XXX: actually it is possible to return an array of pointers
257         pass
258
259     def visitBlob(self, blob, instance):
260         pass
261
262     def visitEnum(self, enum, instance):
263         pass
264
265     def visitBitmask(self, bitmask, instance):
266         pass
267
268     def visitPointer(self, pointer, instance):
269         print "    if (%s) {" % instance
270         self.visit(pointer.type, "*" + instance)
271         print "    }"
272     
273     def visitIntPointer(self, pointer, instance):
274         pass
275
276     def visitLinearPointer(self, pointer, instance):
277         pass
278
279     def visitHandle(self, handle, instance):
280         self.visit(handle.type, instance)
281
282     def visitAlias(self, alias, instance):
283         self.visit(alias.type, instance)
284
285     def visitOpaque(self, opaque, instance):
286         pass
287     
288     def visitInterface(self, interface, instance):
289         assert instance.startswith('*')
290         instance = instance[1:]
291         print "    if (%s) {" % instance
292         print "        %s = new %s(%s);" % (instance, interface_wrap_name(interface), instance)
293         print "    }"
294     
295     def visitPolymorphic(self, type, instance):
296         # XXX: There might be polymorphic values that need wrapping in the future
297         pass
298
299
300 class ValueUnwrapper(ValueWrapper):
301     '''Reverse of ValueWrapper.'''
302
303     def visitInterface(self, interface, instance):
304         assert instance.startswith('*')
305         instance = instance[1:]
306         print "    if (%s) {" % instance
307         print "        %s = static_cast<%s *>(%s)->m_pInstance;" % (instance, interface_wrap_name(interface), instance)
308         print "    }"
309
310
311 class Tracer:
312     '''Base class to orchestrate the code generation of API tracing.'''
313
314     def __init__(self):
315         self.api = None
316
317     def serializerFactory(self):
318         '''Create a serializer.
319         
320         Can be overriden by derived classes to inject their own serialzer.
321         '''
322
323         return ValueSerializer()
324
325     def trace_api(self, api):
326         self.api = api
327
328         self.header(api)
329
330         # Includes
331         for header in api.headers:
332             print header
333         print
334
335         # Generate the serializer functions
336         types = api.all_types()
337         visitor = ComplexValueSerializer(self.serializerFactory())
338         map(visitor.visit, types)
339         print
340
341         # Interfaces wrapers
342         interfaces = [type for type in types if isinstance(type, stdapi.Interface)]
343         map(self.traceInterfaceImpl, interfaces)
344         print
345
346         # Function wrappers
347         map(self.traceFunctionDecl, api.functions)
348         map(self.traceFunctionImpl, api.functions)
349         print
350
351         self.footer(api)
352
353     def header(self, api):
354         pass
355
356     def footer(self, api):
357         pass
358
359     def traceFunctionDecl(self, function):
360         # Per-function declarations
361
362         if function.args:
363             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
364         else:
365             print 'static const char ** __%s_args = NULL;' % (function.name,)
366         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
367         print
368
369     def isFunctionPublic(self, function):
370         return True
371
372     def traceFunctionImpl(self, function):
373         if self.isFunctionPublic(function):
374             print 'extern "C" PUBLIC'
375         else:
376             print 'extern "C" PRIVATE'
377         print function.prototype() + ' {'
378         if function.type is not stdapi.Void:
379             print '    %s __result;' % function.type
380         self.traceFunctionImplBody(function)
381         if function.type is not stdapi.Void:
382             self.wrapRet(function, "__result")
383             print '    return __result;'
384         print '}'
385         print
386
387     def traceFunctionImplBody(self, function):
388         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
389         for arg in function.args:
390             if not arg.output:
391                 self.unwrapArg(function, arg)
392                 self.serializeArg(function, arg)
393         print '    trace::localWriter.endEnter();'
394         self.invokeFunction(function)
395         print '    trace::localWriter.beginLeave(__call);'
396         for arg in function.args:
397             if arg.output:
398                 self.serializeArg(function, arg)
399                 self.wrapArg(function, arg)
400         if function.type is not stdapi.Void:
401             self.serializeRet(function, "__result")
402         print '    trace::localWriter.endLeave();'
403
404     def invokeFunction(self, function, prefix='__', suffix=''):
405         if function.type is stdapi.Void:
406             result = ''
407         else:
408             result = '__result = '
409         dispatch = prefix + function.name + suffix
410         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
411
412     def serializeArg(self, function, arg):
413         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
414         self.serializeArgValue(function, arg)
415         print '    trace::localWriter.endArg();'
416
417     def serializeArgValue(self, function, arg):
418         self.serializeValue(arg.type, arg.name)
419
420     def wrapArg(self, function, arg):
421         self.wrapValue(arg.type, arg.name)
422
423     def unwrapArg(self, function, arg):
424         self.unwrapValue(arg.type, arg.name)
425
426     def serializeRet(self, function, instance):
427         print '    trace::localWriter.beginReturn();'
428         self.serializeValue(function.type, instance)
429         print '    trace::localWriter.endReturn();'
430
431     def serializeValue(self, type, instance):
432         serializer = self.serializerFactory()
433         serializer.visit(type, instance)
434
435     def wrapRet(self, function, instance):
436         self.wrapValue(function.type, instance)
437
438     def unwrapRet(self, function, instance):
439         self.unwrapValue(function.type, instance)
440
441     def wrapValue(self, type, instance):
442         visitor = ValueWrapper()
443         visitor.visit(type, instance)
444
445     def unwrapValue(self, type, instance):
446         visitor = ValueUnwrapper()
447         visitor.visit(type, instance)
448
449     def traceInterfaceImpl(self, interface):
450         print '%s::%s(%s * pInstance) {' % (interface_wrap_name(interface), interface_wrap_name(interface), interface.name)
451         print '    m_pInstance = pInstance;'
452         print '}'
453         print
454         print '%s::~%s() {' % (interface_wrap_name(interface), interface_wrap_name(interface))
455         print '}'
456         print
457         for method in interface.iterMethods():
458             self.traceMethod(interface, method)
459         print
460
461     def traceMethod(self, interface, method):
462         print method.prototype(interface_wrap_name(interface) + '::' + method.name) + ' {'
463         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
464         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
465         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
466         print '    trace::localWriter.beginArg(0);'
467         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
468         print '    trace::localWriter.endArg();'
469         for arg in method.args:
470             if not arg.output:
471                 self.unwrapArg(method, arg)
472                 self.serializeArg(method, arg)
473         if method.type is stdapi.Void:
474             result = ''
475         else:
476             print '    %s __result;' % method.type
477             result = '__result = '
478         print '    trace::localWriter.endEnter();'
479         print '    %sm_pInstance->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
480         print '    trace::localWriter.beginLeave(__call);'
481         for arg in method.args:
482             if arg.output:
483                 self.serializeArg(method, arg)
484                 self.wrapArg(method, arg)
485         if method.type is not stdapi.Void:
486             print '    trace::localWriter.beginReturn();'
487             self.serializeValue(method.type, "__result")
488             print '    trace::localWriter.endReturn();'
489             self.wrapValue(method.type, '__result')
490         print '    trace::localWriter.endLeave();'
491         if method.name == 'QueryInterface':
492             print '    if (ppvObj && *ppvObj) {'
493             print '        if (*ppvObj == m_pInstance) {'
494             print '            *ppvObj = this;'
495             print '        }'
496             for iface in self.api.interfaces:
497                 print r'        else if (riid == IID_%s) {' % iface.name
498                 print r'            *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
499                 print r'        }'
500             print r'        else {'
501             print r'            os::log("apitrace: warning: unknown REFIID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
502             print r'                    riid.Data1, riid.Data2, riid.Data3,'
503             print r'                    riid.Data4[0],'
504             print r'                    riid.Data4[1],'
505             print r'                    riid.Data4[2],'
506             print r'                    riid.Data4[3],'
507             print r'                    riid.Data4[4],'
508             print r'                    riid.Data4[5],'
509             print r'                    riid.Data4[6],'
510             print r'                    riid.Data4[7]);'
511             print r'        }'
512             print '    }'
513         if method.name == 'Release':
514             assert method.type is not stdapi.Void
515             print '    if (!__result)'
516             print '        delete this;'
517         if method.type is not stdapi.Void:
518             print '    return __result;'
519         print '}'
520         print
521