1use std::collections::HashMap;
7
8use rustpython_parser::ast::{Expr, Operator, Stmt};
9use tracing::warn;
10
11use crate::utils::TargetSpec;
12
13use super::expr::{expr_to_rust, translate_for_iter, translate_len_guard, translate_while_test};
14use super::infer::{infer_assign_type, infer_params};
15
16pub struct Translation {
18 pub params: Vec<(String, String)>,
19 pub return_type: String,
20 pub body: String,
21 pub fallback: bool,
22}
23
24pub(super) struct BodyTranslation {
26 pub return_type: String,
27 pub body: String,
28 pub fallback: bool,
29}
30
31pub fn translate_function_body(target: &TargetSpec, module: &[Stmt]) -> Option<Translation> {
36 let func_def = module
37 .iter()
38 .find_map(|stmt| match stmt {
39 Stmt::FunctionDef(def) if def.name == target.func => Some(def),
40 _ => None,
41 })
42 .or_else(|| {
43 module.iter().find_map(|stmt| match stmt {
44 Stmt::FunctionDef(def) => Some(def),
45 _ => None,
46 })
47 })?;
48
49 let mut params = infer_params(func_def.args.as_ref());
50 if params.is_empty() {
51 params.push(("data".to_string(), "Vec<f64>".to_string()));
52 }
53
54 if let Some(Stmt::Return(ret)) = func_def.body.first()
56 && let Some(expr) = &ret.value
57 {
58 match expr.as_ref() {
59 Expr::Name(name) => {
60 return Some(Translation {
61 params,
62 return_type: "Vec<f64>".to_string(),
63 body: format!(
64 "// returning input name `{}` as-is\n Ok({})",
65 name.id, name.id
66 ),
67 fallback: false,
68 });
69 }
70 Expr::Constant(c) => {
71 return Some(Translation {
72 params,
73 return_type: "f64".to_string(),
74 body: format!(
75 "// returning constant from Python: {:?}\n Ok({})",
76 c.value,
77 expr_to_rust(expr)
78 ),
79 fallback: false,
80 });
81 }
82 _ => {}
83 }
84 }
85
86 if let Some(translated) = translate_body(&func_def.body) {
88 return Some(Translation {
89 params,
90 return_type: translated.return_type,
91 body: translated.body,
92 fallback: translated.fallback,
93 });
94 }
95
96 warn!(func = %target.func, "unable to translate function body; echoing input");
97 Some(Translation {
98 params,
99 return_type: "Vec<f64>".to_string(),
100 body: "// fallback: echo input\n Ok(data)".to_string(),
101 fallback: true,
102 })
103}
104
105pub(super) fn translate_body(body: &[Stmt]) -> Option<BodyTranslation> {
106 translate_body_inner(body, 1)
107}
108
109pub(super) fn translate_body_inner(body: &[Stmt], depth: usize) -> Option<BodyTranslation> {
111 if body.is_empty() {
112 return None;
113 }
114
115 let indent = " ".repeat(depth);
116 let mut var_types: HashMap<String, &str> = HashMap::new();
117
118 let mut out = String::new();
120 let mut had_unhandled = false;
121 let mut inferred_return: Option<String> = None;
122 let mut had_return = false;
123
124 for stmt in body {
125 if let Stmt::Assign(assign) = stmt
127 && let Some(target) = assign.targets.first()
128 {
129 if let Expr::BinOp(binop) = assign.value.as_ref()
131 && matches!(binop.op, Operator::Mult)
132 && let Expr::List(lst) = binop.left.as_ref()
133 && lst.elts.len() == 1
134 && let Expr::Name(name_target) = target
135 {
136 var_types.insert(name_target.id.to_string(), "vec");
137 }
138 if let Expr::ListComp(_) = assign.value.as_ref()
139 && let Expr::Name(name_target) = target
140 {
141 var_types.insert(name_target.id.to_string(), "vec");
142 }
143 if let Expr::Subscript(sub) = assign.value.as_ref()
144 && matches!(sub.slice.as_ref(), Expr::Slice(_))
145 && let Expr::Name(name_target) = target
146 {
147 var_types.insert(name_target.id.to_string(), "vec");
148 }
149 if let Expr::Dict(dict) = assign.value.as_ref()
150 && dict.keys.is_empty()
151 && let Expr::Name(name_target) = target
152 {
153 var_types.insert(name_target.id.to_string(), "map");
154 }
155 }
156
157 match translate_stmt_inner(stmt, depth) {
158 Some(line) => {
159 if line.trim_start().starts_with("return ")
160 && let Stmt::Return(ret) = stmt
161 && let Some(expr) = &ret.value
162 {
163 had_return = true;
164 let ret_ty = infer_return_type(expr.as_ref(), &var_types);
165 inferred_return = Some(ret_ty);
166 }
167 out.push_str(&indent);
168 out.push_str(&line);
169 if !line.ends_with('\n') {
170 out.push('\n');
171 }
172 }
173 None => {
174 had_unhandled = true;
175 out.push_str(&indent);
176 out.push_str("// Unhandled stmt\n");
177 }
178 }
179 }
180
181 if !had_return
182 && let Some(ret_var) = var_types.keys().find(|k| *k == "result" || *k == "output")
183 {
184 inferred_return = Some("Vec<f64>".to_string());
185 out.push_str(&format!("{indent}return Ok({});\n", ret_var));
186 }
187
188 Some(BodyTranslation {
189 return_type: inferred_return.unwrap_or_else(|| "f64".to_string()),
190 body: out,
191 fallback: had_unhandled,
192 })
193}
194
195pub(super) fn translate_stmt_inner(stmt: &Stmt, depth: usize) -> Option<String> {
198 match stmt {
199 Stmt::Assign(assign) => {
200 if let (Some(target), value) = (assign.targets.first(), &assign.value) {
201 if let Expr::Dict(dict) = value.as_ref()
202 && dict.keys.is_empty()
203 {
204 if let Expr::Name(n) = target {
205 return Some(format!(
206 "let mut {}: std::collections::HashMap<(i64, i64), i64> = std::collections::HashMap::new();",
207 n.id
208 ));
209 }
210 }
211 if let Expr::Subscript(sub) = target {
213 let lhs = format!("{}[{}]", expr_to_rust(&sub.value), expr_to_rust(&sub.slice));
214 let rhs = expr_to_rust(value);
215 return Some(format!("{} = {};", lhs, rhs));
216 }
217 if let Expr::Subscript(sub) = value.as_ref()
219 && matches!(sub.slice.as_ref(), Expr::Slice(_))
220 && let Expr::Name(name_target) = target
221 {
222 let rhs = expr_to_rust(value);
224 return Some(format!("let mut {} = {};", name_target.id, rhs));
225 }
226 if let Expr::BinOp(binop) = value.as_ref()
228 && matches!(binop.op, Operator::Mult)
229 && let Expr::List(lst) = binop.left.as_ref()
230 && lst.elts.len() == 1
231 {
232 let fill = expr_to_rust(&lst.elts[0]);
233 let size = expr_to_rust(&binop.right);
234 let var_name = match target {
235 Expr::Name(n) => n.id.to_string(),
236 _ => "result".to_string(),
237 };
238 let fill_rust = if fill.contains('.') {
239 format!("{}f64", fill)
240 } else {
241 fill.clone()
242 };
243 return Some(format!(
244 "let mut {var} = vec![{fill}; {size}];",
245 var = var_name,
246 fill = fill_rust,
247 size = size
248 ));
249 }
250 if let Expr::ListComp(lc) = value.as_ref()
253 && lc.generators.len() == 1
254 {
255 let comprehension = &lc.generators[0];
256 let iter_str = expr_to_rust(&comprehension.iter);
257 let loop_var = expr_to_rust(&comprehension.target);
258 let elt = expr_to_rust(&lc.elt);
259 let var_name = match target {
260 Expr::Name(n) => n.id.to_string(),
261 _ => "result".to_string(),
262 };
263 return Some(format!(
264 "let {var}: Vec<f64> = {iter}.iter().map(|{lv}| {elt}).collect();",
265 var = var_name,
266 iter = iter_str,
267 lv = loop_var,
268 elt = elt,
269 ));
270 }
271 let lhs = match target {
273 Expr::Name(n) => {
274 let type_suffix = infer_assign_type(value);
275 format!("let mut {}{}", n.id, type_suffix)
276 }
277 Expr::Attribute(_) => format!("// attribute assign {}", expr_to_rust(target)),
278 _ => format!("// complex assign {}", expr_to_rust(target)),
279 };
280 let rhs = expr_to_rust(value);
281 return Some(format!("{} = {};", lhs, rhs));
282 }
283 None
284 }
285 Stmt::For(for_stmt) => {
286 if !for_stmt.orelse.is_empty() {
287 return None;
288 }
289 let iter_str = translate_for_iter(&for_stmt.iter);
290 let loop_var = expr_to_rust(&for_stmt.target);
291 let inner = translate_body_inner(for_stmt.body.as_slice(), depth + 1);
292 let loop_body = inner
293 .map(|b| b.body)
294 .unwrap_or_else(|| " // unhandled loop body".to_string());
295 Some(format!(
296 "for {loop_var} in {iter_str} {{\n{loop_body}\n{indent}}}",
297 loop_var = loop_var,
298 iter_str = iter_str,
299 loop_body = loop_body,
300 indent = " ".repeat(depth)
301 ))
302 }
303 Stmt::AugAssign(aug) => {
304 let lhs = expr_to_rust(&aug.target);
305 let rhs = expr_to_rust(&aug.value);
306 let op = match aug.op {
307 Operator::Add => "+=",
308 Operator::Sub => "-=",
309 Operator::Mult => "*=",
310 Operator::Div => "/=",
311 _ => "+=",
312 };
313 Some(format!("{} {} {};", lhs, op, rhs))
314 }
315 Stmt::While(while_stmt) => {
316 let test = translate_while_test(&while_stmt.test);
317 let inner = translate_body_inner(while_stmt.body.as_slice(), depth + 1);
318 let loop_body = inner
319 .map(|b| b.body)
320 .unwrap_or_else(|| format!("{} // unhandled while body", " ".repeat(depth)));
321 Some(format!(
322 "while {test} {{\n{loop_body}\n{indent}}}",
323 test = test,
324 loop_body = loop_body,
325 indent = " ".repeat(depth)
326 ))
327 }
328 Stmt::Return(ret) => {
329 if let Some(v) = &ret.value {
330 Some(format!("return Ok({});", expr_to_rust(v)))
331 } else {
332 Some("return Ok(());".to_string())
333 }
334 }
335 Stmt::Expr(expr_stmt) => {
336 if let Expr::Constant(c) = expr_stmt.value.as_ref()
338 && matches!(c.value, rustpython_parser::ast::Constant::Str(_))
339 {
340 return Some("// docstring omitted".to_string());
341 }
342 if let Expr::Call(call) = expr_stmt.value.as_ref()
343 && let Expr::Attribute(attr) = call.func.as_ref()
344 && attr.attr.as_str() == "append"
345 && call.args.len() == 1
346 {
347 let target = expr_to_rust(&attr.value);
348 let arg = expr_to_rust(&call.args[0]);
349 return Some(format!("{target}.push({arg});"));
350 }
351 Some(format!("// expr: {}", expr_to_rust(&expr_stmt.value)))
352 }
353 Stmt::If(if_stmt) => {
354 if let Some(guard) = translate_len_guard(&if_stmt.test) {
355 return Some(guard);
356 }
357 let test = expr_to_rust(&if_stmt.test);
358 let body = translate_body_inner(if_stmt.body.as_slice(), depth + 1)
359 .map(|b| b.body)
360 .unwrap_or_else(|| "// unhandled if body".to_string());
361 let orelse = if !if_stmt.orelse.is_empty() {
362 translate_body_inner(if_stmt.orelse.as_slice(), depth + 1)
363 .map(|b| b.body)
364 .unwrap_or_else(|| "// unhandled else body".to_string())
365 } else {
366 String::new()
367 };
368 let else_block = if orelse.is_empty() {
369 String::new()
370 } else {
371 format!(" else {{\n{}\n{}}}", orelse, " ".repeat(depth))
372 };
373 Some(format!(
374 "if {test} {{\n{body}\n{indent}}}{else_block}",
375 test = test,
376 body = body,
377 indent = " ".repeat(depth),
378 else_block = else_block
379 ))
380 }
381 _ => None,
382 }
383}
384
385fn infer_return_type(expr: &Expr, var_types: &HashMap<String, &str>) -> String {
386 match expr {
387 Expr::Name(n) => {
388 if let Some(&"vec") = var_types.get(n.id.as_str()) {
389 return "Vec<f64>".to_string();
390 }
391 if let Some(&"map") = var_types.get(n.id.as_str()) {
392 return "std::collections::HashMap<(i64, i64), i64>".to_string();
393 }
394 "f64".to_string()
395 }
396 Expr::List(_) | Expr::ListComp(_) => "Vec<f64>".to_string(),
397 Expr::Tuple(tuple) => {
398 if tuple.elts.len() == 2 {
400 "(usize, f64)".to_string()
401 } else {
402 "Vec<f64>".to_string()
403 }
404 }
405 _ => "f64".to_string(),
406 }
407}