]> git.cworth.org Git - apitrace/blob - wrappers/trace.py
Move tracers to wrappers subdirectory.
[apitrace] / wrappers / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 # Adjust path
30 import os.path
31 import sys
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
33
34
35 import specs.stdapi as stdapi
36
37
38 def getWrapperInterfaceName(interface):
39     return "Wrap" + interface.expr
40
41
42 class ComplexValueSerializer(stdapi.OnceVisitor):
43     '''Type visitors which generates serialization functions for
44     complex types.
45     
46     Simple types are serialized inline.
47     '''
48
49     def __init__(self, serializer):
50         stdapi.OnceVisitor.__init__(self)
51         self.serializer = serializer
52
53     def visitVoid(self, literal):
54         pass
55
56     def visitLiteral(self, literal):
57         pass
58
59     def visitString(self, string):
60         pass
61
62     def visitConst(self, const):
63         self.visit(const.type)
64
65     def visitStruct(self, struct):
66         for type, name in struct.members:
67             self.visit(type)
68         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
69         print '    static const char * members[%u] = {' % (len(struct.members),)
70         for type, name,  in struct.members:
71             print '        "%s",' % (name,)
72         print '    };'
73         print '    static const trace::StructSig sig = {'
74         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
75         print '    };'
76         print '    trace::localWriter.beginStruct(&sig);'
77         for type, name in struct.members:
78             self.serializer.visit(type, 'value.%s' % (name,))
79         print '    trace::localWriter.endStruct();'
80         print '}'
81         print
82
83     def visitArray(self, array):
84         self.visit(array.type)
85
86     def visitBlob(self, array):
87         pass
88
89     def visitEnum(self, enum):
90         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
91         for value in enum.values:
92             print '   {"%s", %s},' % (value, value)
93         print '};'
94         print
95         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
96         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
97         print '};'
98         print
99
100     def visitBitmask(self, bitmask):
101         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
102         for value in bitmask.values:
103             print '   {"%s", %s},' % (value, value)
104         print '};'
105         print
106         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
107         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
108         print '};'
109         print
110
111     def visitPointer(self, pointer):
112         self.visit(pointer.type)
113
114     def visitIntPointer(self, pointer):
115         pass
116
117     def visitObjPointer(self, pointer):
118         self.visit(pointer.type)
119
120     def visitLinearPointer(self, pointer):
121         self.visit(pointer.type)
122
123     def visitHandle(self, handle):
124         self.visit(handle.type)
125
126     def visitReference(self, reference):
127         self.visit(reference.type)
128
129     def visitAlias(self, alias):
130         self.visit(alias.type)
131
132     def visitOpaque(self, opaque):
133         pass
134
135     def visitInterface(self, interface):
136         pass
137
138     def visitPolymorphic(self, polymorphic):
139         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
140         print '    switch (selector) {'
141         for cases, type in polymorphic.iterSwitch():
142             for case in cases:
143                 print '    %s:' % case
144             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
145             print '        break;'
146         print '    }'
147         print '}'
148         print
149
150
151 class ValueSerializer(stdapi.Visitor):
152     '''Visitor which generates code to serialize any type.
153     
154     Simple types are serialized inline here, whereas the serialization of
155     complex types is dispatched to the serialization functions generated by
156     ComplexValueSerializer visitor above.
157     '''
158
159     def visitLiteral(self, literal, instance):
160         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
161
162     def visitString(self, string, instance):
163         if string.kind == 'String':
164             cast = 'const char *'
165         elif string.kind == 'WString':
166             cast = 'const wchar_t *'
167         else:
168             assert False
169         if cast != string.expr:
170             # reinterpret_cast is necessary for GLubyte * <=> char *
171             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
172         if string.length is not None:
173             length = ', %s' % string.length
174         else:
175             length = ''
176         print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
177
178     def visitConst(self, const, instance):
179         self.visit(const.type, instance)
180
181     def visitStruct(self, struct, instance):
182         print '    _write__%s(%s);' % (struct.tag, instance)
183
184     def visitArray(self, array, instance):
185         length = '__c' + array.type.tag
186         index = '__i' + array.type.tag
187         print '    if (%s) {' % instance
188         print '        size_t %s = %s;' % (length, array.length)
189         print '        trace::localWriter.beginArray(%s);' % length
190         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
191         print '            trace::localWriter.beginElement();'
192         self.visit(array.type, '(%s)[%s]' % (instance, index))
193         print '            trace::localWriter.endElement();'
194         print '        }'
195         print '        trace::localWriter.endArray();'
196         print '    } else {'
197         print '        trace::localWriter.writeNull();'
198         print '    }'
199
200     def visitBlob(self, blob, instance):
201         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
202
203     def visitEnum(self, enum, instance):
204         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
205
206     def visitBitmask(self, bitmask, instance):
207         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
208
209     def visitPointer(self, pointer, instance):
210         print '    if (%s) {' % instance
211         print '        trace::localWriter.beginArray(1);'
212         print '        trace::localWriter.beginElement();'
213         self.visit(pointer.type, "*" + instance)
214         print '        trace::localWriter.endElement();'
215         print '        trace::localWriter.endArray();'
216         print '    } else {'
217         print '        trace::localWriter.writeNull();'
218         print '    }'
219
220     def visitIntPointer(self, pointer, instance):
221         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
222
223     def visitObjPointer(self, pointer, instance):
224         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
225
226     def visitLinearPointer(self, pointer, instance):
227         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
228
229     def visitReference(self, reference, instance):
230         self.visit(reference.type, instance)
231
232     def visitHandle(self, handle, instance):
233         self.visit(handle.type, instance)
234
235     def visitAlias(self, alias, instance):
236         self.visit(alias.type, instance)
237
238     def visitOpaque(self, opaque, instance):
239         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
240
241     def visitInterface(self, interface, instance):
242         print '    trace::localWriter.writeOpaque((const void *)&%s);' % instance
243
244     def visitPolymorphic(self, polymorphic, instance):
245         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
246
247
248 class ValueWrapper(stdapi.Visitor):
249     '''Type visitor which will generate the code to wrap an instance.
250     
251     Wrapping is necessary mostly for interfaces, however interface pointers can
252     appear anywhere inside complex types.
253     '''
254
255     def visitVoid(self, type, instance):
256         raise NotImplementedError
257
258     def visitLiteral(self, type, instance):
259         pass
260
261     def visitString(self, type, instance):
262         pass
263
264     def visitConst(self, type, instance):
265         pass
266
267     def visitStruct(self, struct, instance):
268         for type, name in struct.members:
269             self.visit(type, "(%s).%s" % (instance, name))
270
271     def visitArray(self, array, instance):
272         # XXX: actually it is possible to return an array of pointers
273         pass
274
275     def visitBlob(self, blob, instance):
276         pass
277
278     def visitEnum(self, enum, instance):
279         pass
280
281     def visitBitmask(self, bitmask, instance):
282         pass
283
284     def visitPointer(self, pointer, instance):
285         print "    if (%s) {" % instance
286         self.visit(pointer.type, "*" + instance)
287         print "    }"
288     
289     def visitIntPointer(self, pointer, instance):
290         pass
291
292     def visitObjPointer(self, pointer, instance):
293         print "    if (%s) {" % instance
294         self.visit(pointer.type, "*" + instance)
295         print "    }"
296     
297     def visitLinearPointer(self, pointer, instance):
298         pass
299
300     def visitReference(self, reference, instance):
301         self.visit(reference.type, instance)
302     
303     def visitHandle(self, handle, instance):
304         self.visit(handle.type, instance)
305
306     def visitAlias(self, alias, instance):
307         self.visit(alias.type, instance)
308
309     def visitOpaque(self, opaque, instance):
310         pass
311     
312     def visitInterface(self, interface, instance):
313         assert instance.startswith('*')
314         instance = instance[1:]
315         print "    if (%s) {" % instance
316         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
317         print "    }"
318     
319     def visitPolymorphic(self, type, instance):
320         # XXX: There might be polymorphic values that need wrapping in the future
321         pass
322
323
324 class ValueUnwrapper(ValueWrapper):
325     '''Reverse of ValueWrapper.'''
326
327     def visitInterface(self, interface, instance):
328         assert instance.startswith('*')
329         instance = instance[1:]
330         print r'    if (%s) {' % instance
331         print r'        %s *pWrapper = static_cast<%s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
332         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
333         print r'            %s = pWrapper->m_pInstance;' % (instance,)
334         print r'        } else {'
335         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
336         print r'        }'
337         print r'    }'
338
339
340 class Tracer:
341     '''Base class to orchestrate the code generation of API tracing.'''
342
343     def __init__(self):
344         self.api = None
345
346     def serializerFactory(self):
347         '''Create a serializer.
348         
349         Can be overriden by derived classes to inject their own serialzer.
350         '''
351
352         return ValueSerializer()
353
354     def trace_api(self, api):
355         self.api = api
356
357         self.header(api)
358
359         # Includes
360         for header in api.headers:
361             print header
362         print
363
364         # Generate the serializer functions
365         types = api.getAllTypes()
366         visitor = ComplexValueSerializer(self.serializerFactory())
367         map(visitor.visit, types)
368         print
369
370         # Interfaces wrapers
371         interfaces = api.getAllInterfaces()
372         map(self.declareWrapperInterface, interfaces)
373         map(self.implementWrapperInterface, interfaces)
374         print
375
376         # Function wrappers
377         map(self.traceFunctionDecl, api.functions)
378         map(self.traceFunctionImpl, api.functions)
379         print
380
381         self.footer(api)
382
383     def header(self, api):
384         pass
385
386     def footer(self, api):
387         pass
388
389     def traceFunctionDecl(self, function):
390         # Per-function declarations
391
392         if function.args:
393             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
394         else:
395             print 'static const char ** __%s_args = NULL;' % (function.name,)
396         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
397         print
398
399     def isFunctionPublic(self, function):
400         return True
401
402     def traceFunctionImpl(self, function):
403         if self.isFunctionPublic(function):
404             print 'extern "C" PUBLIC'
405         else:
406             print 'extern "C" PRIVATE'
407         print function.prototype() + ' {'
408         if function.type is not stdapi.Void:
409             print '    %s __result;' % function.type
410         self.traceFunctionImplBody(function)
411         if function.type is not stdapi.Void:
412             self.wrapRet(function, "__result")
413             print '    return __result;'
414         print '}'
415         print
416
417     def traceFunctionImplBody(self, function):
418         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
419         for arg in function.args:
420             if not arg.output:
421                 self.unwrapArg(function, arg)
422                 self.serializeArg(function, arg)
423         print '    trace::localWriter.endEnter();'
424         self.invokeFunction(function)
425         print '    trace::localWriter.beginLeave(__call);'
426         for arg in function.args:
427             if arg.output:
428                 self.serializeArg(function, arg)
429                 self.wrapArg(function, arg)
430         if function.type is not stdapi.Void:
431             self.serializeRet(function, "__result")
432         print '    trace::localWriter.endLeave();'
433
434     def invokeFunction(self, function, prefix='__', suffix=''):
435         if function.type is stdapi.Void:
436             result = ''
437         else:
438             result = '__result = '
439         dispatch = prefix + function.name + suffix
440         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
441
442     def serializeArg(self, function, arg):
443         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
444         self.serializeArgValue(function, arg)
445         print '    trace::localWriter.endArg();'
446
447     def serializeArgValue(self, function, arg):
448         self.serializeValue(arg.type, arg.name)
449
450     def wrapArg(self, function, arg):
451         self.wrapValue(arg.type, arg.name)
452
453     def unwrapArg(self, function, arg):
454         self.unwrapValue(arg.type, arg.name)
455
456     def serializeRet(self, function, instance):
457         print '    trace::localWriter.beginReturn();'
458         self.serializeValue(function.type, instance)
459         print '    trace::localWriter.endReturn();'
460
461     def serializeValue(self, type, instance):
462         serializer = self.serializerFactory()
463         serializer.visit(type, instance)
464
465     def wrapRet(self, function, instance):
466         self.wrapValue(function.type, instance)
467
468     def unwrapRet(self, function, instance):
469         self.unwrapValue(function.type, instance)
470
471     def wrapValue(self, type, instance):
472         visitor = ValueWrapper()
473         visitor.visit(type, instance)
474
475     def unwrapValue(self, type, instance):
476         visitor = ValueUnwrapper()
477         visitor.visit(type, instance)
478
479     def declareWrapperInterface(self, interface):
480         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
481         print "{"
482         print "public:"
483         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
484         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
485         print
486         for method in interface.iterMethods():
487             print "    " + method.prototype() + ";"
488         print
489         self.declareWrapperInterfaceVariables(interface)
490         print "};"
491         print
492
493     def declareWrapperInterfaceVariables(self, interface):
494         #print "private:"
495         print "    DWORD m_dwMagic;"
496         print "    %s * m_pInstance;" % (interface.name,)
497
498     def implementWrapperInterface(self, interface):
499         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
500         print '    m_dwMagic = 0xd8365d6c;'
501         print '    m_pInstance = pInstance;'
502         print '}'
503         print
504         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
505         print '}'
506         print
507         for base, method in interface.iterBaseMethods():
508             self.implementWrapperInterfaceMethod(interface, base, method)
509         print
510
511     def implementWrapperInterfaceMethod(self, interface, base, method):
512         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
513         if method.type is not stdapi.Void:
514             print '    %s __result;' % method.type
515     
516         self.implementWrapperInterfaceMethodBody(interface, base, method)
517     
518         if method.type is not stdapi.Void:
519             print '    return __result;'
520         print '}'
521         print
522
523     def implementWrapperInterfaceMethodBody(self, interface, base, method):
524         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
525         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
526         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
527         print '    trace::localWriter.beginArg(0);'
528         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
529         print '    trace::localWriter.endArg();'
530
531         from specs.winapi import REFIID
532         from specs.stdapi import Pointer, Opaque, Interface
533
534         riid = None
535         for arg in method.args:
536             if not arg.output:
537                 self.unwrapArg(method, arg)
538                 self.serializeArg(method, arg)
539                 if arg.type is REFIID:
540                     riid = arg
541         print '    trace::localWriter.endEnter();'
542         
543         self.invokeMethod(interface, base, method)
544
545         print '    trace::localWriter.beginLeave(__call);'
546         for arg in method.args:
547             if arg.output:
548                 self.serializeArg(method, arg)
549                 self.wrapArg(method, arg)
550                 if riid is not None and isinstance(arg.type, Pointer):
551                     if isinstance(arg.type.type, Opaque):
552                         self.wrapIid(interface, method, riid, arg)
553                     else:
554                         assert isinstance(arg.type.type, Pointer)
555                         assert isinstance(arg.type.type.type, Interface)
556
557         if method.type is not stdapi.Void:
558             print '    trace::localWriter.beginReturn();'
559             self.serializeValue(method.type, "__result")
560             print '    trace::localWriter.endReturn();'
561             self.wrapValue(method.type, '__result')
562         print '    trace::localWriter.endLeave();'
563         if method.name == 'Release':
564             assert method.type is not stdapi.Void
565             print '    if (!__result)'
566             print '        delete this;'
567
568     def wrapIid(self, interface, method, riid, out):
569             print '    if (%s && *%s) {' % (out.name, out.name)
570             print '        if (*%s == m_pInstance) {' % (out.name,)
571             print '            *%s = this;' % (out.name,)
572             print '        }'
573             for iface in self.api.getAllInterfaces():
574                 print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
575                 print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
576                 print r'        }'
577             print r'        else {'
578             print r'            os::log("apitrace: warning: %s::%s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
579             print r'                    "%s", "%s",' % (interface.name, method.name)
580             print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
581             print r'                    %s.Data4[0],' % (riid.name,)
582             print r'                    %s.Data4[1],' % (riid.name,)
583             print r'                    %s.Data4[2],' % (riid.name,)
584             print r'                    %s.Data4[3],' % (riid.name,)
585             print r'                    %s.Data4[4],' % (riid.name,)
586             print r'                    %s.Data4[5],' % (riid.name,)
587             print r'                    %s.Data4[6],' % (riid.name,)
588             print r'                    %s.Data4[7]);' % (riid.name,)
589             print r'        }'
590             print '    }'
591
592     def invokeMethod(self, interface, base, method):
593         if method.type is stdapi.Void:
594             result = ''
595         else:
596             result = '__result = '
597         print '    %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args]))
598     
599     def emit_memcpy(self, dest, src, length):
600         print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
601         print '        trace::localWriter.beginArg(0);'
602         print '        trace::localWriter.writeOpaque(%s);' % dest
603         print '        trace::localWriter.endArg();'
604         print '        trace::localWriter.beginArg(1);'
605         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
606         print '        trace::localWriter.endArg();'
607         print '        trace::localWriter.beginArg(2);'
608         print '        trace::localWriter.writeUInt(%s);' % length
609         print '        trace::localWriter.endArg();'
610         print '        trace::localWriter.endEnter();'
611         print '        trace::localWriter.beginLeave(__call);'
612         print '        trace::localWriter.endLeave();'
613