]> git.cworth.org Git - apitrace/blob - trace.py
Fix major regression in D3D tracing.
[apitrace] / trace.py
1 ##########################################################################
2 #
3 # Copyright 2008-2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26 """Common trace code generation."""
27
28
29 import specs.stdapi as stdapi
30
31
32 def getWrapperInterfaceName(interface):
33     return "Wrap" + interface.expr
34
35
36 class ComplexValueSerializer(stdapi.OnceVisitor):
37     '''Type visitors which generates serialization functions for
38     complex types.
39     
40     Simple types are serialized inline.
41     '''
42
43     def __init__(self, serializer):
44         stdapi.OnceVisitor.__init__(self)
45         self.serializer = serializer
46
47     def visitVoid(self, literal):
48         pass
49
50     def visitLiteral(self, literal):
51         pass
52
53     def visitString(self, string):
54         pass
55
56     def visitConst(self, const):
57         self.visit(const.type)
58
59     def visitStruct(self, struct):
60         for type, name in struct.members:
61             self.visit(type)
62         print 'static void _write__%s(const %s &value) {' % (struct.tag, struct.expr)
63         print '    static const char * members[%u] = {' % (len(struct.members),)
64         for type, name,  in struct.members:
65             print '        "%s",' % (name,)
66         print '    };'
67         print '    static const trace::StructSig sig = {'
68         print '       %u, "%s", %u, members' % (struct.id, struct.name, len(struct.members))
69         print '    };'
70         print '    trace::localWriter.beginStruct(&sig);'
71         for type, name in struct.members:
72             self.serializer.visit(type, 'value.%s' % (name,))
73         print '    trace::localWriter.endStruct();'
74         print '}'
75         print
76
77     def visitArray(self, array):
78         self.visit(array.type)
79
80     def visitBlob(self, array):
81         pass
82
83     def visitEnum(self, enum):
84         print 'static const trace::EnumValue __enum%s_values[] = {' % (enum.tag)
85         for value in enum.values:
86             print '   {"%s", %s},' % (value, value)
87         print '};'
88         print
89         print 'static const trace::EnumSig __enum%s_sig = {' % (enum.tag)
90         print '   %u, %u, __enum%s_values' % (enum.id, len(enum.values), enum.tag)
91         print '};'
92         print
93
94     def visitBitmask(self, bitmask):
95         print 'static const trace::BitmaskFlag __bitmask%s_flags[] = {' % (bitmask.tag)
96         for value in bitmask.values:
97             print '   {"%s", %s},' % (value, value)
98         print '};'
99         print
100         print 'static const trace::BitmaskSig __bitmask%s_sig = {' % (bitmask.tag)
101         print '   %u, %u, __bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
102         print '};'
103         print
104
105     def visitPointer(self, pointer):
106         self.visit(pointer.type)
107
108     def visitIntPointer(self, pointer):
109         pass
110
111     def visitLinearPointer(self, pointer):
112         self.visit(pointer.type)
113
114     def visitHandle(self, handle):
115         self.visit(handle.type)
116
117     def visitAlias(self, alias):
118         self.visit(alias.type)
119
120     def visitOpaque(self, opaque):
121         pass
122
123     def visitInterface(self, interface):
124         pass
125
126     def visitPolymorphic(self, polymorphic):
127         print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
128         print '    switch (selector) {'
129         for cases, type in polymorphic.iterSwitch():
130             for case in cases:
131                 print '    %s:' % case
132             self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
133             print '        break;'
134         print '    }'
135         print '}'
136         print
137
138
139 class ValueSerializer(stdapi.Visitor):
140     '''Visitor which generates code to serialize any type.
141     
142     Simple types are serialized inline here, whereas the serialization of
143     complex types is dispatched to the serialization functions generated by
144     ComplexValueSerializer visitor above.
145     '''
146
147     def visitLiteral(self, literal, instance):
148         print '    trace::localWriter.write%s(%s);' % (literal.kind, instance)
149
150     def visitString(self, string, instance):
151         if string.kind == 'String':
152             cast = 'const char *'
153         elif string.kind == 'WString':
154             cast = 'const wchar_t *'
155         else:
156             assert False
157         if cast != string.expr:
158             # reinterpret_cast is necessary for GLubyte * <=> char *
159             instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
160         if string.length is not None:
161             length = ', %s' % string.length
162         else:
163             length = ''
164         print '    trace::localWriter.write%s(%s%s);' % (string.kind, instance, length)
165
166     def visitConst(self, const, instance):
167         self.visit(const.type, instance)
168
169     def visitStruct(self, struct, instance):
170         print '    _write__%s(%s);' % (struct.tag, instance)
171
172     def visitArray(self, array, instance):
173         length = '__c' + array.type.tag
174         index = '__i' + array.type.tag
175         print '    if (%s) {' % instance
176         print '        size_t %s = %s;' % (length, array.length)
177         print '        trace::localWriter.beginArray(%s);' % length
178         print '        for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
179         print '            trace::localWriter.beginElement();'
180         self.visit(array.type, '(%s)[%s]' % (instance, index))
181         print '            trace::localWriter.endElement();'
182         print '        }'
183         print '        trace::localWriter.endArray();'
184         print '    } else {'
185         print '        trace::localWriter.writeNull();'
186         print '    }'
187
188     def visitBlob(self, blob, instance):
189         print '    trace::localWriter.writeBlob(%s, %s);' % (instance, blob.size)
190
191     def visitEnum(self, enum, instance):
192         print '    trace::localWriter.writeEnum(&__enum%s_sig, %s);' % (enum.tag, instance)
193
194     def visitBitmask(self, bitmask, instance):
195         print '    trace::localWriter.writeBitmask(&__bitmask%s_sig, %s);' % (bitmask.tag, instance)
196
197     def visitPointer(self, pointer, instance):
198         print '    if (%s) {' % instance
199         print '        trace::localWriter.beginArray(1);'
200         print '        trace::localWriter.beginElement();'
201         self.visit(pointer.type, "*" + instance)
202         print '        trace::localWriter.endElement();'
203         print '        trace::localWriter.endArray();'
204         print '    } else {'
205         print '        trace::localWriter.writeNull();'
206         print '    }'
207
208     def visitIntPointer(self, pointer, instance):
209         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
210
211     def visitLinearPointer(self, pointer, instance):
212         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
213
214     def visitHandle(self, handle, instance):
215         self.visit(handle.type, instance)
216
217     def visitAlias(self, alias, instance):
218         self.visit(alias.type, instance)
219
220     def visitOpaque(self, opaque, instance):
221         print '    trace::localWriter.writeOpaque((const void *)%s);' % instance
222
223     def visitInterface(self, interface, instance):
224         print '    trace::localWriter.writeOpaque((const void *)&%s);' % instance
225
226     def visitPolymorphic(self, polymorphic, instance):
227         print '    _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
228
229
230 class ValueWrapper(stdapi.Visitor):
231     '''Type visitor which will generate the code to wrap an instance.
232     
233     Wrapping is necessary mostly for interfaces, however interface pointers can
234     appear anywhere inside complex types.
235     '''
236
237     def visitVoid(self, type, instance):
238         raise NotImplementedError
239
240     def visitLiteral(self, type, instance):
241         pass
242
243     def visitString(self, type, instance):
244         pass
245
246     def visitConst(self, type, instance):
247         pass
248
249     def visitStruct(self, struct, instance):
250         for type, name in struct.members:
251             self.visit(type, "(%s).%s" % (instance, name))
252
253     def visitArray(self, array, instance):
254         # XXX: actually it is possible to return an array of pointers
255         pass
256
257     def visitBlob(self, blob, instance):
258         pass
259
260     def visitEnum(self, enum, instance):
261         pass
262
263     def visitBitmask(self, bitmask, instance):
264         pass
265
266     def visitPointer(self, pointer, instance):
267         print "    if (%s) {" % instance
268         self.visit(pointer.type, "*" + instance)
269         print "    }"
270     
271     def visitIntPointer(self, pointer, instance):
272         pass
273
274     def visitLinearPointer(self, pointer, instance):
275         pass
276
277     def visitHandle(self, handle, instance):
278         self.visit(handle.type, instance)
279
280     def visitAlias(self, alias, instance):
281         self.visit(alias.type, instance)
282
283     def visitOpaque(self, opaque, instance):
284         pass
285     
286     def visitInterface(self, interface, instance):
287         assert instance.startswith('*')
288         instance = instance[1:]
289         print "    if (%s) {" % instance
290         print "        %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
291         print "    }"
292     
293     def visitPolymorphic(self, type, instance):
294         # XXX: There might be polymorphic values that need wrapping in the future
295         pass
296
297
298 class ValueUnwrapper(ValueWrapper):
299     '''Reverse of ValueWrapper.'''
300
301     def visitInterface(self, interface, instance):
302         assert instance.startswith('*')
303         instance = instance[1:]
304         print r'    if (%s) {' % instance
305         print r'        %s *pWrapper = static_cast<%s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
306         print r'        if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
307         print r'            %s = pWrapper->m_pInstance;' % (instance,)
308         print r'        } else {'
309         print r'            os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
310         print r'        }'
311         print r'    }'
312
313
314 class Tracer:
315     '''Base class to orchestrate the code generation of API tracing.'''
316
317     def __init__(self):
318         self.api = None
319
320     def serializerFactory(self):
321         '''Create a serializer.
322         
323         Can be overriden by derived classes to inject their own serialzer.
324         '''
325
326         return ValueSerializer()
327
328     def trace_api(self, api):
329         self.api = api
330
331         self.header(api)
332
333         # Includes
334         for header in api.headers:
335             print header
336         print
337
338         # Generate the serializer functions
339         types = api.getAllTypes()
340         visitor = ComplexValueSerializer(self.serializerFactory())
341         map(visitor.visit, types)
342         print
343
344         # Interfaces wrapers
345         interfaces = api.getAllInterfaces()
346         map(self.declareWrapperInterface, interfaces)
347         map(self.implementWrapperInterface, interfaces)
348         print
349
350         # Function wrappers
351         map(self.traceFunctionDecl, api.functions)
352         map(self.traceFunctionImpl, api.functions)
353         print
354
355         self.footer(api)
356
357     def header(self, api):
358         pass
359
360     def footer(self, api):
361         pass
362
363     def traceFunctionDecl(self, function):
364         # Per-function declarations
365
366         if function.args:
367             print 'static const char * __%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
368         else:
369             print 'static const char ** __%s_args = NULL;' % (function.name,)
370         print 'static const trace::FunctionSig __%s_sig = {%u, "%s", %u, __%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
371         print
372
373     def isFunctionPublic(self, function):
374         return True
375
376     def traceFunctionImpl(self, function):
377         if self.isFunctionPublic(function):
378             print 'extern "C" PUBLIC'
379         else:
380             print 'extern "C" PRIVATE'
381         print function.prototype() + ' {'
382         if function.type is not stdapi.Void:
383             print '    %s __result;' % function.type
384         self.traceFunctionImplBody(function)
385         if function.type is not stdapi.Void:
386             self.wrapRet(function, "__result")
387             print '    return __result;'
388         print '}'
389         print
390
391     def traceFunctionImplBody(self, function):
392         print '    unsigned __call = trace::localWriter.beginEnter(&__%s_sig);' % (function.name,)
393         for arg in function.args:
394             if not arg.output:
395                 self.unwrapArg(function, arg)
396                 self.serializeArg(function, arg)
397         print '    trace::localWriter.endEnter();'
398         self.invokeFunction(function)
399         print '    trace::localWriter.beginLeave(__call);'
400         for arg in function.args:
401             if arg.output:
402                 self.serializeArg(function, arg)
403                 self.wrapArg(function, arg)
404         if function.type is not stdapi.Void:
405             self.serializeRet(function, "__result")
406         print '    trace::localWriter.endLeave();'
407
408     def invokeFunction(self, function, prefix='__', suffix=''):
409         if function.type is stdapi.Void:
410             result = ''
411         else:
412             result = '__result = '
413         dispatch = prefix + function.name + suffix
414         print '    %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
415
416     def serializeArg(self, function, arg):
417         print '    trace::localWriter.beginArg(%u);' % (arg.index,)
418         self.serializeArgValue(function, arg)
419         print '    trace::localWriter.endArg();'
420
421     def serializeArgValue(self, function, arg):
422         self.serializeValue(arg.type, arg.name)
423
424     def wrapArg(self, function, arg):
425         self.wrapValue(arg.type, arg.name)
426
427     def unwrapArg(self, function, arg):
428         self.unwrapValue(arg.type, arg.name)
429
430     def serializeRet(self, function, instance):
431         print '    trace::localWriter.beginReturn();'
432         self.serializeValue(function.type, instance)
433         print '    trace::localWriter.endReturn();'
434
435     def serializeValue(self, type, instance):
436         serializer = self.serializerFactory()
437         serializer.visit(type, instance)
438
439     def wrapRet(self, function, instance):
440         self.wrapValue(function.type, instance)
441
442     def unwrapRet(self, function, instance):
443         self.unwrapValue(function.type, instance)
444
445     def wrapValue(self, type, instance):
446         visitor = ValueWrapper()
447         visitor.visit(type, instance)
448
449     def unwrapValue(self, type, instance):
450         visitor = ValueUnwrapper()
451         visitor.visit(type, instance)
452
453     def declareWrapperInterface(self, interface):
454         print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
455         print "{"
456         print "public:"
457         print "    %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
458         print "    virtual ~%s();" % getWrapperInterfaceName(interface)
459         print
460         for method in interface.iterMethods():
461             print "    " + method.prototype() + ";"
462         print
463         self.declareWrapperInterfaceVariables(interface)
464         print "};"
465         print
466
467     def declareWrapperInterfaceVariables(self, interface):
468         #print "private:"
469         print "    DWORD m_dwMagic;"
470         print "    %s * m_pInstance;" % (interface.name,)
471
472     def implementWrapperInterface(self, interface):
473         print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
474         print '    m_dwMagic = 0xd8365d6c;'
475         print '    m_pInstance = pInstance;'
476         print '}'
477         print
478         print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
479         print '}'
480         print
481         for base, method in interface.iterBaseMethods():
482             self.implementWrapperInterfaceMethod(interface, base, method)
483         print
484
485     def implementWrapperInterfaceMethod(self, interface, base, method):
486         print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
487         if method.type is not stdapi.Void:
488             print '    %s __result;' % method.type
489     
490         self.implementWrapperInterfaceMethodBody(interface, base, method)
491     
492         if method.type is not stdapi.Void:
493             print '    return __result;'
494         print '}'
495         print
496
497     def implementWrapperInterfaceMethodBody(self, interface, base, method):
498         print '    static const char * __args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
499         print '    static const trace::FunctionSig __sig = {%u, "%s", %u, __args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
500         print '    unsigned __call = trace::localWriter.beginEnter(&__sig);'
501         print '    trace::localWriter.beginArg(0);'
502         print '    trace::localWriter.writeOpaque((const void *)m_pInstance);'
503         print '    trace::localWriter.endArg();'
504
505         from specs.winapi import REFIID
506         from specs.stdapi import Pointer, Opaque, Interface
507
508         riid = None
509         for arg in method.args:
510             if not arg.output:
511                 self.unwrapArg(method, arg)
512                 self.serializeArg(method, arg)
513                 if arg.type is REFIID:
514                     riid = arg
515         print '    trace::localWriter.endEnter();'
516         
517         self.invokeMethod(interface, base, method)
518
519         print '    trace::localWriter.beginLeave(__call);'
520         for arg in method.args:
521             if arg.output:
522                 self.serializeArg(method, arg)
523                 self.wrapArg(method, arg)
524                 if riid is not None and isinstance(arg.type, Pointer):
525                     if isinstance(arg.type.type, Opaque):
526                         self.wrapIid(riid, arg)
527                     else:
528                         assert isinstance(arg.type.type, Pointer)
529                         assert isinstance(arg.type.type.type, Interface)
530
531         if method.type is not stdapi.Void:
532             print '    trace::localWriter.beginReturn();'
533             self.serializeValue(method.type, "__result")
534             print '    trace::localWriter.endReturn();'
535             self.wrapValue(method.type, '__result')
536         print '    trace::localWriter.endLeave();'
537         if method.name == 'Release':
538             assert method.type is not stdapi.Void
539             print '    if (!__result)'
540             print '        delete this;'
541
542     def wrapIid(self, riid, out):
543             print '    if (%s && *%s) {' % (out.name, out.name)
544             print '        if (*%s == m_pInstance) {' % (out.name,)
545             print '            *%s = this;' % (out.name,)
546             print '        }'
547             for iface in self.api.getAllInterfaces():
548                 print r'        else if (%s == IID_%s) {' % (riid.name, iface.name)
549                 print r'            *%s = new Wrap%s((%s *) *%s);' % (out.name, iface.name, iface.name, out.name)
550                 print r'        }'
551             print r'        else {'
552             print r'            os::log("apitrace: warning: %s: unknown REFIID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
553             print r'                    __FUNCTION__,'
554             print r'                    %s.Data1, %s.Data2, %s.Data3,' % (riid.name, riid.name, riid.name)
555             print r'                    %s.Data4[0],' % (riid.name,)
556             print r'                    %s.Data4[1],' % (riid.name,)
557             print r'                    %s.Data4[2],' % (riid.name,)
558             print r'                    %s.Data4[3],' % (riid.name,)
559             print r'                    %s.Data4[4],' % (riid.name,)
560             print r'                    %s.Data4[5],' % (riid.name,)
561             print r'                    %s.Data4[6],' % (riid.name,)
562             print r'                    %s.Data4[7]);' % (riid.name,)
563             print r'        }'
564             print '    }'
565
566     def invokeMethod(self, interface, base, method):
567         if method.type is stdapi.Void:
568             result = ''
569         else:
570             result = '__result = '
571         print '    %sstatic_cast<%s *>(m_pInstance)->%s(%s);' % (result, base, method.name, ', '.join([str(arg.name) for arg in method.args]))
572     
573     def emit_memcpy(self, dest, src, length):
574         print '        unsigned __call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
575         print '        trace::localWriter.beginArg(0);'
576         print '        trace::localWriter.writeOpaque(%s);' % dest
577         print '        trace::localWriter.endArg();'
578         print '        trace::localWriter.beginArg(1);'
579         print '        trace::localWriter.writeBlob(%s, %s);' % (src, length)
580         print '        trace::localWriter.endArg();'
581         print '        trace::localWriter.beginArg(2);'
582         print '        trace::localWriter.writeUInt(%s);' % length
583         print '        trace::localWriter.endArg();'
584         print '        trace::localWriter.endEnter();'
585         print '        trace::localWriter.beginLeave(__call);'
586         print '        trace::localWriter.endLeave();'
587