]> git.cworth.org Git - apitrace/blob - retrace.py
Merge branch 'master' into d3dretrace
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class ConstRemover(stdapi.Rebuilder):
37     '''Type visitor which strips out const qualifiers from types.'''
38
39     def visitConst(self, const):
40         return const.type
41
42     def visitOpaque(self, opaque):
43         return opaque
44
45
46 def lookupHandle(handle, value):
47     if handle.key is None:
48         return "__%s_map[%s]" % (handle.name, value)
49     else:
50         key_name, key_type = handle.key
51         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
52
53
54 class ValueDeserializer(stdapi.Visitor):
55
56     def visitLiteral(self, literal, lvalue, rvalue):
57         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
58
59     def visitConst(self, const, lvalue, rvalue):
60         self.visit(const.type, lvalue, rvalue)
61
62     def visitAlias(self, alias, lvalue, rvalue):
63         self.visit(alias.type, lvalue, rvalue)
64     
65     def visitEnum(self, enum, lvalue, rvalue):
66         print '    %s = (%s).toSInt();' % (lvalue, rvalue)
67
68     def visitBitmask(self, bitmask, lvalue, rvalue):
69         self.visit(bitmask.type, lvalue, rvalue)
70
71     def visitArray(self, array, lvalue, rvalue):
72         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
73         print '    if (__a%s) {' % (array.tag)
74         length = '__a%s->values.size()' % array.tag
75         print '        %s = _allocator.alloc<%s>(%s);' % (lvalue, array.type, length)
76         index = '__j' + array.tag
77         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
78         try:
79             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
80         finally:
81             print '        }'
82             print '    } else {'
83             print '        %s = NULL;' % lvalue
84             print '    }'
85     
86     def visitPointer(self, pointer, lvalue, rvalue):
87         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
88         print '    if (__a%s) {' % (pointer.tag)
89         print '        %s = _allocator.alloc<%s>();' % (lvalue, pointer.type)
90         try:
91             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
92         finally:
93             print '    } else {'
94             print '        %s = NULL;' % lvalue
95             print '    }'
96
97     def visitIntPointer(self, pointer, lvalue, rvalue):
98         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
99
100     def visitLinearPointer(self, pointer, lvalue, rvalue):
101         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
102
103     def visitHandle(self, handle, lvalue, rvalue):
104         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
105         self.visit(handle.type, lvalue, rvalue);
106         new_lvalue = lookupHandle(handle, lvalue)
107         print '    if (retrace::verbosity >= 2) {'
108         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
109         print '    }'
110         print '    %s = %s;' % (lvalue, new_lvalue)
111     
112     def visitBlob(self, blob, lvalue, rvalue):
113         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
114     
115     def visitString(self, string, lvalue, rvalue):
116         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
117
118
119 class OpaqueValueDeserializer(ValueDeserializer):
120     '''Value extractor that also understands opaque values.
121
122     Normally opaque values can't be retraced, unless they are being extracted
123     in the context of handles.'''
124
125     def visitOpaque(self, opaque, lvalue, rvalue):
126         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
127
128
129 class SwizzledValueRegistrator(stdapi.Visitor):
130     '''Type visitor which will register (un)swizzled value pairs, to later be
131     swizzled.'''
132
133     def visitLiteral(self, literal, lvalue, rvalue):
134         pass
135
136     def visitAlias(self, alias, lvalue, rvalue):
137         self.visit(alias.type, lvalue, rvalue)
138     
139     def visitEnum(self, enum, lvalue, rvalue):
140         pass
141
142     def visitBitmask(self, bitmask, lvalue, rvalue):
143         pass
144
145     def visitArray(self, array, lvalue, rvalue):
146         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
147         print '    if (__a%s) {' % (array.tag)
148         length = '__a%s->values.size()' % array.tag
149         index = '__j' + array.tag
150         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
151         try:
152             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
153         finally:
154             print '        }'
155             print '    }'
156     
157     def visitPointer(self, pointer, lvalue, rvalue):
158         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
159         print '    if (__a%s) {' % (pointer.tag)
160         try:
161             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
162         finally:
163             print '    }'
164     
165     def visitIntPointer(self, pointer, lvalue, rvalue):
166         pass
167     
168     def visitLinearPointer(self, pointer, lvalue, rvalue):
169         assert pointer.size is not None
170         if pointer.size is not None:
171             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
172
173     def visitHandle(self, handle, lvalue, rvalue):
174         print '    %s __orig_result;' % handle.type
175         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
176         if handle.range is None:
177             rvalue = "__orig_result"
178             entry = lookupHandle(handle, rvalue) 
179             print "    %s = %s;" % (entry, lvalue)
180             print '    if (retrace::verbosity >= 2) {'
181             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
182             print '    }'
183         else:
184             i = '__h' + handle.tag
185             lvalue = "%s + %s" % (lvalue, i)
186             rvalue = "__orig_result + %s" % (i,)
187             entry = lookupHandle(handle, rvalue) 
188             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
189             print '        {entry} = {lvalue};'.format(**locals())
190             print '        if (retrace::verbosity >= 2) {'
191             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
192             print '        }'
193             print '    }'
194     
195     def visitBlob(self, blob, lvalue, rvalue):
196         pass
197     
198     def visitString(self, string, lvalue, rvalue):
199         pass
200
201
202 class Retracer:
203
204     def retraceFunction(self, function):
205         print 'static void retrace_%s(trace::Call &call) {' % function.name
206         self.retraceFunctionBody(function)
207         print '}'
208         print
209
210     def retraceInterfaceMethod(self, interface, method):
211         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
212         self.retraceInterfaceMethodBody(interface, method)
213         print '}'
214         print
215
216     def retraceFunctionBody(self, function):
217         if not function.sideeffects:
218             print '    (void)call;'
219             return
220
221         self.deserializeArgs(function)
222         
223         self.invokeFunction(function)
224
225         self.swizzleValues(function)
226
227     def retraceInterfaceMethodBody(self, interface, method):
228         if not method.sideeffects:
229             print '    (void)call;'
230             return
231
232         self.deserializeThisPointer(interface)
233
234         self.deserializeArgs(method)
235         
236         self.invokeInterfaceMethod(interface, method)
237
238         self.swizzleValues(method)
239
240     def deserializeThisPointer(self, interface):
241         print '    %s *_this;' % (interface.name,)
242         # FIXME
243
244     def deserializeArgs(self, function):
245         print '    retrace::ScopedAllocator _allocator;'
246         print '    (void)_allocator;'
247         success = True
248         for arg in function.args:
249             arg_type = ConstRemover().visit(arg.type)
250             #print '    // %s ->  %s' % (arg.type, arg_type)
251             print '    %s %s;' % (arg_type, arg.name)
252             rvalue = 'call.arg(%u)' % (arg.index,)
253             lvalue = arg.name
254             try:
255                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
256             except NotImplementedError:
257                 success =  False
258                 print '    %s = 0; // FIXME' % arg.name
259
260         if not success:
261             print '    if (1) {'
262             self.failFunction(function)
263             if function.name[-1].islower():
264                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
265             print '    }'
266
267     def swizzleValues(self, function):
268         for arg in function.args:
269             if arg.output:
270                 arg_type = ConstRemover().visit(arg.type)
271                 rvalue = 'call.arg(%u)' % (arg.index,)
272                 lvalue = arg.name
273                 try:
274                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
275                 except NotImplementedError:
276                     print '    // XXX: %s' % arg.name
277         if function.type is not stdapi.Void:
278             rvalue = '*call.ret'
279             lvalue = '__result'
280             try:
281                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
282             except NotImplementedError:
283                 print '    // XXX: result'
284
285     def failFunction(self, function):
286         print '    if (retrace::verbosity >= 0) {'
287         print '        retrace::unsupported(call);'
288         print '    }'
289         print '    return;'
290
291     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
292         ValueDeserializer().visit(arg_type, lvalue, rvalue)
293     
294     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
295         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
296
297     def regiterSwizzledValue(self, type, lvalue, rvalue):
298         visitor = SwizzledValueRegistrator()
299         visitor.visit(type, lvalue, rvalue)
300
301     def invokeFunction(self, function):
302         arg_names = ", ".join(function.argNames())
303         if function.type is not stdapi.Void:
304             print '    %s __result;' % (function.type)
305             print '    __result = %s(%s);' % (function.name, arg_names)
306             print '    (void)__result;'
307         else:
308             print '    %s(%s);' % (function.name, arg_names)
309
310     def invokeInterfaceMethod(self, interface, method):
311         arg_names = ", ".join(method.argNames())
312         if method.type is not stdapi.Void:
313             print '    %s __result;' % (method.type)
314             print '    __result = _this->%s(%s);' % (method.name, arg_names)
315             print '    (void)__result;'
316         else:
317             print '    _this->%s(%s);' % (method.name, arg_names)
318
319     def filterFunction(self, function):
320         return True
321
322     table_name = 'retrace::callbacks'
323
324     def retraceApi(self, api):
325
326         print '#include "trace_parser.hpp"'
327         print '#include "retrace.hpp"'
328         print
329
330         types = api.getAllTypes()
331         handles = [type for type in types if isinstance(type, stdapi.Handle)]
332         handle_names = set()
333         for handle in handles:
334             if handle.name not in handle_names:
335                 if handle.key is None:
336                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
337                 else:
338                     key_name, key_type = handle.key
339                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
340                 handle_names.add(handle.name)
341         print
342
343         functions = filter(self.filterFunction, api.functions)
344         for function in functions:
345             self.retraceFunction(function)
346         interfaces = api.getAllInterfaces()
347         for interface in interfaces:
348             for method in interface.iterMethods():
349                 self.retraceInterfaceMethod(interface, method)
350
351         print 'const retrace::Entry %s[] = {' % self.table_name
352         for function in functions:
353             print '    {"%s", &retrace_%s},' % (function.name, function.name)
354         for interface in interfaces:
355             for method in interface.iterMethods():
356                 print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
357         print '    {NULL, NULL}'
358         print '};'
359         print
360