]> git.cworth.org Git - apitrace/blob - retrace.py
Prevent derreference after free when retracing glFeedbackBuffer/glSelectBuffer.
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class ConstRemover(stdapi.Rebuilder):
37     '''Type visitor which strips out const qualifiers from types.'''
38
39     def visitConst(self, const):
40         return const.type
41
42     def visitOpaque(self, opaque):
43         return opaque
44
45
46 def lookupHandle(handle, value):
47     if handle.key is None:
48         return "__%s_map[%s]" % (handle.name, value)
49     else:
50         key_name, key_type = handle.key
51         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
52
53
54 class ValueDeserializer(stdapi.Visitor):
55
56     def visitLiteral(self, literal, lvalue, rvalue):
57         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
58
59     def visitConst(self, const, lvalue, rvalue):
60         self.visit(const.type, lvalue, rvalue)
61
62     def visitAlias(self, alias, lvalue, rvalue):
63         self.visit(alias.type, lvalue, rvalue)
64     
65     def visitEnum(self, enum, lvalue, rvalue):
66         print '    %s = static_cast<%s>((%s).toSInt());' % (lvalue, enum, rvalue)
67
68     def visitBitmask(self, bitmask, lvalue, rvalue):
69         self.visit(bitmask.type, lvalue, rvalue)
70
71     def visitArray(self, array, lvalue, rvalue):
72         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
73         print '    if (__a%s) {' % (array.tag)
74         length = '__a%s->values.size()' % array.tag
75         print '        %s = _allocator.alloc<%s>(%s);' % (lvalue, array.type, length)
76         index = '__j' + array.tag
77         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
78         try:
79             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
80         finally:
81             print '        }'
82             print '    } else {'
83             print '        %s = NULL;' % lvalue
84             print '    }'
85     
86     def visitPointer(self, pointer, lvalue, rvalue):
87         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
88         print '    if (__a%s) {' % (pointer.tag)
89         print '        %s = _allocator.alloc<%s>();' % (lvalue, pointer.type)
90         try:
91             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
92         finally:
93             print '    } else {'
94             print '        %s = NULL;' % lvalue
95             print '    }'
96
97     def visitIntPointer(self, pointer, lvalue, rvalue):
98         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, pointer, rvalue)
99
100     def visitLinearPointer(self, pointer, lvalue, rvalue):
101         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, pointer, rvalue)
102
103     def visitHandle(self, handle, lvalue, rvalue):
104         #OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
105         self.visit(handle.type, lvalue, rvalue);
106         new_lvalue = lookupHandle(handle, lvalue)
107         print '    if (retrace::verbosity >= 2) {'
108         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
109         print '    }'
110         print '    %s = %s;' % (lvalue, new_lvalue)
111     
112     def visitBlob(self, blob, lvalue, rvalue):
113         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
114     
115     def visitString(self, string, lvalue, rvalue):
116         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
117
118
119 class OpaqueValueDeserializer(ValueDeserializer):
120     '''Value extractor that also understands opaque values.
121
122     Normally opaque values can't be retraced, unless they are being extracted
123     in the context of handles.'''
124
125     def visitOpaque(self, opaque, lvalue, rvalue):
126         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
127
128
129 class SwizzledValueRegistrator(stdapi.Visitor):
130     '''Type visitor which will register (un)swizzled value pairs, to later be
131     swizzled.'''
132
133     def visitLiteral(self, literal, lvalue, rvalue):
134         pass
135
136     def visitAlias(self, alias, lvalue, rvalue):
137         self.visit(alias.type, lvalue, rvalue)
138     
139     def visitEnum(self, enum, lvalue, rvalue):
140         pass
141
142     def visitBitmask(self, bitmask, lvalue, rvalue):
143         pass
144
145     def visitArray(self, array, lvalue, rvalue):
146         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
147         print '    if (__a%s) {' % (array.tag)
148         length = '__a%s->values.size()' % array.tag
149         index = '__j' + array.tag
150         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
151         try:
152             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
153         finally:
154             print '        }'
155             print '    }'
156     
157     def visitPointer(self, pointer, lvalue, rvalue):
158         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
159         print '    if (__a%s) {' % (pointer.tag)
160         try:
161             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
162         finally:
163             print '    }'
164     
165     def visitIntPointer(self, pointer, lvalue, rvalue):
166         pass
167     
168     def visitLinearPointer(self, pointer, lvalue, rvalue):
169         assert pointer.size is not None
170         if pointer.size is not None:
171             print r'    retrace::addRegion((%s).toUIntPtr(), %s, %s);' % (rvalue, lvalue, pointer.size)
172
173     def visitHandle(self, handle, lvalue, rvalue):
174         print '    %s __orig_result;' % handle.type
175         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
176         if handle.range is None:
177             rvalue = "__orig_result"
178             entry = lookupHandle(handle, rvalue) 
179             print "    %s = %s;" % (entry, lvalue)
180             print '    if (retrace::verbosity >= 2) {'
181             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
182             print '    }'
183         else:
184             i = '__h' + handle.tag
185             lvalue = "%s + %s" % (lvalue, i)
186             rvalue = "__orig_result + %s" % (i,)
187             entry = lookupHandle(handle, rvalue) 
188             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
189             print '        {entry} = {lvalue};'.format(**locals())
190             print '        if (retrace::verbosity >= 2) {'
191             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
192             print '        }'
193             print '    }'
194     
195     def visitBlob(self, blob, lvalue, rvalue):
196         pass
197     
198     def visitString(self, string, lvalue, rvalue):
199         pass
200
201
202 class Retracer:
203
204     def retraceFunction(self, function):
205         print 'static void retrace_%s(trace::Call &call) {' % function.name
206         self.retraceFunctionBody(function)
207         print '}'
208         print
209
210     def retraceFunctionBody(self, function):
211         assert function.sideeffects
212
213         print '    retrace::ScopedAllocator _allocator;'
214         print '    (void)_allocator;'
215         success = True
216         for arg in function.args:
217             arg_type = ConstRemover().visit(arg.type)
218             #print '    // %s ->  %s' % (arg.type, arg_type)
219             print '    %s %s;' % (arg_type, arg.name)
220             rvalue = 'call.arg(%u)' % (arg.index,)
221             lvalue = arg.name
222             try:
223                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
224             except NotImplementedError:
225                 success = False
226                 print '    %s = 0; // FIXME' % arg.name
227         if not success:
228             print '    if (1) {'
229             self.failFunction(function)
230             print '    }'
231         self.invokeFunction(function)
232         for arg in function.args:
233             if arg.output:
234                 arg_type = ConstRemover().visit(arg.type)
235                 rvalue = 'call.arg(%u)' % (arg.index,)
236                 lvalue = arg.name
237                 try:
238                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
239                 except NotImplementedError:
240                     print '    // XXX: %s' % arg.name
241         if function.type is not stdapi.Void:
242             rvalue = '*call.ret'
243             lvalue = '__result'
244             try:
245                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
246             except NotImplementedError:
247                 print '    // XXX: result'
248         if not success:
249             if function.name[-1].islower():
250                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
251
252     def failFunction(self, function):
253         print '    if (retrace::verbosity >= 0) {'
254         print '        retrace::unsupported(call);'
255         print '    }'
256         print '    return;'
257
258     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
259         ValueDeserializer().visit(arg_type, lvalue, rvalue)
260     
261     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
262         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
263
264     def regiterSwizzledValue(self, type, lvalue, rvalue):
265         visitor = SwizzledValueRegistrator()
266         visitor.visit(type, lvalue, rvalue)
267
268     def invokeFunction(self, function):
269         arg_names = ", ".join(function.argNames())
270         if function.type is not stdapi.Void:
271             print '    %s __result;' % (function.type)
272             print '    __result = %s(%s);' % (function.name, arg_names)
273             print '    (void)__result;'
274         else:
275             print '    %s(%s);' % (function.name, arg_names)
276
277     def filterFunction(self, function):
278         return True
279
280     table_name = 'retrace::callbacks'
281
282     def retraceFunctions(self, functions):
283         functions = filter(self.filterFunction, functions)
284
285         for function in functions:
286             if function.sideeffects:
287                 self.retraceFunction(function)
288
289         print 'const retrace::Entry %s[] = {' % self.table_name
290         for function in functions:
291             if function.sideeffects:
292                 print '    {"%s", &retrace_%s},' % (function.name, function.name)
293             else:
294                 print '    {"%s", &retrace::ignore},' % (function.name,)
295         print '    {NULL, NULL}'
296         print '};'
297         print
298
299
300     def retraceApi(self, api):
301
302         print '#include "os_time.hpp"'
303         print '#include "trace_parser.hpp"'
304         print '#include "retrace.hpp"'
305         print
306
307         types = api.getAllTypes()
308         handles = [type for type in types if isinstance(type, stdapi.Handle)]
309         handle_names = set()
310         for handle in handles:
311             if handle.name not in handle_names:
312                 if handle.key is None:
313                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
314                 else:
315                     key_name, key_type = handle.key
316                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
317                 handle_names.add(handle.name)
318         print
319
320         self.retraceFunctions(api.functions)
321