1use std::collections::HashMap;
7
8use rustpython_parser::ast::{Expr, Operator, Stmt};
9use tracing::warn;
10
11use crate::utils::TargetSpec;
12
13use super::expr::{expr_to_rust, translate_for_iter, translate_len_guard, translate_while_test};
14use super::infer::{infer_assign_type, infer_params};
15
16pub struct Translation {
18 pub params: Vec<(String, String)>,
19 pub return_type: String,
20 pub body: String,
21 pub fallback: bool,
22}
23
24pub(super) struct BodyTranslation {
26 pub return_type: String,
27 pub body: String,
28 pub fallback: bool,
29}
30
31pub fn translate_function_body(target: &TargetSpec, module: &[Stmt]) -> Option<Translation> {
36 let func_def = module.iter().find_map(|stmt| match stmt {
37 Stmt::FunctionDef(def) if def.name == target.func => Some(def),
38 _ => None,
39 })?;
40
41 let mut params = infer_params(func_def.args.as_ref());
42 if params.is_empty() {
43 params.push(("data".to_string(), "Vec<f64>".to_string()));
44 }
45
46 if let Some(Stmt::Return(ret)) = func_def.body.first()
48 && let Some(expr) = &ret.value
49 {
50 match expr.as_ref() {
51 Expr::Name(name) => {
52 return Some(Translation {
53 params,
54 return_type: "Vec<f64>".to_string(),
55 body: format!(
56 "// returning input name `{}` as-is\n Ok({})",
57 name.id, name.id
58 ),
59 fallback: false,
60 });
61 }
62 Expr::Constant(c) => {
63 return Some(Translation {
64 params,
65 return_type: "f64".to_string(),
66 body: format!(
67 "// returning constant from Python: {:?}\n Ok({})",
68 c.value,
69 expr_to_rust(expr)
70 ),
71 fallback: false,
72 });
73 }
74 _ => {}
75 }
76 }
77
78 if let Some(translated) = translate_body(&func_def.body) {
80 let first_param_name = params
81 .first()
82 .map(|(n, _)| n.as_str())
83 .unwrap_or("data")
84 .to_string();
85 if translated.fallback {
86 return Some(Translation {
87 params,
88 return_type: "Vec<f64>".to_string(),
89 body: format!("// fallback: echo input\n Ok({first_param_name})"),
90 fallback: true,
91 });
92 }
93 return Some(Translation {
94 params,
95 return_type: translated.return_type,
96 body: translated.body,
97 fallback: translated.fallback,
98 });
99 }
100
101 warn!(func = %target.func, "unable to translate function body; echoing input");
102 Some(Translation {
103 params,
104 return_type: "Vec<f64>".to_string(),
105 body: "// fallback: echo input\n Ok(data)".to_string(),
106 fallback: true,
107 })
108}
109
110pub(super) fn translate_body(body: &[Stmt]) -> Option<BodyTranslation> {
111 translate_body_inner(body, 1)
112}
113
114pub(super) fn translate_body_inner(body: &[Stmt], depth: usize) -> Option<BodyTranslation> {
116 if body.is_empty() {
117 return None;
118 }
119
120 let indent = " ".repeat(depth);
121 let mut var_types: HashMap<String, &str> = HashMap::new();
122
123 let mut out = String::new();
125 let mut had_unhandled = false;
126 let mut inferred_return: Option<String> = None;
127 let mut had_return = false;
128
129 for stmt in body {
130 if let Stmt::Assign(assign) = stmt
132 && let Some(target) = assign.targets.first()
133 {
134 if let Expr::BinOp(binop) = assign.value.as_ref()
136 && matches!(binop.op, Operator::Mult)
137 && let Expr::List(lst) = binop.left.as_ref()
138 && lst.elts.len() == 1
139 && let Expr::Name(name_target) = target
140 {
141 var_types.insert(name_target.id.to_string(), "vec");
142 }
143 if let Expr::ListComp(_) = assign.value.as_ref()
144 && let Expr::Name(name_target) = target
145 {
146 var_types.insert(name_target.id.to_string(), "vec");
147 }
148 }
149
150 match translate_stmt_inner(stmt, depth) {
151 Some(line) => {
152 if line.trim_start().starts_with("return ")
153 && let Stmt::Return(ret) = stmt
154 && let Some(expr) = &ret.value
155 {
156 had_return = true;
157 let ret_ty = infer_return_type(expr.as_ref(), &var_types);
158 inferred_return = Some(ret_ty);
159 }
160 out.push_str(&indent);
161 out.push_str(&line);
162 if !line.ends_with('\n') {
163 out.push('\n');
164 }
165 }
166 None => {
167 had_unhandled = true;
168 out.push_str(&indent);
169 out.push_str("// Unhandled stmt\n");
170 }
171 }
172 }
173
174 if !had_return
175 && let Some(ret_var) = var_types.keys().find(|k| *k == "result" || *k == "output")
176 {
177 inferred_return = Some("Vec<f64>".to_string());
178 out.push_str(&format!("{indent}return Ok({});\n", ret_var));
179 }
180
181 Some(BodyTranslation {
182 return_type: inferred_return.unwrap_or_else(|| "f64".to_string()),
183 body: out,
184 fallback: had_unhandled,
185 })
186}
187
188pub(super) fn translate_stmt_inner(stmt: &Stmt, depth: usize) -> Option<String> {
191 match stmt {
192 Stmt::Assign(assign) => {
193 if let (Some(target), value) = (assign.targets.first(), &assign.value) {
194 if let Expr::Subscript(sub) = target {
196 let lhs = format!("{}[{}]", expr_to_rust(&sub.value), expr_to_rust(&sub.slice));
197 let rhs = expr_to_rust(value);
198 return Some(format!("{} = {};", lhs, rhs));
199 }
200 if let Expr::BinOp(binop) = value.as_ref()
202 && matches!(binop.op, Operator::Mult)
203 && let Expr::List(lst) = binop.left.as_ref()
204 && lst.elts.len() == 1
205 {
206 let fill = expr_to_rust(&lst.elts[0]);
207 let size = expr_to_rust(&binop.right);
208 let var_name = match target {
209 Expr::Name(n) => n.id.to_string(),
210 _ => "result".to_string(),
211 };
212 let fill_rust = if fill.contains('.') {
213 format!("{}f64", fill)
214 } else {
215 fill.clone()
216 };
217 return Some(format!(
218 "let mut {var} = vec![{fill}; {size}];",
219 var = var_name,
220 fill = fill_rust,
221 size = size
222 ));
223 }
224 if let Expr::ListComp(lc) = value.as_ref()
227 && lc.generators.len() == 1
228 {
229 let comprehension = &lc.generators[0];
230 let iter_str = expr_to_rust(&comprehension.iter);
231 let loop_var = expr_to_rust(&comprehension.target);
232 let elt = expr_to_rust(&lc.elt);
233 let var_name = match target {
234 Expr::Name(n) => n.id.to_string(),
235 _ => "result".to_string(),
236 };
237 return Some(format!(
238 "let {var}: Vec<f64> = {iter}.iter().map(|{lv}| {elt}).collect();",
239 var = var_name,
240 iter = iter_str,
241 lv = loop_var,
242 elt = elt,
243 ));
244 }
245 let lhs = match target {
247 Expr::Name(n) => {
248 let type_suffix = infer_assign_type(value);
249 format!("let mut {}{}", n.id, type_suffix)
250 }
251 Expr::Attribute(_) => format!("// attribute assign {}", expr_to_rust(target)),
252 _ => format!("// complex assign {}", expr_to_rust(target)),
253 };
254 let rhs = expr_to_rust(value);
255 return Some(format!("{} = {};", lhs, rhs));
256 }
257 None
258 }
259 Stmt::For(for_stmt) => {
260 let iter_str = translate_for_iter(&for_stmt.iter);
261 let loop_var = expr_to_rust(&for_stmt.target);
262 let inner = translate_body_inner(for_stmt.body.as_slice(), depth + 1);
263 let loop_body = inner
264 .map(|b| b.body)
265 .unwrap_or_else(|| " // unhandled loop body".to_string());
266 Some(format!(
267 "for {loop_var} in {iter_str} {{\n{loop_body}\n{indent}}}",
268 loop_var = loop_var,
269 iter_str = iter_str,
270 loop_body = loop_body,
271 indent = " ".repeat(depth)
272 ))
273 }
274 Stmt::AugAssign(aug) => {
275 let lhs = expr_to_rust(&aug.target);
276 let rhs = expr_to_rust(&aug.value);
277 let op = match aug.op {
278 Operator::Add => "+=",
279 Operator::Sub => "-=",
280 Operator::Mult => "*=",
281 Operator::Div => "/=",
282 _ => "+=",
283 };
284 Some(format!("{} {} {};", lhs, op, rhs))
285 }
286 Stmt::While(while_stmt) => {
287 let test = translate_while_test(&while_stmt.test);
288 let inner = translate_body_inner(while_stmt.body.as_slice(), depth + 1);
289 let loop_body = inner
290 .map(|b| b.body)
291 .unwrap_or_else(|| format!("{} // unhandled while body", " ".repeat(depth)));
292 Some(format!(
293 "while {test} {{\n{loop_body}\n{indent}}}",
294 test = test,
295 loop_body = loop_body,
296 indent = " ".repeat(depth)
297 ))
298 }
299 Stmt::Return(ret) => {
300 if let Some(v) = &ret.value {
301 Some(format!("return Ok({});", expr_to_rust(v)))
302 } else {
303 Some("return Ok(());".to_string())
304 }
305 }
306 Stmt::Expr(expr_stmt) => {
307 if let Expr::Constant(c) = expr_stmt.value.as_ref()
309 && matches!(c.value, rustpython_parser::ast::Constant::Str(_))
310 {
311 return Some("// docstring omitted".to_string());
312 }
313 if let Expr::Call(call) = expr_stmt.value.as_ref()
314 && let Expr::Attribute(attr) = call.func.as_ref()
315 && attr.attr.as_str() == "append"
316 && call.args.len() == 1
317 {
318 let target = expr_to_rust(&attr.value);
319 let arg = expr_to_rust(&call.args[0]);
320 return Some(format!("{target}.push({arg});"));
321 }
322 Some(format!("// expr: {}", expr_to_rust(&expr_stmt.value)))
323 }
324 Stmt::If(if_stmt) => {
325 if let Some(guard) = translate_len_guard(&if_stmt.test) {
326 return Some(guard);
327 }
328 let test = expr_to_rust(&if_stmt.test);
329 let body = translate_body_inner(if_stmt.body.as_slice(), depth + 1)
330 .map(|b| b.body)
331 .unwrap_or_else(|| "// unhandled if body".to_string());
332 let orelse = if !if_stmt.orelse.is_empty() {
333 translate_body_inner(if_stmt.orelse.as_slice(), depth + 1)
334 .map(|b| b.body)
335 .unwrap_or_else(|| "// unhandled else body".to_string())
336 } else {
337 String::new()
338 };
339 let else_block = if orelse.is_empty() {
340 String::new()
341 } else {
342 format!(" else {{\n{}\n{}}}", orelse, " ".repeat(depth))
343 };
344 Some(format!(
345 "if {test} {{\n{body}\n{indent}}}{else_block}",
346 test = test,
347 body = body,
348 indent = " ".repeat(depth),
349 else_block = else_block
350 ))
351 }
352 _ => None,
353 }
354}
355
356fn infer_return_type(expr: &Expr, var_types: &HashMap<String, &str>) -> String {
357 match expr {
358 Expr::Name(n) => {
359 if let Some(&"vec") = var_types.get(n.id.as_str()) {
360 return "Vec<f64>".to_string();
361 }
362 "f64".to_string()
363 }
364 Expr::List(_) | Expr::ListComp(_) => "Vec<f64>".to_string(),
365 Expr::Tuple(_) => "Vec<f64>".to_string(),
366 _ => "f64".to_string(),
367 }
368}