]> git.cworth.org Git - apitrace/blob - wrappers/trace.py
Properly (un)wrap array arguments.
[apitrace] / wrappers / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 # Adjust path
30 import os.path
31 import sys
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
33
34
35 import specs.stdapi as stdapi
36
37
38 def getWrapperInterfaceName(interface):
39     return "Wrap" + interface.expr
40
41
42 class ComplexValueSerializer(stdapi.OnceVisitor):
43     '''Type visitors which generates serialization functions for
44     complex types.
45     
46     Simple types are serialized inline.
47     '''
48
49     def __init__(self, serializer):
50         stdapi.OnceVisitor.__init__(self)
51         self.serializer = serializer
52
53     def visitVoid(self, literal):
54         pass
55
56     def visitLiteral(self, literal):
57         pass
58
59     def visitString(self, string):
60         pass
61
62     def visitConst(self, const):
63         self.visit(const.type)
64
65     def visitStruct(self, struct):
66         for type, name in struct.members:
67             self.visit(type)
68         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
69         print '    static const char * members[%u] = {' % (len(struct.members),)
70         for type, name,  in struct.members:
71             print '        "%s",' % (name,)
72         print '    };'
73         print '    static const trace::StructSig sig = {'
74         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
75         print '    };'
76         print '    trace::localWriter.beginStruct(&sig);'
77         for type, name in struct.members:
78             self.serializer.visit(type, 'value.%s' % (name,))
79         print '    trace::localWriter.endStruct();'
80         print '}'
81         print
82
83     def visitArray(self, array):
84         self.visit(array.type)
85
86     def visitBlob(self, array):
87         pass
88
89     def visitEnum(self, enum):
90         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
91         for value in enum.values:
92             print '   {"%s", %s},' % (value, value)
93         print '};'
94         print
95         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
96         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
97         print '};'
98         print
99
100     def visitBitmask(self, bitmask):
101         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
102         for value in bitmask.values:
103             print '   {"%s", %s},' % (value, value)
104         print '};'
105         print
106         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
107         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
108         print '};'
109         print
110
111     def visitPointer(self, pointer):
112         self.visit(pointer.type)
113
114     def visitIntPointer(self, pointer):
115         pass
116
117     def visitObjPointer(self, pointer):
118         self.visit(pointer.type)
119
120     def visitLinearPointer(self, pointer):
121         self.visit(pointer.type)
122
123     def visitHandle(self, handle):
124         self.visit(handle.type)
125
126     def visitReference(self, reference):
127         self.visit(reference.type)
128
129     def visitAlias(self, alias):
130         self.visit(alias.type)
131
132     def visitOpaque(self, opaque):
133         pass
134
135     def visitInterface(self, interface):
136         pass
137
138     def visitPolymorphic(self, polymorphic):
139         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
140         print '    switch (selector) {'
141         for cases, type in polymorphic.iterSwitch():
142             for case in cases:
143                 print '    %s:' % case
144             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
145             print '        break;'
146         print '    }'
147         print '}'
148         print
149
150
151 class ValueSerializer(stdapi.Visitor):
152     '''Visitor which generates code to serialize any type.
153     
154     Simple types are serialized inline here, whereas the serialization of
155     complex types is dispatched to the serialization functions generated by
156     ComplexValueSerializer visitor above.
157     '''
158
159     def visitLiteral(self, literal, instance):
160         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
161
162     def visitString(self, string, instance):
163         if string.kind == 'String':
164             cast = 'const char *'
165         elif string.kind == 'WString':
166             cast = 'const wchar_t *'
167         else:
168             assert False
169         if cast != string.expr:
170             # reinterpret_cast is necessary for GLubyte * <=> char *
171             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
172         if string.length is not None:
173             length = ', %s' % string.length
174         else:
175             length = ''
176         print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
177
178     def visitConst(self, const, instance):
179         self.visit(const.type, instance)
180
181     def visitStruct(self, struct, instance):
182         print '    _write__%s(%s);' % (struct.tag, instance)
183
184     def visitArray(self, array, instance):
185         length = '__c' + array.type.tag
186         index = '__i' + array.type.tag
187         print '    if (%s) {' % instance
188         print '        size_t %s = %s;' % (length, array.length)
189         print '        trace::localWriter.beginArray(%s);' % length
190         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
191         print '            trace::localWriter.beginElement();'
192         self.visit(array.type, '(%s)[%s]' % (instance, index))
193         print '            trace::localWriter.endElement();'
194         print '        }'
195         print '        trace::localWriter.endArray();'
196         print '    } else {'
197         print '        trace::localWriter.writeNull();'
198         print '    }'
199
200     def visitBlob(self, blob, instance):
201         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
202
203     def visitEnum(self, enum, instance):
204         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
205
206     def visitBitmask(self, bitmask, instance):
207         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
208
209     def visitPointer(self, pointer, instance):
210         print '    if (%s) {' % instance
211         print '        trace::localWriter.beginArray(1);'
212         print '        trace::localWriter.beginElement();'
213         self.visit(pointer.type, "*" + instance)
214         print '        trace::localWriter.endElement();'
215         print '        trace::localWriter.endArray();'
216         print '    } else {'
217         print '        trace::localWriter.writeNull();'
218         print '    }'
219
220     def visitIntPointer(self, pointer, instance):
221         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
222
223     def visitObjPointer(self, pointer, instance):
224         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
225
226     def visitLinearPointer(self, pointer, instance):
227         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
228
229     def visitReference(self, reference, instance):
230         self.visit(reference.type, instance)
231
232     def visitHandle(self, handle, instance):
233         self.visit(handle.type, instance)
234
235     def visitAlias(self, alias, instance):
236         self.visit(alias.type, instance)
237
238     def visitOpaque(self, opaque, instance):
239         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
240
241     def visitInterface(self, interface, instance):
242         print '    trace::localWriter.writeOpaque((const void *)&%s);' % instance
243
244     def visitPolymorphic(self, polymorphic, instance):
245         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
246
247
248 class WrapDecider(stdapi.Traverser):
249     '''Type visitor which will decide wheter this type will need wrapping or not.
250     
251     For complex types (arrays, structures), we need to know this before hand.
252     '''
253
254     def __init__(self):
255         self.needsWrapping = False
256
257     def visitVoid(self, void):
258         raise NotImplementedError
259
260     def visitLinearPointer(self, void):
261         pass
262
263     def visitInterface(self, interface):
264         self.needsWrapping = True
265
266
267 class ValueWrapper(stdapi.Traverser):
268     '''Type visitor which will generate the code to wrap an instance.
269     
270     Wrapping is necessary mostly for interfaces, however interface pointers can
271     appear anywhere inside complex types.
272     '''
273
274     def visitStruct(self, struct, instance):
275         for type, name in struct.members:
276             self.visit(type, "(%s).%s" % (instance, name))
277
278     def visitArray(self, array, instance):
279         print "    if (%s) {" % instance
280         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
281         self.visit(array.type, instance + "[_i]")
282         print "        }"
283         print "    }"
284
285     def visitPointer(self, pointer, instance):
286         print "    if (%s) {" % instance
287         self.visit(pointer.type, "*" + instance)
288         print "    }"
289     
290     def visitObjPointer(self, pointer, instance):
291         print "    if (%s) {" % instance
292         self.visit(pointer.type, "*" + instance)
293         print "    }"
294     
295     def visitInterface(self, interface, instance):
296         assert instance.startswith('*')
297         instance = instance[1:]
298         print "    if (%s) {" % instance
299         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
300         print "    }"
301     
302     def visitPolymorphic(self, type, instance):
303         # XXX: There might be polymorphic values that need wrapping in the future
304         raise NotImplementedError
305
306
307 class ValueUnwrapper(ValueWrapper):
308     '''Reverse of ValueWrapper.'''
309
310     allocated = False
311
312     def visitArray(self, array, instance):
313         if self.allocated or isinstance(instance, stdapi.Interface):
314             return ValueWrapper.visitArray(self, array, instance)
315         elem_type = array.type.mutable()
316         print "    if (%s && %s) {" % (instance, array.length)
317         print "        %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array.length)
318         print "        for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array.length
319         print "            _t[_i] = %s[_i];" % instance 
320         self.allocated = True
321         self.visit(array.type, "_t[_i]")
322         print "        }"
323         print "        %s = _t;" % instance
324         print "    }"
325
326     def visitInterface(self, interface, instance):
327         assert instance.startswith('*')
328         instance = instance[1:]
329         print r'    if (%s) {' % instance
330         print r'        const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
331         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
332         print r'            %s = pWrapper->m_pInstance;' % (instance,)
333         print r'        } else {'
334         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
335         print r'        }'
336         print r'    }'
337
338
339 class Tracer:
340     '''Base class to orchestrate the code generation of API tracing.'''
341
342     def __init__(self):
343         self.api = None
344
345     def serializerFactory(self):
346         '''Create a serializer.
347         
348         Can be overriden by derived classes to inject their own serialzer.
349         '''
350
351         return ValueSerializer()
352
353     def trace_api(self, api):
354         self.api = api
355
356         self.header(api)
357
358         # Includes
359         for header in api.headers:
360             print header
361         print
362
363         # Generate the serializer functions
364         types = api.getAllTypes()
365         visitor = ComplexValueSerializer(self.serializerFactory())
366         map(visitor.visit, types)
367         print
368
369         # Interfaces wrapers
370         interfaces = api.getAllInterfaces()
371         map(self.declareWrapperInterface, interfaces)
372         map(self.implementWrapperInterface, interfaces)
373         print
374
375         # Function wrappers
376         map(self.traceFunctionDecl, api.functions)
377         map(self.traceFunctionImpl, api.functions)
378         print
379
380         self.footer(api)
381
382     def header(self, api):
383         pass
384
385     def footer(self, api):
386         pass
387
388     def traceFunctionDecl(self, function):
389         # Per-function declarations
390
391         if function.args:
392             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
393         else:
394             print 'static const char ** __%s_args = NULL;' % (function.name,)
395         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
396         print
397
398     def isFunctionPublic(self, function):
399         return True
400
401     def traceFunctionImpl(self, function):
402         if self.isFunctionPublic(function):
403             print 'extern "C" PUBLIC'
404         else:
405             print 'extern "C" PRIVATE'
406         print function.prototype() + ' {'
407         if function.type is not stdapi.Void:
408             print '    %s __result;' % function.type
409         self.traceFunctionImplBody(function)
410         if function.type is not stdapi.Void:
411             self.wrapRet(function, "__result")
412             print '    return __result;'
413         print '}'
414         print
415
416     def traceFunctionImplBody(self, function):
417         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
418         for arg in function.args:
419             if not arg.output:
420                 self.unwrapArg(function, arg)
421                 self.serializeArg(function, arg)
422         print '    trace::localWriter.endEnter();'
423         self.invokeFunction(function)
424         print '    trace::localWriter.beginLeave(__call);'
425         for arg in function.args:
426             if arg.output:
427                 self.serializeArg(function, arg)
428                 self.wrapArg(function, arg)
429         if function.type is not stdapi.Void:
430             self.serializeRet(function, "__result")
431         print '    trace::localWriter.endLeave();'
432
433     def invokeFunction(self, function, prefix='__', suffix=''):
434         if function.type is stdapi.Void:
435             result = ''
436         else:
437             result = '__result = '
438         dispatch = prefix + function.name + suffix
439         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
440
441     def serializeArg(self, function, arg):
442         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
443         self.serializeArgValue(function, arg)
444         print '    trace::localWriter.endArg();'
445
446     def serializeArgValue(self, function, arg):
447         self.serializeValue(arg.type, arg.name)
448
449     def wrapArg(self, function, arg):
450         self.wrapValue(arg.type, arg.name)
451
452     def unwrapArg(self, function, arg):
453         self.unwrapValue(arg.type, arg.name)
454
455     def serializeRet(self, function, instance):
456         print '    trace::localWriter.beginReturn();'
457         self.serializeValue(function.type, instance)
458         print '    trace::localWriter.endReturn();'
459
460     def serializeValue(self, type, instance):
461         serializer = self.serializerFactory()
462         serializer.visit(type, instance)
463
464     def wrapRet(self, function, instance):
465         self.wrapValue(function.type, instance)
466
467     def unwrapRet(self, function, instance):
468         self.unwrapValue(function.type, instance)
469
470     def needsWrapping(self, type):
471         visitor = WrapDecider()
472         visitor.visit(type)
473         return visitor.needsWrapping
474
475     def wrapValue(self, type, instance):
476         if self.needsWrapping(type):
477             visitor = ValueWrapper()
478             visitor.visit(type, instance)
479
480     def unwrapValue(self, type, instance):
481         if self.needsWrapping(type):
482             visitor = ValueUnwrapper()
483             visitor.visit(type, instance)
484
485     def declareWrapperInterface(self, interface):
486         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
487         print "{"
488         print "public:"
489         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
490         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
491         print
492         for method in interface.iterMethods():
493             print "    " + method.prototype() + ";"
494         print
495         self.declareWrapperInterfaceVariables(interface)
496         print "};"
497         print
498
499     def declareWrapperInterfaceVariables(self, interface):
500         #print "private:"
501         print "    DWORD m_dwMagic;"
502         print "    %s * m_pInstance;" % (interface.name,)
503
504     def implementWrapperInterface(self, interface):
505         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
506         print '    m_dwMagic = 0xd8365d6c;'
507         print '    m_pInstance = pInstance;'
508         print '}'
509         print
510         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
511         print '}'
512         print
513         for base, method in interface.iterBaseMethods():
514             self.implementWrapperInterfaceMethod(interface, base, method)
515         print
516
517     def implementWrapperInterfaceMethod(self, interface, base, method):
518         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
519         if method.type is not stdapi.Void:
520             print '    %s __result;' % method.type
521     
522         self.implementWrapperInterfaceMethodBody(interface, base, method)
523     
524         if method.type is not stdapi.Void:
525             print '    return __result;'
526         print '}'
527         print
528
529     def implementWrapperInterfaceMethodBody(self, interface, base, method):
530         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
531         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
532         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
533         print '    trace::localWriter.beginArg(0);'
534         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
535         print '    trace::localWriter.endArg();'
536
537         from specs.winapi import REFIID
538         from specs.stdapi import Pointer, Opaque, Interface
539
540         riid = None
541         for arg in method.args:
542             if not arg.output:
543                 self.unwrapArg(method, arg)
544                 self.serializeArg(method, arg)
545                 if arg.type is REFIID:
546                     riid = arg
547         print '    trace::localWriter.endEnter();'
548         
549         self.invokeMethod(interface, base, method)
550
551         print '    trace::localWriter.beginLeave(__call);'
552         for arg in method.args:
553             if arg.output:
554                 self.serializeArg(method, arg)
555                 self.wrapArg(method, arg)
556                 if riid is not None and isinstance(arg.type, Pointer):
557                     if isinstance(arg.type.type, Opaque):
558                         self.wrapIid(interface, method, riid, arg)
559                     else:
560                         assert isinstance(arg.type.type, Pointer)
561                         assert isinstance(arg.type.type.type, Interface)
562
563         if method.type is not stdapi.Void:
564             print '    trace::localWriter.beginReturn();'
565             self.serializeValue(method.type, "__result")
566             print '    trace::localWriter.endReturn();'
567             self.wrapValue(method.type, '__result')
568         print '    trace::localWriter.endLeave();'
569         if method.name == 'Release':
570             assert method.type is not stdapi.Void
571             print '    if (!__result)'
572             print '        delete this;'
573
574     def wrapIid(self, interface, method, riid, out):
575             print '    if (%s && *%s) {' % (out.name, out.name)
576             print '        if (*%s == m_pInstance) {' % (out.name,)
577             print '            *%s = this;' % (out.name,)
578             print '        }'
579             for iface in self.api.getAllInterfaces():
580                 print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
581                 print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
582                 print r'        }'
583             print r'        else {'
584             print r'            os::log("apitrace: warning: %s::%s: unknown IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
585             print r'                    "%s", "%s",' % (interface.name, method.name)
586             print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
587             print r'                    %s.Data4[0],' % (riid.name,)
588             print r'                    %s.Data4[1],' % (riid.name,)
589             print r'                    %s.Data4[2],' % (riid.name,)
590             print r'                    %s.Data4[3],' % (riid.name,)
591             print r'                    %s.Data4[4],' % (riid.name,)
592             print r'                    %s.Data4[5],' % (riid.name,)
593             print r'                    %s.Data4[6],' % (riid.name,)
594             print r'                    %s.Data4[7]);' % (riid.name,)
595             print r'        }'
596             print '    }'
597
598     def invokeMethod(self, interface, base, method):
599         if method.type is stdapi.Void:
600             result = ''
601         else:
602             result = '__result = '
603         print '    %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args]))
604     
605     def emit_memcpy(self, dest, src, length):
606         print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
607         print '        trace::localWriter.beginArg(0);'
608         print '        trace::localWriter.writeOpaque(%s);' % dest
609         print '        trace::localWriter.endArg();'
610         print '        trace::localWriter.beginArg(1);'
611         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
612         print '        trace::localWriter.endArg();'
613         print '        trace::localWriter.beginArg(2);'
614         print '        trace::localWriter.writeUInt(%s);' % length
615         print '        trace::localWriter.endArg();'
616         print '        trace::localWriter.endEnter();'
617         print '        trace::localWriter.beginLeave(__call);'
618         print '        trace::localWriter.endLeave();'
619