]> git.cworth.org Git - apitrace/blob - retrace.py
Bumple libpng source.
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29 import stdapi
30 import glapi
31 from codegen import *
32
33
34 class ConstRemover(stdapi.Rebuilder):
35
36     def visit_const(self, const):
37         return const.type
38
39     def visit_opaque(self, opaque):
40         expr = opaque.expr
41         if expr.startswith('const '):
42             expr = expr[6:]
43         return stdapi.Opaque(expr)
44
45
46 class ValueExtractor(stdapi.Visitor):
47
48     def visit_literal(self, literal, lvalue, rvalue):
49         if literal.format == 'Bool':
50             print '    %s = static_cast<bool>(%s);' % (lvalue, rvalue)
51         else:
52             print '    %s = %s;' % (lvalue, rvalue)
53
54     def visit_const(self, const, lvalue, rvalue):
55         self.visit(const.type, lvalue, rvalue)
56
57     def visit_alias(self, alias, lvalue, rvalue):
58         self.visit(alias.type, lvalue, rvalue)
59     
60     def visit_enum(self, enum, lvalue, rvalue):
61         print '    %s = %s;' % (lvalue, rvalue)
62
63     def visit_bitmask(self, bitmask, lvalue, rvalue):
64         self.visit(bitmask.type, lvalue, rvalue)
65
66     def visit_array(self, array, lvalue, rvalue):
67         print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (array.id, rvalue)
68         print '    if (__a%s) {' % (array.id)
69         length = '__a%s->values.size()' % array.id
70         print '        %s = new %s[%s];' % (lvalue, array.type, length)
71         index = '__j' + array.id
72         print '        for(size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
73         try:
74             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.id, index))
75         finally:
76             print '        }'
77             print '    } else {'
78             print '        %s = NULL;' % lvalue
79             print '    }'
80     
81     def visit_pointer(self, pointer, lvalue, rvalue):
82         print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (pointer.id, rvalue)
83         print '    if (__a%s) {' % (pointer.id)
84         print '        %s = new %s;' % (lvalue, pointer.type)
85         try:
86             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.id,))
87         finally:
88             print '    } else {'
89             print '        %s = NULL;' % lvalue
90             print '    }'
91
92     def visit_handle(self, handle, lvalue, rvalue):
93         self.visit(handle.type, lvalue, "__%s_map[%s]" %(handle.name, rvalue));
94         print '    if (verbosity >= 2)'
95         print '        std::cout << "%s " << static_cast<%s>(%s) << " <- " << %s << "\\n";' % (handle.name, handle.type, rvalue, lvalue)
96     
97     def visit_blob(self, blob, lvalue, rvalue):
98         print '    %s = static_cast<%s>((%s).blob());' % (lvalue, blob, rvalue)
99     
100     def visit_string(self, string, lvalue, rvalue):
101         print '    %s = (%s)((%s).string());' % (lvalue, string.expr, rvalue)
102
103
104
105 class ValueWrapper(stdapi.Visitor):
106
107     def visit_literal(self, literal, lvalue, rvalue):
108         pass
109
110     def visit_alias(self, alias, lvalue, rvalue):
111         self.visit(alias.type, lvalue, rvalue)
112     
113     def visit_enum(self, enum, lvalue, rvalue):
114         pass
115
116     def visit_bitmask(self, bitmask, lvalue, rvalue):
117         pass
118
119     def visit_array(self, array, lvalue, rvalue):
120         print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (array.id, rvalue)
121         print '    if (__a%s) {' % (array.id)
122         length = '__a%s->values.size()' % array.id
123         index = '__j' + array.id
124         print '        for(size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
125         try:
126             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.id, index))
127         finally:
128             print '        }'
129             print '    }'
130     
131     def visit_pointer(self, pointer, lvalue, rvalue):
132         print '    const Trace::Array *__a%s = dynamic_cast<const Trace::Array *>(&%s);' % (pointer.id, rvalue)
133         print '    if (__a%s) {' % (pointer.id)
134         try:
135             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.id,))
136         finally:
137             print '    }'
138     
139
140     def visit_handle(self, handle, lvalue, rvalue):
141         if handle.range is None:
142             print "    __{handle.name}_map[static_cast<{handle.type}>({rvalue})] = {lvalue};".format(**locals())
143             print '    if (verbosity >= 2)'
144             print '        std::cout << "{handle.name} " << static_cast<{handle.type}>({rvalue}) << " -> " << {lvalue} << "\\n";'.format(**locals())
145         else:
146             i = '__h' + handle.id
147             print '    for({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
148             print '        __{handle.name}_map[static_cast<{handle.type}>({rvalue}) + {i}] = {lvalue} + {i};'.format(**locals())
149             print '        if (verbosity >= 2)'
150             print '            std::cout << "{handle.name} " << (static_cast<{handle.type}>({rvalue}) + {i}) << " -> " << ({lvalue} + {i}) << "\\n";'.format(**locals())
151             print '    }'
152     
153     def visit_blob(self, blob, lvalue, rvalue):
154         pass
155     
156     def visit_string(self, string, lvalue, rvalue):
157         pass
158
159
160
161 class Retracer:
162
163     def retrace_function(self, function):
164         print 'static void retrace_%s(Trace::Call &call) {' % function.name
165         success = True
166         for arg in function.args:
167             arg_type = ConstRemover().visit(arg.type)
168             #print '    // %s ->  %s' % (arg.type, arg_type)
169             print '    %s %s;' % (arg_type, arg.name)
170             rvalue = 'call.arg(%u)' % (arg.index,)
171             lvalue = arg.name
172             try:
173                 self.extract_arg(function, arg, arg_type, lvalue, rvalue)
174             except NotImplementedError:
175                 success = False
176                 print '    %s = 0; // FIXME' % arg.name
177         if not success:
178             self.fail_function(function)
179         self.call_function(function)
180         for arg in function.args:
181             if arg.output:
182                 arg_type = ConstRemover().visit(arg.type)
183                 rvalue = 'call.arg(%u)' % (arg.index,)
184                 lvalue = arg.name
185                 try:
186                     ValueWrapper().visit(arg_type, lvalue, rvalue)
187                 except NotImplementedError:
188                     print '   // FIXME: %s' % arg.name
189         if function.type is not stdapi.Void:
190             rvalue = '*call.ret'
191             lvalue = '__result'
192             try:
193                 ValueWrapper().visit(function.type, lvalue, rvalue)
194             except NotImplementedError:
195                 print '   // FIXME: result'
196         print '}'
197         print
198
199     def fail_function(self, function):
200         print '    std::cerr << "warning: unsupported call %s\\n";' % function.name
201         print '    return;'
202
203     def extract_arg(self, function, arg, arg_type, lvalue, rvalue):
204         ValueExtractor().visit(arg_type, lvalue, rvalue)
205
206     def call_function(self, function):
207         arg_names = ", ".join([arg.name for arg in function.args])
208         if function.type is not stdapi.Void:
209             print '    %s __result;' % (function.type)
210             print '    __result = %s(%s);' % (function.name, arg_names)
211         else:
212             print '    %s(%s);' % (function.name, arg_names)
213
214     def filter_function(self, function):
215         return True
216
217     def retrace_functions(self, functions):
218         functions = filter(self.filter_function, functions)
219
220         for function in functions:
221             if function.sideeffects:
222                 self.retrace_function(function)
223
224         print 'static bool retrace_call(Trace::Call &call) {'
225         print '    const char *name = call.name().c_str();'
226         print
227         print '    if (verbosity >=1 ) {'
228         print '        std::cout << call;'
229         print '        std::cout.flush();'
230         print '    };'
231         print
232
233         func_dict = dict([(function.name, function) for function in functions])
234
235         def handle_case(function_name):
236             function = func_dict[function_name]
237             if function.sideeffects:
238                 print '        retrace_%s(call);' % function.name
239             print '        return true;'
240     
241         string_switch('name', func_dict.keys(), handle_case)
242
243         print '    std::cerr << "warning: unknown call " << call.name() << "\\n";'
244         print '    return false;'
245         print '}'
246         print
247
248
249     def retrace_api(self, api):
250
251         print '#include "trace_parser.hpp"'
252         print
253
254         types = api.all_types()
255         handles = [type for type in types if isinstance(type, stdapi.Handle)]
256         handle_names = set()
257         for handle in handles:
258             if handle.name not in handle_names:
259                 print 'static std::map<%s, %s> __%s_map;' % (handle.type, handle.type, handle.name)
260                 handle_names.add(handle.name)
261         print
262
263         print 'unsigned verbosity = 0;'
264         print
265
266         self.retrace_functions(api.functions)
267