]> git.cworth.org Git - apitrace/blob - wrappers/trace.py
Use ObjPointers consistenty.
[apitrace] / wrappers / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 # Adjust path
30 import os.path
31 import sys
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
33
34
35 import specs.stdapi as stdapi
36
37
38 def getWrapperInterfaceName(interface):
39     return "Wrap" + interface.expr
40
41
42 class ComplexValueSerializer(stdapi.OnceVisitor):
43     '''Type visitors which generates serialization functions for
44     complex types.
45     
46     Simple types are serialized inline.
47     '''
48
49     def __init__(self, serializer):
50         stdapi.OnceVisitor.__init__(self)
51         self.serializer = serializer
52
53     def visitVoid(self, literal):
54         pass
55
56     def visitLiteral(self, literal):
57         pass
58
59     def visitString(self, string):
60         pass
61
62     def visitConst(self, const):
63         self.visit(const.type)
64
65     def visitStruct(self, struct):
66         for type, name in struct.members:
67             self.visit(type)
68         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
69         print '    static const char * members[%u] = {' % (len(struct.members),)
70         for type, name,  in struct.members:
71             print '        "%s",' % (name,)
72         print '    };'
73         print '    static const trace::StructSig sig = {'
74         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
75         print '    };'
76         print '    trace::localWriter.beginStruct(&sig);'
77         for type, name in struct.members:
78             self.serializer.visit(type, 'value.%s' % (name,))
79         print '    trace::localWriter.endStruct();'
80         print '}'
81         print
82
83     def visitArray(self, array):
84         self.visit(array.type)
85
86     def visitBlob(self, array):
87         pass
88
89     def visitEnum(self, enum):
90         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
91         for value in enum.values:
92             print '   {"%s", %s},' % (value, value)
93         print '};'
94         print
95         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
96         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
97         print '};'
98         print
99
100     def visitBitmask(self, bitmask):
101         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
102         for value in bitmask.values:
103             print '   {"%s", %s},' % (value, value)
104         print '};'
105         print
106         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
107         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
108         print '};'
109         print
110
111     def visitPointer(self, pointer):
112         self.visit(pointer.type)
113
114     def visitIntPointer(self, pointer):
115         pass
116
117     def visitObjPointer(self, pointer):
118         self.visit(pointer.type)
119
120     def visitLinearPointer(self, pointer):
121         self.visit(pointer.type)
122
123     def visitHandle(self, handle):
124         self.visit(handle.type)
125
126     def visitReference(self, reference):
127         self.visit(reference.type)
128
129     def visitAlias(self, alias):
130         self.visit(alias.type)
131
132     def visitOpaque(self, opaque):
133         pass
134
135     def visitInterface(self, interface):
136         pass
137
138     def visitPolymorphic(self, polymorphic):
139         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
140         print '    switch (selector) {'
141         for cases, type in polymorphic.iterSwitch():
142             for case in cases:
143                 print '    %s:' % case
144             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
145             print '        break;'
146         print '    }'
147         print '}'
148         print
149
150
151 class ValueSerializer(stdapi.Visitor):
152     '''Visitor which generates code to serialize any type.
153     
154     Simple types are serialized inline here, whereas the serialization of
155     complex types is dispatched to the serialization functions generated by
156     ComplexValueSerializer visitor above.
157     '''
158
159     def visitLiteral(self, literal, instance):
160         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
161
162     def visitString(self, string, instance):
163         if string.kind == 'String':
164             cast = 'const char *'
165         elif string.kind == 'WString':
166             cast = 'const wchar_t *'
167         else:
168             assert False
169         if cast != string.expr:
170             # reinterpret_cast is necessary for GLubyte * <=> char *
171             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
172         if string.length is not None:
173             length = ', %s' % string.length
174         else:
175             length = ''
176         print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
177
178     def visitConst(self, const, instance):
179         self.visit(const.type, instance)
180
181     def visitStruct(self, struct, instance):
182         print '    _write__%s(%s);' % (struct.tag, instance)
183
184     def visitArray(self, array, instance):
185         length = '__c' + array.type.tag
186         index = '__i' + array.type.tag
187         print '    if (%s) {' % instance
188         print '        size_t %s = %s;' % (length, array.length)
189         print '        trace::localWriter.beginArray(%s);' % length
190         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
191         print '            trace::localWriter.beginElement();'
192         self.visit(array.type, '(%s)[%s]' % (instance, index))
193         print '            trace::localWriter.endElement();'
194         print '        }'
195         print '        trace::localWriter.endArray();'
196         print '    } else {'
197         print '        trace::localWriter.writeNull();'
198         print '    }'
199
200     def visitBlob(self, blob, instance):
201         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
202
203     def visitEnum(self, enum, instance):
204         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
205
206     def visitBitmask(self, bitmask, instance):
207         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
208
209     def visitPointer(self, pointer, instance):
210         print '    if (%s) {' % instance
211         print '        trace::localWriter.beginArray(1);'
212         print '        trace::localWriter.beginElement();'
213         self.visit(pointer.type, "*" + instance)
214         print '        trace::localWriter.endElement();'
215         print '        trace::localWriter.endArray();'
216         print '    } else {'
217         print '        trace::localWriter.writeNull();'
218         print '    }'
219
220     def visitIntPointer(self, pointer, instance):
221         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
222
223     def visitObjPointer(self, pointer, instance):
224         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
225
226     def visitLinearPointer(self, pointer, instance):
227         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
228
229     def visitReference(self, reference, instance):
230         self.visit(reference.type, instance)
231
232     def visitHandle(self, handle, instance):
233         self.visit(handle.type, instance)
234
235     def visitAlias(self, alias, instance):
236         self.visit(alias.type, instance)
237
238     def visitOpaque(self, opaque, instance):
239         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
240
241     def visitInterface(self, interface, instance):
242         print '    trace::localWriter.writeOpaque((const void *)&%s);' % instance
243
244     def visitPolymorphic(self, polymorphic, instance):
245         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
246
247
248 class WrapDecider(stdapi.Traverser):
249     '''Type visitor which will decide wheter this type will need wrapping or not.
250     
251     For complex types (arrays, structures), we need to know this before hand.
252     '''
253
254     def __init__(self):
255         self.needsWrapping = False
256
257     def visitLinearPointer(self, void):
258         pass
259
260     def visitInterface(self, interface):
261         self.needsWrapping = True
262
263
264 class ValueWrapper(stdapi.Traverser):
265     '''Type visitor which will generate the code to wrap an instance.
266     
267     Wrapping is necessary mostly for interfaces, however interface pointers can
268     appear anywhere inside complex types.
269     '''
270
271     def visitStruct(self, struct, instance):
272         for type, name in struct.members:
273             self.visit(type, "(%s).%s" % (instance, name))
274
275     def visitArray(self, array, instance):
276         print "    if (%s) {" % instance
277         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
278         self.visit(array.type, instance + "[_i]")
279         print "        }"
280         print "    }"
281
282     def visitPointer(self, pointer, instance):
283         print "    if (%s) {" % instance
284         self.visit(pointer.type, "*" + instance)
285         print "    }"
286     
287     def visitObjPointer(self, pointer, instance):
288         elem_type = pointer.type.mutable()
289         if isinstance(elem_type, stdapi.Interface):
290             self.visitInterfacePointer(elem_type, instance)
291         else:
292             self.visitPointer(self, pointer, instance)
293     
294     def visitInterface(self, interface, instance):
295         raise NotImplementedError
296
297     def visitInterfacePointer(self, interface, instance):
298         print "    if (%s) {" % instance
299         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
300         print "    }"
301     
302     def visitPolymorphic(self, type, instance):
303         # XXX: There might be polymorphic values that need wrapping in the future
304         raise NotImplementedError
305
306
307 class ValueUnwrapper(ValueWrapper):
308     '''Reverse of ValueWrapper.'''
309
310     allocated = False
311
312     def visitArray(self, array, instance):
313         if self.allocated or isinstance(instance, stdapi.Interface):
314             return ValueWrapper.visitArray(self, array, instance)
315         elem_type = array.type.mutable()
316         print "    if (%s && %s) {" % (instance, array.length)
317         print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
318         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
319         print "            _t[_i] = %s[_i];" % instance 
320         self.allocated = True
321         self.visit(array.type, "_t[_i]")
322         print "        }"
323         print "        %s = _t;" % instance
324         print "    }"
325
326     def visitInterfacePointer(self, interface, instance):
327         print r'    if (%s) {' % instance
328         print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
329         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
330         print r'            %s = pWrapper->m_pInstance;' % (instance,)
331         print r'        } else {'
332         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
333         print r'        }'
334         print r'    }'
335
336
337 class Tracer:
338     '''Base class to orchestrate the code generation of API tracing.'''
339
340     def __init__(self):
341         self.api = None
342
343     def serializerFactory(self):
344         '''Create a serializer.
345         
346         Can be overriden by derived classes to inject their own serialzer.
347         '''
348
349         return ValueSerializer()
350
351     def trace_api(self, api):
352         self.api = api
353
354         self.header(api)
355
356         # Includes
357         for header in api.headers:
358             print header
359         print
360
361         # Generate the serializer functions
362         types = api.getAllTypes()
363         visitor = ComplexValueSerializer(self.serializerFactory())
364         map(visitor.visit, types)
365         print
366
367         # Interfaces wrapers
368         interfaces = api.getAllInterfaces()
369         map(self.declareWrapperInterface, interfaces)
370         map(self.implementWrapperInterface, interfaces)
371         print
372
373         # Function wrappers
374         map(self.traceFunctionDecl, api.functions)
375         map(self.traceFunctionImpl, api.functions)
376         print
377
378         self.footer(api)
379
380     def header(self, api):
381         pass
382
383     def footer(self, api):
384         pass
385
386     def traceFunctionDecl(self, function):
387         # Per-function declarations
388
389         if function.args:
390             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
391         else:
392             print 'static const char ** __%s_args = NULL;' % (function.name,)
393         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
394         print
395
396     def isFunctionPublic(self, function):
397         return True
398
399     def traceFunctionImpl(self, function):
400         if self.isFunctionPublic(function):
401             print 'extern "C" PUBLIC'
402         else:
403             print 'extern "C" PRIVATE'
404         print function.prototype() + ' {'
405         if function.type is not stdapi.Void:
406             print '    %s __result;' % function.type
407         self.traceFunctionImplBody(function)
408         if function.type is not stdapi.Void:
409             self.wrapRet(function, "__result")
410             print '    return __result;'
411         print '}'
412         print
413
414     def traceFunctionImplBody(self, function):
415         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
416         for arg in function.args:
417             if not arg.output:
418                 self.unwrapArg(function, arg)
419                 self.serializeArg(function, arg)
420         print '    trace::localWriter.endEnter();'
421         self.invokeFunction(function)
422         print '    trace::localWriter.beginLeave(__call);'
423         for arg in function.args:
424             if arg.output:
425                 self.serializeArg(function, arg)
426                 self.wrapArg(function, arg)
427         if function.type is not stdapi.Void:
428             self.serializeRet(function, "__result")
429         print '    trace::localWriter.endLeave();'
430
431     def invokeFunction(self, function, prefix='__', suffix=''):
432         if function.type is stdapi.Void:
433             result = ''
434         else:
435             result = '__result = '
436         dispatch = prefix + function.name + suffix
437         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
438
439     def serializeArg(self, function, arg):
440         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
441         self.serializeArgValue(function, arg)
442         print '    trace::localWriter.endArg();'
443
444     def serializeArgValue(self, function, arg):
445         self.serializeValue(arg.type, arg.name)
446
447     def wrapArg(self, function, arg):
448         assert not isinstance(arg.type, stdapi.ObjPointer)
449         self.wrapValue(arg.type, arg.name)
450
451     def unwrapArg(self, function, arg):
452         self.unwrapValue(arg.type, arg.name)
453
454     def serializeRet(self, function, instance):
455         print '    trace::localWriter.beginReturn();'
456         self.serializeValue(function.type, instance)
457         print '    trace::localWriter.endReturn();'
458
459     def serializeValue(self, type, instance):
460         serializer = self.serializerFactory()
461         serializer.visit(type, instance)
462
463     def wrapRet(self, function, instance):
464         self.wrapValue(function.type, instance)
465
466     def unwrapRet(self, function, instance):
467         self.unwrapValue(function.type, instance)
468
469     def needsWrapping(self, type):
470         visitor = WrapDecider()
471         visitor.visit(type)
472         return visitor.needsWrapping
473
474     def wrapValue(self, type, instance):
475         if self.needsWrapping(type):
476             visitor = ValueWrapper()
477             visitor.visit(type, instance)
478
479     def unwrapValue(self, type, instance):
480         if self.needsWrapping(type):
481             visitor = ValueUnwrapper()
482             visitor.visit(type, instance)
483
484     def declareWrapperInterface(self, interface):
485         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
486         print "{"
487         print "public:"
488         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
489         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
490         print
491         for method in interface.iterMethods():
492             print "    " + method.prototype() + ";"
493         print
494         self.declareWrapperInterfaceVariables(interface)
495         print "};"
496         print
497
498     def declareWrapperInterfaceVariables(self, interface):
499         #print "private:"
500         print "    DWORD m_dwMagic;"
501         print "    %s * m_pInstance;" % (interface.name,)
502
503     def implementWrapperInterface(self, interface):
504         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
505         print '    m_dwMagic = 0xd8365d6c;'
506         print '    m_pInstance = pInstance;'
507         print '}'
508         print
509         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
510         print '}'
511         print
512         for base, method in interface.iterBaseMethods():
513             self.implementWrapperInterfaceMethod(interface, base, method)
514         print
515
516     def implementWrapperInterfaceMethod(self, interface, base, method):
517         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
518         if method.type is not stdapi.Void:
519             print '    %s __result;' % method.type
520     
521         self.implementWrapperInterfaceMethodBody(interface, base, method)
522     
523         if method.type is not stdapi.Void:
524             print '    return __result;'
525         print '}'
526         print
527
528     def implementWrapperInterfaceMethodBody(self, interface, base, method):
529         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
530         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
531         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
532         print '    trace::localWriter.beginArg(0);'
533         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
534         print '    trace::localWriter.endArg();'
535
536         from specs.winapi import REFIID
537
538         riid = None
539         for arg in method.args:
540             if not arg.output:
541                 self.unwrapArg(method, arg)
542                 self.serializeArg(method, arg)
543                 if arg.type is REFIID:
544                     riid = arg
545         print '    trace::localWriter.endEnter();'
546         
547         self.invokeMethod(interface, base, method)
548
549         print '    trace::localWriter.beginLeave(__call);'
550         for arg in method.args:
551             if arg.output:
552                 self.serializeArg(method, arg)
553                 self.wrapArg(method, arg)
554                 if riid is not None and isinstance(arg.type, stdapi.Pointer):
555                     assert isinstance(arg.type.type, stdapi.ObjPointer)
556                     obj_type = arg.type.type.type
557                     assert obj_type is stdapi.Void
558                     self.wrapIid(interface, method, riid, arg)
559                     riid = None
560         assert riid is None
561
562         if method.type is not stdapi.Void:
563             print '    trace::localWriter.beginReturn();'
564             self.serializeValue(method.type, "__result")
565             print '    trace::localWriter.endReturn();'
566             self.wrapValue(method.type, '__result')
567         print '    trace::localWriter.endLeave();'
568         if method.name == 'Release':
569             assert method.type is not stdapi.Void
570             print '    if (!__result)'
571             print '        delete this;'
572
573     def wrapIid(self, interface, method, riid, out):
574             print '    if (%s && *%s) {' % (out.name, out.name)
575             print '        if (*%s == m_pInstance) {' % (out.name,)
576             print '            *%s = this;' % (out.name,)
577             print '        }'
578             for iface in self.api.getAllInterfaces():
579                 print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
580                 print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
581                 print r'        }'
582             print r'        else {'
583             print r'            os::log("apitrace: warning: %s::%s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
584             print r'                    "%s", "%s",' % (interface.name, method.name)
585             print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
586             print r'                    %s.Data4[0],' % (riid.name,)
587             print r'                    %s.Data4[1],' % (riid.name,)
588             print r'                    %s.Data4[2],' % (riid.name,)
589             print r'                    %s.Data4[3],' % (riid.name,)
590             print r'                    %s.Data4[4],' % (riid.name,)
591             print r'                    %s.Data4[5],' % (riid.name,)
592             print r'                    %s.Data4[6],' % (riid.name,)
593             print r'                    %s.Data4[7]);' % (riid.name,)
594             print r'        }'
595             print '    }'
596
597     def invokeMethod(self, interface, base, method):
598         if method.type is stdapi.Void:
599             result = ''
600         else:
601             result = '__result = '
602         print '    %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args]))
603     
604     def emit_memcpy(self, dest, src, length):
605         print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
606         print '        trace::localWriter.beginArg(0);'
607         print '        trace::localWriter.writeOpaque(%s);' % dest
608         print '        trace::localWriter.endArg();'
609         print '        trace::localWriter.beginArg(1);'
610         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
611         print '        trace::localWriter.endArg();'
612         print '        trace::localWriter.beginArg(2);'
613         print '        trace::localWriter.writeUInt(%s);' % length
614         print '        trace::localWriter.endArg();'
615         print '        trace::localWriter.endEnter();'
616         print '        trace::localWriter.beginLeave(__call);'
617         print '        trace::localWriter.endLeave();'
618