Skip to main content

rustify_ml/generator/
translate.rs

1//! Python AST → Rust body translation.
2//!
3//! Walks Python statement/expression trees and emits Rust source strings.
4//! Entry point: `translate_function_body`.
5
6use std::collections::HashMap;
7
8use rustpython_parser::ast::{Expr, Operator, Stmt};
9use tracing::warn;
10
11use crate::utils::TargetSpec;
12
13use super::expr::{expr_to_rust, translate_for_iter, translate_len_guard, translate_while_test};
14use super::infer::{infer_assign_type, infer_params};
15
16/// Result of translating a single Python function body.
17pub struct Translation {
18    pub params: Vec<(String, String)>,
19    pub return_type: String,
20    pub body: String,
21    pub fallback: bool,
22}
23
24/// Result of translating a block of Python statements.
25pub(super) struct BodyTranslation {
26    pub return_type: String,
27    pub body: String,
28    pub fallback: bool,
29}
30
31/// Find and translate the body of the named function in `module`.
32///
33/// Returns `None` only if the function is not found.
34/// Returns a `Translation` with `fallback: true` if the body cannot be translated.
35pub fn translate_function_body(target: &TargetSpec, module: &[Stmt]) -> Option<Translation> {
36    let func_def = module
37        .iter()
38        .find_map(|stmt| match stmt {
39            Stmt::FunctionDef(def) if def.name == target.func => Some(def),
40            _ => None,
41        })
42        .or_else(|| {
43            module.iter().find_map(|stmt| match stmt {
44                Stmt::FunctionDef(def) => Some(def),
45                _ => None,
46            })
47        })?;
48
49    let mut params = infer_params(func_def.args.as_ref());
50    if params.is_empty() {
51        params.push(("data".to_string(), "Vec<f64>".to_string()));
52    }
53
54    // Fast path: single-statement return of a name or constant
55    if let Some(Stmt::Return(ret)) = func_def.body.first()
56        && let Some(expr) = &ret.value
57    {
58        match expr.as_ref() {
59            Expr::Name(name) => {
60                return Some(Translation {
61                    params,
62                    return_type: "Vec<f64>".to_string(),
63                    body: format!(
64                        "// returning input name `{}` as-is\n    Ok({})",
65                        name.id, name.id
66                    ),
67                    fallback: false,
68                });
69            }
70            Expr::Constant(c) => {
71                return Some(Translation {
72                    params,
73                    return_type: "f64".to_string(),
74                    body: format!(
75                        "// returning constant from Python: {:?}\n    Ok({})",
76                        c.value,
77                        expr_to_rust(expr)
78                    ),
79                    fallback: false,
80                });
81            }
82            _ => {}
83        }
84    }
85
86    // Generic body translation
87    if let Some(translated) = translate_body(&func_def.body) {
88        return Some(Translation {
89            params,
90            return_type: translated.return_type,
91            body: translated.body,
92            fallback: translated.fallback,
93        });
94    }
95
96    warn!(func = %target.func, "unable to translate function body; echoing input");
97    Some(Translation {
98        params,
99        return_type: "Vec<f64>".to_string(),
100        body: "// fallback: echo input\n    Ok(data)".to_string(),
101        fallback: true,
102    })
103}
104
105pub(super) fn translate_body(body: &[Stmt]) -> Option<BodyTranslation> {
106    translate_body_inner(body, 1)
107}
108
109/// Recursive body translator. `depth` tracks nesting level for indentation.
110pub(super) fn translate_body_inner(body: &[Stmt], depth: usize) -> Option<BodyTranslation> {
111    if body.is_empty() {
112        return None;
113    }
114
115    let indent = "    ".repeat(depth);
116    let mut var_types: HashMap<String, &str> = HashMap::new();
117
118    // Generic sequential statement translation
119    let mut out = String::new();
120    let mut had_unhandled = false;
121    let mut inferred_return: Option<String> = None;
122    let mut had_return = false;
123
124    for stmt in body {
125        // Track simple vector-producing assignments for later return inference
126        if let Stmt::Assign(assign) = stmt
127            && let Some(target) = assign.targets.first()
128        {
129            // result = [0.0] * n or result = [expr for ...]
130            if let Expr::BinOp(binop) = assign.value.as_ref()
131                && matches!(binop.op, Operator::Mult)
132                && let Expr::List(lst) = binop.left.as_ref()
133                && lst.elts.len() == 1
134                && let Expr::Name(name_target) = target
135            {
136                var_types.insert(name_target.id.to_string(), "vec");
137            }
138            if let Expr::ListComp(_) = assign.value.as_ref()
139                && let Expr::Name(name_target) = target
140            {
141                var_types.insert(name_target.id.to_string(), "vec");
142            }
143            if let Expr::Subscript(sub) = assign.value.as_ref()
144                && matches!(sub.slice.as_ref(), Expr::Slice(_))
145                && let Expr::Name(name_target) = target
146            {
147                var_types.insert(name_target.id.to_string(), "vec");
148            }
149            if let Expr::Dict(dict) = assign.value.as_ref()
150                && dict.keys.is_empty()
151                && let Expr::Name(name_target) = target
152            {
153                var_types.insert(name_target.id.to_string(), "map");
154            }
155        }
156
157        match translate_stmt_inner(stmt, depth) {
158            Some(line) => {
159                if line.trim_start().starts_with("return ")
160                    && let Stmt::Return(ret) = stmt
161                    && let Some(expr) = &ret.value
162                {
163                    had_return = true;
164                    let ret_ty = infer_return_type(expr.as_ref(), &var_types);
165                    inferred_return = Some(ret_ty);
166                }
167                out.push_str(&indent);
168                out.push_str(&line);
169                if !line.ends_with('\n') {
170                    out.push('\n');
171                }
172            }
173            None => {
174                had_unhandled = true;
175                out.push_str(&indent);
176                out.push_str("// Unhandled stmt\n");
177            }
178        }
179    }
180
181    if !had_return
182        && let Some(ret_var) = var_types.keys().find(|k| *k == "result" || *k == "output")
183    {
184        inferred_return = Some("Vec<f64>".to_string());
185        out.push_str(&format!("{indent}return Ok({});\n", ret_var));
186    }
187
188    Some(BodyTranslation {
189        return_type: inferred_return.unwrap_or_else(|| "f64".to_string()),
190        body: out,
191        fallback: had_unhandled,
192    })
193}
194
195/// Translate a single Python statement to a Rust statement string.
196/// Returns `None` for unhandled statement types (triggers fallback).
197pub(super) fn translate_stmt_inner(stmt: &Stmt, depth: usize) -> Option<String> {
198    match stmt {
199        Stmt::Assign(assign) => {
200            if let (Some(target), value) = (assign.targets.first(), &assign.value) {
201                if let Expr::Dict(dict) = value.as_ref()
202                    && dict.keys.is_empty()
203                {
204                    if let Expr::Name(n) = target {
205                        return Some(format!(
206                            "let mut {}: std::collections::HashMap<(i64, i64), i64> = std::collections::HashMap::new();",
207                            n.id
208                        ));
209                    }
210                }
211                // Subscript assign: result[i] = val → result[i] = val;
212                if let Expr::Subscript(sub) = target {
213                    let lhs = format!("{}[{}]", expr_to_rust(&sub.value), expr_to_rust(&sub.slice));
214                    let rhs = expr_to_rust(value);
215                    return Some(format!("{} = {};", lhs, rhs));
216                }
217                // Track slice assigns as vec for return inference
218                if let Expr::Subscript(sub) = value.as_ref()
219                    && matches!(sub.slice.as_ref(), Expr::Slice(_))
220                    && let Expr::Name(name_target) = target
221                {
222                    // mark as vec downstream via a variable declaration
223                    let rhs = expr_to_rust(value);
224                    return Some(format!("let mut {} = {};", name_target.id, rhs));
225                }
226                // List init: result = [0.0] * n → let mut result = vec![0.0f64; n];
227                if let Expr::BinOp(binop) = value.as_ref()
228                    && matches!(binop.op, Operator::Mult)
229                    && let Expr::List(lst) = binop.left.as_ref()
230                    && lst.elts.len() == 1
231                {
232                    let fill = expr_to_rust(&lst.elts[0]);
233                    let size = expr_to_rust(&binop.right);
234                    let var_name = match target {
235                        Expr::Name(n) => n.id.to_string(),
236                        _ => "result".to_string(),
237                    };
238                    let fill_rust = if fill.contains('.') {
239                        format!("{}f64", fill)
240                    } else {
241                        fill.clone()
242                    };
243                    return Some(format!(
244                        "let mut {var} = vec![{fill}; {size}];",
245                        var = var_name,
246                        fill = fill_rust,
247                        size = size
248                    ));
249                }
250                // List comprehension: result = [expr for var in iterable]
251                // → let result: Vec<f64> = iterable.iter().map(|var| expr).collect();
252                if let Expr::ListComp(lc) = value.as_ref()
253                    && lc.generators.len() == 1
254                {
255                    let comprehension = &lc.generators[0];
256                    let iter_str = expr_to_rust(&comprehension.iter);
257                    let loop_var = expr_to_rust(&comprehension.target);
258                    let elt = expr_to_rust(&lc.elt);
259                    let var_name = match target {
260                        Expr::Name(n) => n.id.to_string(),
261                        _ => "result".to_string(),
262                    };
263                    return Some(format!(
264                        "let {var}: Vec<f64> = {iter}.iter().map(|{lv}| {elt}).collect();",
265                        var = var_name,
266                        iter = iter_str,
267                        lv = loop_var,
268                        elt = elt,
269                    ));
270                }
271                // Simple name assign
272                let lhs = match target {
273                    Expr::Name(n) => {
274                        let type_suffix = infer_assign_type(value);
275                        format!("let mut {}{}", n.id, type_suffix)
276                    }
277                    Expr::Attribute(_) => format!("// attribute assign {}", expr_to_rust(target)),
278                    _ => format!("// complex assign {}", expr_to_rust(target)),
279                };
280                let rhs = expr_to_rust(value);
281                return Some(format!("{} = {};", lhs, rhs));
282            }
283            None
284        }
285        Stmt::For(for_stmt) => {
286            if !for_stmt.orelse.is_empty() {
287                return None;
288            }
289            let iter_str = translate_for_iter(&for_stmt.iter);
290            let loop_var = expr_to_rust(&for_stmt.target);
291            let inner = translate_body_inner(for_stmt.body.as_slice(), depth + 1);
292            let loop_body = inner
293                .map(|b| b.body)
294                .unwrap_or_else(|| "    // unhandled loop body".to_string());
295            Some(format!(
296                "for {loop_var} in {iter_str} {{\n{loop_body}\n{indent}}}",
297                loop_var = loop_var,
298                iter_str = iter_str,
299                loop_body = loop_body,
300                indent = "    ".repeat(depth)
301            ))
302        }
303        Stmt::AugAssign(aug) => {
304            let lhs = expr_to_rust(&aug.target);
305            let rhs = expr_to_rust(&aug.value);
306            let op = match aug.op {
307                Operator::Add => "+=",
308                Operator::Sub => "-=",
309                Operator::Mult => "*=",
310                Operator::Div => "/=",
311                _ => "+=",
312            };
313            Some(format!("{} {} {};", lhs, op, rhs))
314        }
315        Stmt::While(while_stmt) => {
316            let test = translate_while_test(&while_stmt.test);
317            let inner = translate_body_inner(while_stmt.body.as_slice(), depth + 1);
318            let loop_body = inner
319                .map(|b| b.body)
320                .unwrap_or_else(|| format!("{}    // unhandled while body", "    ".repeat(depth)));
321            Some(format!(
322                "while {test} {{\n{loop_body}\n{indent}}}",
323                test = test,
324                loop_body = loop_body,
325                indent = "    ".repeat(depth)
326            ))
327        }
328        Stmt::Return(ret) => {
329            if let Some(v) = &ret.value {
330                Some(format!("return Ok({});", expr_to_rust(v)))
331            } else {
332                Some("return Ok(());".to_string())
333            }
334        }
335        Stmt::Expr(expr_stmt) => {
336            // Docstring (string constant) → comment, not fallback
337            if let Expr::Constant(c) = expr_stmt.value.as_ref()
338                && matches!(c.value, rustpython_parser::ast::Constant::Str(_))
339            {
340                return Some("// docstring omitted".to_string());
341            }
342            if let Expr::Call(call) = expr_stmt.value.as_ref()
343                && let Expr::Attribute(attr) = call.func.as_ref()
344                && attr.attr.as_str() == "append"
345                && call.args.len() == 1
346            {
347                let target = expr_to_rust(&attr.value);
348                let arg = expr_to_rust(&call.args[0]);
349                return Some(format!("{target}.push({arg});"));
350            }
351            Some(format!("// expr: {}", expr_to_rust(&expr_stmt.value)))
352        }
353        Stmt::If(if_stmt) => {
354            if let Some(guard) = translate_len_guard(&if_stmt.test) {
355                return Some(guard);
356            }
357            let test = expr_to_rust(&if_stmt.test);
358            let body = translate_body_inner(if_stmt.body.as_slice(), depth + 1)
359                .map(|b| b.body)
360                .unwrap_or_else(|| "// unhandled if body".to_string());
361            let orelse = if !if_stmt.orelse.is_empty() {
362                translate_body_inner(if_stmt.orelse.as_slice(), depth + 1)
363                    .map(|b| b.body)
364                    .unwrap_or_else(|| "// unhandled else body".to_string())
365            } else {
366                String::new()
367            };
368            let else_block = if orelse.is_empty() {
369                String::new()
370            } else {
371                format!(" else {{\n{}\n{}}}", orelse, "    ".repeat(depth))
372            };
373            Some(format!(
374                "if {test} {{\n{body}\n{indent}}}{else_block}",
375                test = test,
376                body = body,
377                indent = "    ".repeat(depth),
378                else_block = else_block
379            ))
380        }
381        _ => None,
382    }
383}
384
385fn infer_return_type(expr: &Expr, var_types: &HashMap<String, &str>) -> String {
386    match expr {
387        Expr::Name(n) => {
388            if let Some(&"vec") = var_types.get(n.id.as_str()) {
389                return "Vec<f64>".to_string();
390            }
391            if let Some(&"map") = var_types.get(n.id.as_str()) {
392                return "std::collections::HashMap<(i64, i64), i64>".to_string();
393            }
394            "f64".to_string()
395        }
396        Expr::List(_) | Expr::ListComp(_) => "Vec<f64>".to_string(),
397        Expr::Tuple(tuple) => {
398            // Heuristic: two-element tuple used for argmax/argmin → (usize, f64)
399            if tuple.elts.len() == 2 {
400                "(usize, f64)".to_string()
401            } else {
402                "Vec<f64>".to_string()
403            }
404        }
405        _ => "f64".to_string(),
406    }
407}