]> git.cworth.org Git - apitrace/blob - trace.py
Cleanup unicode support.
[apitrace] / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 import specs.stdapi as stdapi
30
31
32 def getWrapperInterfaceName(interface):
33     return "Wrap" + interface.expr
34
35
36 class ComplexValueSerializer(stdapi.OnceVisitor):
37     '''Type visitors which generates serialization functions for
38     complex types.
39     
40     Simple types are serialized inline.
41     '''
42
43     def __init__(self, serializer):
44         stdapi.OnceVisitor.__init__(self)
45         self.serializer = serializer
46
47     def visitVoid(self, literal):
48         pass
49
50     def visitLiteral(self, literal):
51         pass
52
53     def visitString(self, string):
54         pass
55
56     def visitConst(self, const):
57         self.visit(const.type)
58
59     def visitStruct(self, struct):
60         for type, name in struct.members:
61             self.visit(type)
62         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
63         print '    static const char * members[%u] = {' % (len(struct.members),)
64         for type, name,  in struct.members:
65             print '        "%s",' % (name,)
66         print '    };'
67         print '    static const trace::StructSig sig = {'
68         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
69         print '    };'
70         print '    trace::localWriter.beginStruct(&sig);'
71         for type, name in struct.members:
72             self.serializer.visit(type, 'value.%s' % (name,))
73         print '    trace::localWriter.endStruct();'
74         print '}'
75         print
76
77     def visitArray(self, array):
78         self.visit(array.type)
79
80     def visitBlob(self, array):
81         pass
82
83     def visitEnum(self, enum):
84         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
85         for value in enum.values:
86             print '   {"%s", %s},' % (value, value)
87         print '};'
88         print
89         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
90         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
91         print '};'
92         print
93
94     def visitBitmask(self, bitmask):
95         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
96         for value in bitmask.values:
97             print '   {"%s", %s},' % (value, value)
98         print '};'
99         print
100         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
101         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
102         print '};'
103         print
104
105     def visitPointer(self, pointer):
106         self.visit(pointer.type)
107
108     def visitIntPointer(self, pointer):
109         pass
110
111     def visitLinearPointer(self, pointer):
112         self.visit(pointer.type)
113
114     def visitHandle(self, handle):
115         self.visit(handle.type)
116
117     def visitAlias(self, alias):
118         self.visit(alias.type)
119
120     def visitOpaque(self, opaque):
121         pass
122
123     def visitInterface(self, interface):
124         pass
125
126     def visitPolymorphic(self, polymorphic):
127         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
128         print '    switch (selector) {'
129         for cases, type in polymorphic.iterSwitch():
130             for case in cases:
131                 print '    %s:' % case
132             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
133             print '        break;'
134         print '    }'
135         print '}'
136         print
137
138
139 class ValueSerializer(stdapi.Visitor):
140     '''Visitor which generates code to serialize any type.
141     
142     Simple types are serialized inline here, whereas the serialization of
143     complex types is dispatched to the serialization functions generated by
144     ComplexValueSerializer visitor above.
145     '''
146
147     def visitLiteral(self, literal, instance):
148         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
149
150     def visitString(self, string, instance):
151         if string.kind == 'String':
152             cast = 'const char *'
153         elif string.kind == 'WString':
154             cast = 'const wchar_t *'
155         else:
156             assert False
157         if cast != string.expr:
158             # reinterpret_cast is necessary for GLubyte * <=> char *
159             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
160         if string.length is not None:
161             length = ', %s' % string.length
162         else:
163             length = ''
164         print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
165
166     def visitConst(self, const, instance):
167         self.visit(const.type, instance)
168
169     def visitStruct(self, struct, instance):
170         print '    _write__%s(%s);' % (struct.tag, instance)
171
172     def visitArray(self, array, instance):
173         length = '__c' + array.type.tag
174         index = '__i' + array.type.tag
175         print '    if (%s) {' % instance
176         print '        size_t %s = %s;' % (length, array.length)
177         print '        trace::localWriter.beginArray(%s);' % length
178         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
179         print '            trace::localWriter.beginElement();'
180         self.visit(array.type, '(%s)[%s]' % (instance, index))
181         print '            trace::localWriter.endElement();'
182         print '        }'
183         print '        trace::localWriter.endArray();'
184         print '    } else {'
185         print '        trace::localWriter.writeNull();'
186         print '    }'
187
188     def visitBlob(self, blob, instance):
189         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
190
191     def visitEnum(self, enum, instance):
192         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
193
194     def visitBitmask(self, bitmask, instance):
195         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
196
197     def visitPointer(self, pointer, instance):
198         print '    if (%s) {' % instance
199         print '        trace::localWriter.beginArray(1);'
200         print '        trace::localWriter.beginElement();'
201         self.visit(pointer.type, "*" + instance)
202         print '        trace::localWriter.endElement();'
203         print '        trace::localWriter.endArray();'
204         print '    } else {'
205         print '        trace::localWriter.writeNull();'
206         print '    }'
207
208     def visitIntPointer(self, pointer, instance):
209         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
210
211     def visitLinearPointer(self, pointer, instance):
212         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
213
214     def visitHandle(self, handle, instance):
215         self.visit(handle.type, instance)
216
217     def visitAlias(self, alias, instance):
218         self.visit(alias.type, instance)
219
220     def visitOpaque(self, opaque, instance):
221         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
222
223     def visitInterface(self, interface, instance):
224         print '    trace::localWriter.writeOpaque((const void *)&%s);' % instance
225
226     def visitPolymorphic(self, polymorphic, instance):
227         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
228
229
230 class ValueWrapper(stdapi.Visitor):
231     '''Type visitor which will generate the code to wrap an instance.
232     
233     Wrapping is necessary mostly for interfaces, however interface pointers can
234     appear anywhere inside complex types.
235     '''
236
237     def visitVoid(self, type, instance):
238         raise NotImplementedError
239
240     def visitLiteral(self, type, instance):
241         pass
242
243     def visitString(self, type, instance):
244         pass
245
246     def visitConst(self, type, instance):
247         pass
248
249     def visitStruct(self, struct, instance):
250         for type, name in struct.members:
251             self.visit(type, "(%s).%s" % (instance, name))
252
253     def visitArray(self, array, instance):
254         # XXX: actually it is possible to return an array of pointers
255         pass
256
257     def visitBlob(self, blob, instance):
258         pass
259
260     def visitEnum(self, enum, instance):
261         pass
262
263     def visitBitmask(self, bitmask, instance):
264         pass
265
266     def visitPointer(self, pointer, instance):
267         print "    if (%s) {" % instance
268         self.visit(pointer.type, "*" + instance)
269         print "    }"
270     
271     def visitIntPointer(self, pointer, instance):
272         pass
273
274     def visitLinearPointer(self, pointer, instance):
275         pass
276
277     def visitHandle(self, handle, instance):
278         self.visit(handle.type, instance)
279
280     def visitAlias(self, alias, instance):
281         self.visit(alias.type, instance)
282
283     def visitOpaque(self, opaque, instance):
284         pass
285     
286     def visitInterface(self, interface, instance):
287         assert instance.startswith('*')
288         instance = instance[1:]
289         print "    if (%s) {" % instance
290         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
291         print "    }"
292     
293     def visitPolymorphic(self, type, instance):
294         # XXX: There might be polymorphic values that need wrapping in the future
295         pass
296
297
298 class ValueUnwrapper(ValueWrapper):
299     '''Reverse of ValueWrapper.'''
300
301     def visitInterface(self, interface, instance):
302         assert instance.startswith('*')
303         instance = instance[1:]
304         print "    if (%s) {" % instance
305         print "        %s = static_cast<%s *>(%s)->m_pInstance;" % (instance, getWrapperInterfaceName(interface), instance)
306         print "    }"
307
308
309 class Tracer:
310     '''Base class to orchestrate the code generation of API tracing.'''
311
312     def __init__(self):
313         self.api = None
314
315     def serializerFactory(self):
316         '''Create a serializer.
317         
318         Can be overriden by derived classes to inject their own serialzer.
319         '''
320
321         return ValueSerializer()
322
323     def trace_api(self, api):
324         self.api = api
325
326         self.header(api)
327
328         # Includes
329         for header in api.headers:
330             print header
331         print
332
333         # Generate the serializer functions
334         types = api.getAllTypes()
335         visitor = ComplexValueSerializer(self.serializerFactory())
336         map(visitor.visit, types)
337         print
338
339         # Interfaces wrapers
340         interfaces = api.getAllInterfaces()
341         map(self.declareWrapperInterface, interfaces)
342         map(self.implementWrapperInterface, interfaces)
343         print
344
345         # Function wrappers
346         map(self.traceFunctionDecl, api.functions)
347         map(self.traceFunctionImpl, api.functions)
348         print
349
350         self.footer(api)
351
352     def header(self, api):
353         pass
354
355     def footer(self, api):
356         pass
357
358     def traceFunctionDecl(self, function):
359         # Per-function declarations
360
361         if function.args:
362             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
363         else:
364             print 'static const char ** __%s_args = NULL;' % (function.name,)
365         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
366         print
367
368     def isFunctionPublic(self, function):
369         return True
370
371     def traceFunctionImpl(self, function):
372         if self.isFunctionPublic(function):
373             print 'extern "C" PUBLIC'
374         else:
375             print 'extern "C" PRIVATE'
376         print function.prototype() + ' {'
377         if function.type is not stdapi.Void:
378             print '    %s __result;' % function.type
379         self.traceFunctionImplBody(function)
380         if function.type is not stdapi.Void:
381             self.wrapRet(function, "__result")
382             print '    return __result;'
383         print '}'
384         print
385
386     def traceFunctionImplBody(self, function):
387         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
388         for arg in function.args:
389             if not arg.output:
390                 self.unwrapArg(function, arg)
391                 self.serializeArg(function, arg)
392         print '    trace::localWriter.endEnter();'
393         self.invokeFunction(function)
394         print '    trace::localWriter.beginLeave(__call);'
395         for arg in function.args:
396             if arg.output:
397                 self.serializeArg(function, arg)
398                 self.wrapArg(function, arg)
399         if function.type is not stdapi.Void:
400             self.serializeRet(function, "__result")
401         print '    trace::localWriter.endLeave();'
402
403     def invokeFunction(self, function, prefix='__', suffix=''):
404         if function.type is stdapi.Void:
405             result = ''
406         else:
407             result = '__result = '
408         dispatch = prefix + function.name + suffix
409         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
410
411     def serializeArg(self, function, arg):
412         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
413         self.serializeArgValue(function, arg)
414         print '    trace::localWriter.endArg();'
415
416     def serializeArgValue(self, function, arg):
417         self.serializeValue(arg.type, arg.name)
418
419     def wrapArg(self, function, arg):
420         self.wrapValue(arg.type, arg.name)
421
422     def unwrapArg(self, function, arg):
423         self.unwrapValue(arg.type, arg.name)
424
425     def serializeRet(self, function, instance):
426         print '    trace::localWriter.beginReturn();'
427         self.serializeValue(function.type, instance)
428         print '    trace::localWriter.endReturn();'
429
430     def serializeValue(self, type, instance):
431         serializer = self.serializerFactory()
432         serializer.visit(type, instance)
433
434     def wrapRet(self, function, instance):
435         self.wrapValue(function.type, instance)
436
437     def unwrapRet(self, function, instance):
438         self.unwrapValue(function.type, instance)
439
440     def wrapValue(self, type, instance):
441         visitor = ValueWrapper()
442         visitor.visit(type, instance)
443
444     def unwrapValue(self, type, instance):
445         visitor = ValueUnwrapper()
446         visitor.visit(type, instance)
447
448     def declareWrapperInterface(self, interface):
449         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
450         print "{"
451         print "public:"
452         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
453         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
454         print
455         for method in interface.iterMethods():
456             print "    " + method.prototype() + ";"
457         print
458         self.declareWrapperInterfaceVariables(interface)
459         print "};"
460         print
461
462     def declareWrapperInterfaceVariables(self, interface):
463         #print "private:"
464         print "    %s * m_pInstance;" % (interface.name,)
465
466     def implementWrapperInterface(self, interface):
467         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
468         print '    m_pInstance = pInstance;'
469         print '}'
470         print
471         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
472         print '}'
473         print
474         for method in interface.iterMethods():
475             self.implementWrapperInterfaceMethod(interface, method)
476         print
477
478     def implementWrapperInterfaceMethod(self, interface, method):
479         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
480         if method.type is not stdapi.Void:
481             print '    %s __result;' % method.type
482     
483         self.implementWrapperInterfaceMethodBody(interface, method)
484     
485         if method.type is not stdapi.Void:
486             print '    return __result;'
487         print '}'
488         print
489
490     def implementWrapperInterfaceMethodBody(self, interface, method):
491         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
492         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
493         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
494         print '    trace::localWriter.beginArg(0);'
495         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
496         print '    trace::localWriter.endArg();'
497
498         from specs.winapi import REFIID
499         from specs.stdapi import Pointer, Opaque
500
501         riid = None
502         for arg in method.args:
503             if not arg.output:
504                 self.unwrapArg(method, arg)
505                 self.serializeArg(method, arg)
506                 if arg.type is REFIID:
507                     riid = arg
508         print '    trace::localWriter.endEnter();'
509         
510         self.invokeMethod(interface, method)
511
512         print '    trace::localWriter.beginLeave(__call);'
513         for arg in method.args:
514             if arg.output:
515                 self.serializeArg(method, arg)
516                 self.wrapArg(method, arg)
517                 if riid is not None and isinstance(arg.type, Pointer):
518                     assert isinstance(arg.type.type, Opaque)
519                     self.wrapIid(interface, method, riid, arg)
520
521         if method.type is not stdapi.Void:
522             print '    trace::localWriter.beginReturn();'
523             self.serializeValue(method.type, "__result")
524             print '    trace::localWriter.endReturn();'
525             self.wrapValue(method.type, '__result')
526         print '    trace::localWriter.endLeave();'
527         if method.name == 'Release':
528             assert method.type is not stdapi.Void
529             print '    if (!__result)'
530             print '        delete this;'
531
532     def wrapIid(self, interface, method, riid, out):
533             print '    if (%s && *%s) {' % (out.name, out.name)
534             print '        if (*%s == m_pInstance) {' % (out.name,)
535             print '            *%s = this;' % (out.name,)
536             print '        }'
537             for iface in self.api.interfaces:
538                 print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
539                 print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
540                 print r'        }'
541             print r'        else {'
542             print r'            os::log("apitrace: warning: unknown REFIID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
543             print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
544             print r'                    %s.Data4[0],' % (riid.name,)
545             print r'                    %s.Data4[1],' % (riid.name,)
546             print r'                    %s.Data4[2],' % (riid.name,)
547             print r'                    %s.Data4[3],' % (riid.name,)
548             print r'                    %s.Data4[4],' % (riid.name,)
549             print r'                    %s.Data4[5],' % (riid.name,)
550             print r'                    %s.Data4[6],' % (riid.name,)
551             print r'                    %s.Data4[7]);' % (riid.name,)
552             print r'        }'
553             print '    }'
554
555     def invokeMethod(self, interface, method):
556         if method.type is stdapi.Void:
557             result = ''
558         else:
559             result = '__result = '
560         print '    %sm_pInstance->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
561     
562     def emit_memcpy(self, dest, src, length):
563         print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
564         print '        trace::localWriter.beginArg(0);'
565         print '        trace::localWriter.writeOpaque(%s);' % dest
566         print '        trace::localWriter.endArg();'
567         print '        trace::localWriter.beginArg(1);'
568         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
569         print '        trace::localWriter.endArg();'
570         print '        trace::localWriter.beginArg(2);'
571         print '        trace::localWriter.writeUInt(%s);' % length
572         print '        trace::localWriter.endArg();'
573         print '        trace::localWriter.endEnter();'
574         print '        trace::localWriter.beginLeave(__call);'
575         print '        trace::localWriter.endLeave();'
576