]> git.cworth.org Git - apitrace/blob - retrace.py
D3D retrace checkpoint.
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class ConstRemover(stdapi.Rebuilder):
37     '''Type visitor which strips out const qualifiers from types.'''
38
39     def visitConst(self, const):
40         return const.type
41
42     def visitOpaque(self, opaque):
43         return opaque
44
45
46 def lookupHandle(handle, value):
47     if handle.key is None:
48         return "__%s_map[%s]" % (handle.name, value)
49     else:
50         key_name, key_type = handle.key
51         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
52
53
54 class ValueDeserializer(stdapi.Visitor):
55
56     def visitLiteral(self, literal, lvalue, rvalue):
57         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
58
59     def visitConst(self, const, lvalue, rvalue):
60         self.visit(const.type, lvalue, rvalue)
61
62     def visitAlias(self, alias, lvalue, rvalue):
63         self.visit(alias.type, lvalue, rvalue)
64     
65     def visitEnum(self, enum, lvalue, rvalue):
66         print '    %s = (%s).toSInt();' % (lvalue, rvalue)
67
68     def visitBitmask(self, bitmask, lvalue, rvalue):
69         self.visit(bitmask.type, lvalue, rvalue)
70
71     def visitArray(self, array, lvalue, rvalue):
72         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
73         print '    if (__a%s) {' % (array.tag)
74         length = '__a%s->values.size()' % array.tag
75         print '        %s = new %s[%s];' % (lvalue, array.type, length)
76         index = '__j' + array.tag
77         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
78         try:
79             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
80         finally:
81             print '        }'
82             print '    } else {'
83             print '        %s = NULL;' % lvalue
84             print '    }'
85     
86     def visitPointer(self, pointer, lvalue, rvalue):
87         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
88         print '    if (__a%s) {' % (pointer.tag)
89         print '        %s = new %s;' % (lvalue, pointer.type)
90         try:
91             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
92         finally:
93             print '    } else {'
94             print '        %s = NULL;' % lvalue
95             print '    }'
96
97     def visitIntPointer(self, pointer, lvalue, rvalue):
98         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
99
100     def visitLinearPointer(self, pointer, lvalue, rvalue):
101         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
102
103     def visitHandle(self, handle, lvalue, rvalue):
104         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
105         self.visit(handle.type, lvalue, rvalue);
106         new_lvalue = lookupHandle(handle, lvalue)
107         print '    if (retrace::verbosity >= 2) {'
108         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
109         print '    }'
110         print '    %s = %s;' % (lvalue, new_lvalue)
111     
112     def visitBlob(self, blob, lvalue, rvalue):
113         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
114     
115     def visitString(self, string, lvalue, rvalue):
116         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
117
118
119 class OpaqueValueDeserializer(ValueDeserializer):
120     '''Value extractor that also understands opaque values.
121
122     Normally opaque values can't be retraced, unless they are being extracted
123     in the context of handles.'''
124
125     def visitOpaque(self, opaque, lvalue, rvalue):
126         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
127
128
129 class SwizzledValueRegistrator(stdapi.Visitor):
130     '''Type visitor which will register (un)swizzled value pairs, to later be
131     swizzled.'''
132
133     def visitLiteral(self, literal, lvalue, rvalue):
134         pass
135
136     def visitAlias(self, alias, lvalue, rvalue):
137         self.visit(alias.type, lvalue, rvalue)
138     
139     def visitEnum(self, enum, lvalue, rvalue):
140         pass
141
142     def visitBitmask(self, bitmask, lvalue, rvalue):
143         pass
144
145     def visitArray(self, array, lvalue, rvalue):
146         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
147         print '    if (__a%s) {' % (array.tag)
148         length = '__a%s->values.size()' % array.tag
149         index = '__j' + array.tag
150         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
151         try:
152             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
153         finally:
154             print '        }'
155             print '    }'
156     
157     def visitPointer(self, pointer, lvalue, rvalue):
158         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
159         print '    if (__a%s) {' % (pointer.tag)
160         try:
161             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
162         finally:
163             print '    }'
164     
165     def visitIntPointer(self, pointer, lvalue, rvalue):
166         pass
167     
168     def visitLinearPointer(self, pointer, lvalue, rvalue):
169         assert pointer.size is not None
170         if pointer.size is not None:
171             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
172
173     def visitHandle(self, handle, lvalue, rvalue):
174         print '    %s __orig_result;' % handle.type
175         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
176         if handle.range is None:
177             rvalue = "__orig_result"
178             entry = lookupHandle(handle, rvalue) 
179             print "    %s = %s;" % (entry, lvalue)
180             print '    if (retrace::verbosity >= 2) {'
181             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
182             print '    }'
183         else:
184             i = '__h' + handle.tag
185             lvalue = "%s + %s" % (lvalue, i)
186             rvalue = "__orig_result + %s" % (i,)
187             entry = lookupHandle(handle, rvalue) 
188             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
189             print '        {entry} = {lvalue};'.format(**locals())
190             print '        if (retrace::verbosity >= 2) {'
191             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
192             print '        }'
193             print '    }'
194     
195     def visitBlob(self, blob, lvalue, rvalue):
196         pass
197     
198     def visitString(self, string, lvalue, rvalue):
199         pass
200
201
202 class Retracer:
203
204     def retraceFunction(self, function):
205         print 'static void retrace_%s(trace::Call &call) {' % function.name
206         self.retraceFunctionBody(function)
207         print '}'
208         print
209
210     def retraceInterfaceMethod(self, interface, method):
211         print 'static void retrace_%s__%s(trace::Call &call) {' % (interface.name, method.name)
212         self.retraceInterfaceMethodBody(interface, method)
213         print '}'
214         print
215
216     def retraceFunctionBody(self, function):
217         if not function.sideeffects:
218             print '    (void)call;'
219             return
220
221         self.deserializeArgs(function)
222         
223         self.invokeFunction(function)
224
225         self.swizzleValues(function)
226
227     def retraceInterfaceMethodBody(self, interface, method):
228         if not method.sideeffects:
229             print '    (void)call;'
230             return
231
232         self.deserializeThisPointer(interface)
233
234         self.deserializeArgs(method)
235         
236         self.invokeInterfaceMethod(interface, method)
237
238         self.swizzleValues(method)
239
240     def deserializeThisPointer(self, interface):
241         print '    %s *_this;' % (interface.name,)
242         # FIXME
243
244     def deserializeArgs(self, function):
245         success = True
246         for arg in function.args:
247             arg_type = ConstRemover().visit(arg.type)
248             #print '    // %s ->  %s' % (arg.type, arg_type)
249             print '    %s %s;' % (arg_type, arg.name)
250             rvalue = 'call.arg(%u)' % (arg.index,)
251             lvalue = arg.name
252             try:
253                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
254             except NotImplementedError:
255                 success =  False
256                 print '    %s = 0; // FIXME' % arg.name
257
258         if not success:
259             print '    if (1) {'
260             self.failFunction(function)
261             if function.name[-1].islower():
262                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
263             print '    }'
264
265     def swizzleValues(self, function):
266         for arg in function.args:
267             if arg.output:
268                 arg_type = ConstRemover().visit(arg.type)
269                 rvalue = 'call.arg(%u)' % (arg.index,)
270                 lvalue = arg.name
271                 try:
272                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
273                 except NotImplementedError:
274                     print '    // XXX: %s' % arg.name
275         if function.type is not stdapi.Void:
276             rvalue = '*call.ret'
277             lvalue = '__result'
278             try:
279                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
280             except NotImplementedError:
281                 print '    // XXX: result'
282
283     def failFunction(self, function):
284         print '    if (retrace::verbosity >= 0) {'
285         print '        retrace::unsupported(call);'
286         print '    }'
287         print '    return;'
288
289     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
290         ValueDeserializer().visit(arg_type, lvalue, rvalue)
291     
292     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
293         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
294
295     def regiterSwizzledValue(self, type, lvalue, rvalue):
296         visitor = SwizzledValueRegistrator()
297         visitor.visit(type, lvalue, rvalue)
298
299     def invokeFunction(self, function):
300         arg_names = ", ".join(function.argNames())
301         if function.type is not stdapi.Void:
302             print '    %s __result;' % (function.type)
303             print '    __result = %s(%s);' % (function.name, arg_names)
304             print '    (void)__result;'
305         else:
306             print '    %s(%s);' % (function.name, arg_names)
307
308     def invokeInterfaceMethod(self, interface, method):
309         arg_names = ", ".join(method.argNames())
310         if method.type is not stdapi.Void:
311             print '    %s __result;' % (method.type)
312             print '    __result = _this->%s(%s);' % (method.name, arg_names)
313             print '    (void)__result;'
314         else:
315             print '    _this->%s(%s);' % (method.name, arg_names)
316
317     def filterFunction(self, function):
318         return True
319
320     table_name = 'retrace::callbacks'
321
322     def retraceApi(self, api):
323
324         print '#include "trace_parser.hpp"'
325         print '#include "retrace.hpp"'
326         print
327
328         types = api.getAllTypes()
329         handles = [type for type in types if isinstance(type, stdapi.Handle)]
330         handle_names = set()
331         for handle in handles:
332             if handle.name not in handle_names:
333                 if handle.key is None:
334                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
335                 else:
336                     key_name, key_type = handle.key
337                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
338                 handle_names.add(handle.name)
339         print
340
341         functions = filter(self.filterFunction, api.functions)
342         for function in functions:
343             self.retraceFunction(function)
344         interfaces = api.getAllInterfaces()
345         for interface in interfaces:
346             for method in interface.iterMethods():
347                 self.retraceInterfaceMethod(interface, method)
348
349         print 'const retrace::Entry %s[] = {' % self.table_name
350         for function in functions:
351             print '    {"%s", &retrace_%s},' % (function.name, function.name)
352         for interface in interfaces:
353             for method in interface.iterMethods():
354                 print '    {"%s::%s", &retrace_%s__%s},' % (interface.name, method.name, interface.name, method.name)
355         print '    {NULL, NULL}'
356         print '};'
357         print
358