]> git.cworth.org Git - apitrace/blob - wrappers/trace.py
Handle REFIIDs on functions too.
[apitrace] / wrappers / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 # Adjust path
30 import os.path
31 import sys
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
33
34
35 import specs.stdapi as stdapi
36
37
38 def getWrapperInterfaceName(interface):
39     return "Wrap" + interface.expr
40
41
42 class ComplexValueSerializer(stdapi.OnceVisitor):
43     '''Type visitors which generates serialization functions for
44     complex types.
45     
46     Simple types are serialized inline.
47     '''
48
49     def __init__(self, serializer):
50         stdapi.OnceVisitor.__init__(self)
51         self.serializer = serializer
52
53     def visitVoid(self, literal):
54         pass
55
56     def visitLiteral(self, literal):
57         pass
58
59     def visitString(self, string):
60         pass
61
62     def visitConst(self, const):
63         self.visit(const.type)
64
65     def visitStruct(self, struct):
66         for type, name in struct.members:
67             self.visit(type)
68         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
69         print '    static const char * members[%u] = {' % (len(struct.members),)
70         for type, name,  in struct.members:
71             print '        "%s",' % (name,)
72         print '    };'
73         print '    static const trace::StructSig sig = {'
74         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
75         print '    };'
76         print '    trace::localWriter.beginStruct(&sig);'
77         for type, name in struct.members:
78             self.serializer.visit(type, 'value.%s' % (name,))
79         print '    trace::localWriter.endStruct();'
80         print '}'
81         print
82
83     def visitArray(self, array):
84         self.visit(array.type)
85
86     def visitBlob(self, array):
87         pass
88
89     def visitEnum(self, enum):
90         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
91         for value in enum.values:
92             print '   {"%s", %s},' % (value, value)
93         print '};'
94         print
95         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
96         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
97         print '};'
98         print
99
100     def visitBitmask(self, bitmask):
101         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
102         for value in bitmask.values:
103             print '   {"%s", %s},' % (value, value)
104         print '};'
105         print
106         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
107         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
108         print '};'
109         print
110
111     def visitPointer(self, pointer):
112         self.visit(pointer.type)
113
114     def visitIntPointer(self, pointer):
115         pass
116
117     def visitObjPointer(self, pointer):
118         self.visit(pointer.type)
119
120     def visitLinearPointer(self, pointer):
121         self.visit(pointer.type)
122
123     def visitHandle(self, handle):
124         self.visit(handle.type)
125
126     def visitReference(self, reference):
127         self.visit(reference.type)
128
129     def visitAlias(self, alias):
130         self.visit(alias.type)
131
132     def visitOpaque(self, opaque):
133         pass
134
135     def visitInterface(self, interface):
136         pass
137
138     def visitPolymorphic(self, polymorphic):
139         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
140         print '    switch (selector) {'
141         for cases, type in polymorphic.iterSwitch():
142             for case in cases:
143                 print '    %s:' % case
144             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
145             print '        break;'
146         print '    }'
147         print '}'
148         print
149
150
151 class ValueSerializer(stdapi.Visitor):
152     '''Visitor which generates code to serialize any type.
153     
154     Simple types are serialized inline here, whereas the serialization of
155     complex types is dispatched to the serialization functions generated by
156     ComplexValueSerializer visitor above.
157     '''
158
159     def visitLiteral(self, literal, instance):
160         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
161
162     def visitString(self, string, instance):
163         if string.kind == 'String':
164             cast = 'const char *'
165         elif string.kind == 'WString':
166             cast = 'const wchar_t *'
167         else:
168             assert False
169         if cast != string.expr:
170             # reinterpret_cast is necessary for GLubyte * <=> char *
171             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
172         if string.length is not None:
173             length = ', %s' % string.length
174         else:
175             length = ''
176         print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
177
178     def visitConst(self, const, instance):
179         self.visit(const.type, instance)
180
181     def visitStruct(self, struct, instance):
182         print '    _write__%s(%s);' % (struct.tag, instance)
183
184     def visitArray(self, array, instance):
185         length = '__c' + array.type.tag
186         index = '__i' + array.type.tag
187         print '    if (%s) {' % instance
188         print '        size_t %s = %s;' % (length, array.length)
189         print '        trace::localWriter.beginArray(%s);' % length
190         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
191         print '            trace::localWriter.beginElement();'
192         self.visit(array.type, '(%s)[%s]' % (instance, index))
193         print '            trace::localWriter.endElement();'
194         print '        }'
195         print '        trace::localWriter.endArray();'
196         print '    } else {'
197         print '        trace::localWriter.writeNull();'
198         print '    }'
199
200     def visitBlob(self, blob, instance):
201         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
202
203     def visitEnum(self, enum, instance):
204         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
205
206     def visitBitmask(self, bitmask, instance):
207         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
208
209     def visitPointer(self, pointer, instance):
210         print '    if (%s) {' % instance
211         print '        trace::localWriter.beginArray(1);'
212         print '        trace::localWriter.beginElement();'
213         self.visit(pointer.type, "*" + instance)
214         print '        trace::localWriter.endElement();'
215         print '        trace::localWriter.endArray();'
216         print '    } else {'
217         print '        trace::localWriter.writeNull();'
218         print '    }'
219
220     def visitIntPointer(self, pointer, instance):
221         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
222
223     def visitObjPointer(self, pointer, instance):
224         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
225
226     def visitLinearPointer(self, pointer, instance):
227         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
228
229     def visitReference(self, reference, instance):
230         self.visit(reference.type, instance)
231
232     def visitHandle(self, handle, instance):
233         self.visit(handle.type, instance)
234
235     def visitAlias(self, alias, instance):
236         self.visit(alias.type, instance)
237
238     def visitOpaque(self, opaque, instance):
239         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
240
241     def visitInterface(self, interface, instance):
242         print '    trace::localWriter.writeOpaque((const void *)&%s);' % instance
243
244     def visitPolymorphic(self, polymorphic, instance):
245         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
246
247
248 class WrapDecider(stdapi.Traverser):
249     '''Type visitor which will decide wheter this type will need wrapping or not.
250     
251     For complex types (arrays, structures), we need to know this before hand.
252     '''
253
254     def __init__(self):
255         self.needsWrapping = False
256
257     def visitLinearPointer(self, void):
258         pass
259
260     def visitInterface(self, interface):
261         self.needsWrapping = True
262
263
264 class ValueWrapper(stdapi.Traverser):
265     '''Type visitor which will generate the code to wrap an instance.
266     
267     Wrapping is necessary mostly for interfaces, however interface pointers can
268     appear anywhere inside complex types.
269     '''
270
271     def visitStruct(self, struct, instance):
272         for type, name in struct.members:
273             self.visit(type, "(%s).%s" % (instance, name))
274
275     def visitArray(self, array, instance):
276         print "    if (%s) {" % instance
277         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
278         self.visit(array.type, instance + "[_i]")
279         print "        }"
280         print "    }"
281
282     def visitPointer(self, pointer, instance):
283         print "    if (%s) {" % instance
284         self.visit(pointer.type, "*" + instance)
285         print "    }"
286     
287     def visitObjPointer(self, pointer, instance):
288         elem_type = pointer.type.mutable()
289         if isinstance(elem_type, stdapi.Interface):
290             self.visitInterfacePointer(elem_type, instance)
291         else:
292             self.visitPointer(self, pointer, instance)
293     
294     def visitInterface(self, interface, instance):
295         raise NotImplementedError
296
297     def visitInterfacePointer(self, interface, instance):
298         print "    if (%s) {" % instance
299         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
300         print "    }"
301     
302     def visitPolymorphic(self, type, instance):
303         # XXX: There might be polymorphic values that need wrapping in the future
304         raise NotImplementedError
305
306
307 class ValueUnwrapper(ValueWrapper):
308     '''Reverse of ValueWrapper.'''
309
310     allocated = False
311
312     def visitArray(self, array, instance):
313         if self.allocated or isinstance(instance, stdapi.Interface):
314             return ValueWrapper.visitArray(self, array, instance)
315         elem_type = array.type.mutable()
316         print "    if (%s && %s) {" % (instance, array.length)
317         print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
318         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
319         print "            _t[_i] = %s[_i];" % instance 
320         self.allocated = True
321         self.visit(array.type, "_t[_i]")
322         print "        }"
323         print "        %s = _t;" % instance
324         print "    }"
325
326     def visitInterfacePointer(self, interface, instance):
327         print r'    if (%s) {' % instance
328         print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
329         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
330         print r'            %s = pWrapper->m_pInstance;' % (instance,)
331         print r'        } else {'
332         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
333         print r'        }'
334         print r'    }'
335
336
337 class Tracer:
338     '''Base class to orchestrate the code generation of API tracing.'''
339
340     def __init__(self):
341         self.api = None
342
343     def serializerFactory(self):
344         '''Create a serializer.
345         
346         Can be overriden by derived classes to inject their own serialzer.
347         '''
348
349         return ValueSerializer()
350
351     def trace_api(self, api):
352         self.api = api
353
354         self.header(api)
355
356         # Includes
357         for header in api.headers:
358             print header
359         print
360
361         # Generate the serializer functions
362         types = api.getAllTypes()
363         visitor = ComplexValueSerializer(self.serializerFactory())
364         map(visitor.visit, types)
365         print
366
367         # Interfaces wrapers
368         interfaces = api.getAllInterfaces()
369         map(self.declareWrapperInterface, interfaces)
370         map(self.implementWrapperInterface, interfaces)
371         print
372
373         # Function wrappers
374         self.interface = None
375         self.base = None
376         map(self.traceFunctionDecl, api.functions)
377         map(self.traceFunctionImpl, api.functions)
378         print
379
380         self.footer(api)
381
382     def header(self, api):
383         pass
384
385     def footer(self, api):
386         pass
387
388     def traceFunctionDecl(self, function):
389         # Per-function declarations
390
391         if function.args:
392             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
393         else:
394             print 'static const char ** __%s_args = NULL;' % (function.name,)
395         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
396         print
397
398     def isFunctionPublic(self, function):
399         return True
400
401     def traceFunctionImpl(self, function):
402         if self.isFunctionPublic(function):
403             print 'extern "C" PUBLIC'
404         else:
405             print 'extern "C" PRIVATE'
406         print function.prototype() + ' {'
407         if function.type is not stdapi.Void:
408             print '    %s __result;' % function.type
409         self.traceFunctionImplBody(function)
410         if function.type is not stdapi.Void:
411             self.wrapRet(function, "__result")
412             print '    return __result;'
413         print '}'
414         print
415
416     def traceFunctionImplBody(self, function):
417         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
418         for arg in function.args:
419             if not arg.output:
420                 self.unwrapArg(function, arg)
421                 self.serializeArg(function, arg)
422         print '    trace::localWriter.endEnter();'
423         self.invokeFunction(function)
424         print '    trace::localWriter.beginLeave(__call);'
425         for arg in function.args:
426             if arg.output:
427                 self.serializeArg(function, arg)
428                 self.wrapArg(function, arg)
429         if function.type is not stdapi.Void:
430             self.serializeRet(function, "__result")
431         print '    trace::localWriter.endLeave();'
432
433     def invokeFunction(self, function, prefix='__', suffix=''):
434         if function.type is stdapi.Void:
435             result = ''
436         else:
437             result = '__result = '
438         dispatch = prefix + function.name + suffix
439         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
440
441     def serializeArg(self, function, arg):
442         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
443         self.serializeArgValue(function, arg)
444         print '    trace::localWriter.endArg();'
445
446     def serializeArgValue(self, function, arg):
447         self.serializeValue(arg.type, arg.name)
448
449     def wrapArg(self, function, arg):
450         assert not isinstance(arg.type, stdapi.ObjPointer)
451
452         from specs.winapi import REFIID
453         riid = None
454         for other_arg in function.args:
455             if not other_arg.output and other_arg.type is REFIID:
456                 riid = other_arg
457         if riid is not None and isinstance(arg.type, stdapi.Pointer):
458             assert isinstance(arg.type.type, stdapi.ObjPointer)
459             obj_type = arg.type.type.type
460             assert obj_type is stdapi.Void
461             self.wrapIid(function, riid, arg)
462             return
463
464         self.wrapValue(arg.type, arg.name)
465
466     def unwrapArg(self, function, arg):
467         self.unwrapValue(arg.type, arg.name)
468
469     def serializeRet(self, function, instance):
470         print '    trace::localWriter.beginReturn();'
471         self.serializeValue(function.type, instance)
472         print '    trace::localWriter.endReturn();'
473
474     def serializeValue(self, type, instance):
475         serializer = self.serializerFactory()
476         serializer.visit(type, instance)
477
478     def wrapRet(self, function, instance):
479         self.wrapValue(function.type, instance)
480
481     def unwrapRet(self, function, instance):
482         self.unwrapValue(function.type, instance)
483
484     def needsWrapping(self, type):
485         visitor = WrapDecider()
486         visitor.visit(type)
487         return visitor.needsWrapping
488
489     def wrapValue(self, type, instance):
490         if self.needsWrapping(type):
491             visitor = ValueWrapper()
492             visitor.visit(type, instance)
493
494     def unwrapValue(self, type, instance):
495         if self.needsWrapping(type):
496             visitor = ValueUnwrapper()
497             visitor.visit(type, instance)
498
499     def declareWrapperInterface(self, interface):
500         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
501         print "{"
502         print "public:"
503         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
504         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
505         print
506         for method in interface.iterMethods():
507             print "    " + method.prototype() + ";"
508         print
509         self.declareWrapperInterfaceVariables(interface)
510         print "};"
511         print
512
513     def declareWrapperInterfaceVariables(self, interface):
514         #print "private:"
515         print "    DWORD m_dwMagic;"
516         print "    %s * m_pInstance;" % (interface.name,)
517
518     def implementWrapperInterface(self, interface):
519         self.interface = interface
520
521         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
522         print '    m_dwMagic = 0xd8365d6c;'
523         print '    m_pInstance = pInstance;'
524         print '}'
525         print
526         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
527         print '}'
528         print
529         
530         for base, method in interface.iterBaseMethods():
531             self.base = base
532             self.implementWrapperInterfaceMethod(interface, base, method)
533
534         print
535
536     def implementWrapperInterfaceMethod(self, interface, base, method):
537         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
538         if method.type is not stdapi.Void:
539             print '    %s __result;' % method.type
540     
541         self.implementWrapperInterfaceMethodBody(interface, base, method)
542     
543         if method.type is not stdapi.Void:
544             print '    return __result;'
545         print '}'
546         print
547
548     def implementWrapperInterfaceMethodBody(self, interface, base, method):
549         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
550         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
551         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
552         print '    trace::localWriter.beginArg(0);'
553         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
554         print '    trace::localWriter.endArg();'
555         for arg in method.args:
556             if not arg.output:
557                 self.unwrapArg(method, arg)
558                 self.serializeArg(method, arg)
559         print '    trace::localWriter.endEnter();'
560         
561         self.invokeMethod(interface, base, method)
562
563         print '    trace::localWriter.beginLeave(__call);'
564         for arg in method.args:
565             if arg.output:
566                 self.serializeArg(method, arg)
567                 self.wrapArg(method, arg)
568
569         if method.type is not stdapi.Void:
570             print '    trace::localWriter.beginReturn();'
571             self.serializeValue(method.type, "__result")
572             print '    trace::localWriter.endReturn();'
573             self.wrapValue(method.type, '__result')
574         print '    trace::localWriter.endLeave();'
575         if method.name == 'Release':
576             assert method.type is not stdapi.Void
577             print '    if (!__result)'
578             print '        delete this;'
579
580     def wrapIid(self, function, riid, out):
581         print r'    if (%s && *%s) {' % (out.name, out.name)
582         function_name = function.name
583         else_ = ''
584         if self.interface is not None:
585             function_name = self.interface.name + '::' + function_name
586             print r'        %sif (*%s == m_pInstance) {' % (else_, out.name,)
587             print r'            *%s = this;' % (out.name,)
588             print r'        }'
589             else_ = 'else '
590         for iface in self.api.getAllInterfaces():
591             print r'        %sif (%s == IID_%s) {' % (else_, riid.name, iface.name)
592             print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
593             print r'        }'
594             else_ = 'else '
595         print r'        %s{' % else_
596         print r'            os::log("apitrace: warning: %s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
597         print r'                    "%s",' % (function_name)
598         print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
599         print r'                    %s.Data4[0],' % (riid.name,)
600         print r'                    %s.Data4[1],' % (riid.name,)
601         print r'                    %s.Data4[2],' % (riid.name,)
602         print r'                    %s.Data4[3],' % (riid.name,)
603         print r'                    %s.Data4[4],' % (riid.name,)
604         print r'                    %s.Data4[5],' % (riid.name,)
605         print r'                    %s.Data4[6],' % (riid.name,)
606         print r'                    %s.Data4[7]);' % (riid.name,)
607         print r'        }'
608         print r'    }'
609
610     def invokeMethod(self, interface, base, method):
611         if method.type is stdapi.Void:
612             result = ''
613         else:
614             result = '__result = '
615         print '    %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args]))
616     
617     def emit_memcpy(self, dest, src, length):
618         print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
619         print '        trace::localWriter.beginArg(0);'
620         print '        trace::localWriter.writeOpaque(%s);' % dest
621         print '        trace::localWriter.endArg();'
622         print '        trace::localWriter.beginArg(1);'
623         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
624         print '        trace::localWriter.endArg();'
625         print '        trace::localWriter.beginArg(2);'
626         print '        trace::localWriter.writeUInt(%s);' % length
627         print '        trace::localWriter.endArg();'
628         print '        trace::localWriter.endEnter();'
629         print '        trace::localWriter.beginLeave(__call);'
630         print '        trace::localWriter.endLeave();'
631