Skip to main content

rustify_ml/generator/
translate.rs

1//! Python AST → Rust body translation.
2//!
3//! Walks Python statement/expression trees and emits Rust source strings.
4//! Entry point: `translate_function_body`.
5
6use std::collections::HashMap;
7
8use rustpython_parser::ast::{Expr, Operator, Stmt};
9use tracing::warn;
10
11use crate::utils::TargetSpec;
12
13use super::expr::{expr_to_rust, translate_for_iter, translate_len_guard, translate_while_test};
14use super::infer::{infer_assign_type, infer_params};
15
16/// Result of translating a single Python function body.
17pub struct Translation {
18    pub params: Vec<(String, String)>,
19    pub return_type: String,
20    pub body: String,
21    pub fallback: bool,
22}
23
24/// Result of translating a block of Python statements.
25pub(super) struct BodyTranslation {
26    pub return_type: String,
27    pub body: String,
28    pub fallback: bool,
29}
30
31/// Find and translate the body of the named function in `module`.
32///
33/// Returns `None` only if the function is not found.
34/// Returns a `Translation` with `fallback: true` if the body cannot be translated.
35pub fn translate_function_body(target: &TargetSpec, module: &[Stmt]) -> Option<Translation> {
36    let func_def = module.iter().find_map(|stmt| match stmt {
37        Stmt::FunctionDef(def) if def.name == target.func => Some(def),
38        _ => None,
39    })?;
40
41    let mut params = infer_params(func_def.args.as_ref());
42    if params.is_empty() {
43        params.push(("data".to_string(), "Vec<f64>".to_string()));
44    }
45
46    // Fast path: single-statement return of a name or constant
47    if let Some(Stmt::Return(ret)) = func_def.body.first()
48        && let Some(expr) = &ret.value
49    {
50        match expr.as_ref() {
51            Expr::Name(name) => {
52                return Some(Translation {
53                    params,
54                    return_type: "Vec<f64>".to_string(),
55                    body: format!(
56                        "// returning input name `{}` as-is\n    Ok({})",
57                        name.id, name.id
58                    ),
59                    fallback: false,
60                });
61            }
62            Expr::Constant(c) => {
63                return Some(Translation {
64                    params,
65                    return_type: "f64".to_string(),
66                    body: format!(
67                        "// returning constant from Python: {:?}\n    Ok({})",
68                        c.value,
69                        expr_to_rust(expr)
70                    ),
71                    fallback: false,
72                });
73            }
74            _ => {}
75        }
76    }
77
78    // Generic body translation
79    if let Some(translated) = translate_body(&func_def.body) {
80        let first_param_name = params
81            .first()
82            .map(|(n, _)| n.as_str())
83            .unwrap_or("data")
84            .to_string();
85        if translated.fallback {
86            return Some(Translation {
87                params,
88                return_type: "Vec<f64>".to_string(),
89                body: format!("// fallback: echo input\n    Ok({first_param_name})"),
90                fallback: true,
91            });
92        }
93        return Some(Translation {
94            params,
95            return_type: translated.return_type,
96            body: translated.body,
97            fallback: translated.fallback,
98        });
99    }
100
101    warn!(func = %target.func, "unable to translate function body; echoing input");
102    Some(Translation {
103        params,
104        return_type: "Vec<f64>".to_string(),
105        body: "// fallback: echo input\n    Ok(data)".to_string(),
106        fallback: true,
107    })
108}
109
110pub(super) fn translate_body(body: &[Stmt]) -> Option<BodyTranslation> {
111    translate_body_inner(body, 1)
112}
113
114/// Recursive body translator. `depth` tracks nesting level for indentation.
115pub(super) fn translate_body_inner(body: &[Stmt], depth: usize) -> Option<BodyTranslation> {
116    if body.is_empty() {
117        return None;
118    }
119
120    let indent = "    ".repeat(depth);
121    let mut var_types: HashMap<String, &str> = HashMap::new();
122
123    // Generic sequential statement translation
124    let mut out = String::new();
125    let mut had_unhandled = false;
126    let mut inferred_return: Option<String> = None;
127    let mut had_return = false;
128
129    for stmt in body {
130        // Track simple vector-producing assignments for later return inference
131        if let Stmt::Assign(assign) = stmt
132            && let Some(target) = assign.targets.first()
133        {
134            // result = [0.0] * n or result = [expr for ...]
135            if let Expr::BinOp(binop) = assign.value.as_ref()
136                && matches!(binop.op, Operator::Mult)
137                && let Expr::List(lst) = binop.left.as_ref()
138                && lst.elts.len() == 1
139                && let Expr::Name(name_target) = target
140            {
141                var_types.insert(name_target.id.to_string(), "vec");
142            }
143            if let Expr::ListComp(_) = assign.value.as_ref()
144                && let Expr::Name(name_target) = target
145            {
146                var_types.insert(name_target.id.to_string(), "vec");
147            }
148        }
149
150        match translate_stmt_inner(stmt, depth) {
151            Some(line) => {
152                if line.trim_start().starts_with("return ")
153                    && let Stmt::Return(ret) = stmt
154                    && let Some(expr) = &ret.value
155                {
156                    had_return = true;
157                    let ret_ty = infer_return_type(expr.as_ref(), &var_types);
158                    inferred_return = Some(ret_ty);
159                }
160                out.push_str(&indent);
161                out.push_str(&line);
162                if !line.ends_with('\n') {
163                    out.push('\n');
164                }
165            }
166            None => {
167                had_unhandled = true;
168                out.push_str(&indent);
169                out.push_str("// Unhandled stmt\n");
170            }
171        }
172    }
173
174    if !had_return
175        && let Some(ret_var) = var_types.keys().find(|k| *k == "result" || *k == "output")
176    {
177        inferred_return = Some("Vec<f64>".to_string());
178        out.push_str(&format!("{indent}return Ok({});\n", ret_var));
179    }
180
181    Some(BodyTranslation {
182        return_type: inferred_return.unwrap_or_else(|| "f64".to_string()),
183        body: out,
184        fallback: had_unhandled,
185    })
186}
187
188/// Translate a single Python statement to a Rust statement string.
189/// Returns `None` for unhandled statement types (triggers fallback).
190pub(super) fn translate_stmt_inner(stmt: &Stmt, depth: usize) -> Option<String> {
191    match stmt {
192        Stmt::Assign(assign) => {
193            if let (Some(target), value) = (assign.targets.first(), &assign.value) {
194                // Subscript assign: result[i] = val → result[i] = val;
195                if let Expr::Subscript(sub) = target {
196                    let lhs = format!("{}[{}]", expr_to_rust(&sub.value), expr_to_rust(&sub.slice));
197                    let rhs = expr_to_rust(value);
198                    return Some(format!("{} = {};", lhs, rhs));
199                }
200                // List init: result = [0.0] * n → let mut result = vec![0.0f64; n];
201                if let Expr::BinOp(binop) = value.as_ref()
202                    && matches!(binop.op, Operator::Mult)
203                    && let Expr::List(lst) = binop.left.as_ref()
204                    && lst.elts.len() == 1
205                {
206                    let fill = expr_to_rust(&lst.elts[0]);
207                    let size = expr_to_rust(&binop.right);
208                    let var_name = match target {
209                        Expr::Name(n) => n.id.to_string(),
210                        _ => "result".to_string(),
211                    };
212                    let fill_rust = if fill.contains('.') {
213                        format!("{}f64", fill)
214                    } else {
215                        fill.clone()
216                    };
217                    return Some(format!(
218                        "let mut {var} = vec![{fill}; {size}];",
219                        var = var_name,
220                        fill = fill_rust,
221                        size = size
222                    ));
223                }
224                // List comprehension: result = [expr for var in iterable]
225                // → let result: Vec<f64> = iterable.iter().map(|var| expr).collect();
226                if let Expr::ListComp(lc) = value.as_ref()
227                    && lc.generators.len() == 1
228                {
229                    let comprehension = &lc.generators[0];
230                    let iter_str = expr_to_rust(&comprehension.iter);
231                    let loop_var = expr_to_rust(&comprehension.target);
232                    let elt = expr_to_rust(&lc.elt);
233                    let var_name = match target {
234                        Expr::Name(n) => n.id.to_string(),
235                        _ => "result".to_string(),
236                    };
237                    return Some(format!(
238                        "let {var}: Vec<f64> = {iter}.iter().map(|{lv}| {elt}).collect();",
239                        var = var_name,
240                        iter = iter_str,
241                        lv = loop_var,
242                        elt = elt,
243                    ));
244                }
245                // Simple name assign
246                let lhs = match target {
247                    Expr::Name(n) => {
248                        let type_suffix = infer_assign_type(value);
249                        format!("let mut {}{}", n.id, type_suffix)
250                    }
251                    Expr::Attribute(_) => format!("// attribute assign {}", expr_to_rust(target)),
252                    _ => format!("// complex assign {}", expr_to_rust(target)),
253                };
254                let rhs = expr_to_rust(value);
255                return Some(format!("{} = {};", lhs, rhs));
256            }
257            None
258        }
259        Stmt::For(for_stmt) => {
260            let iter_str = translate_for_iter(&for_stmt.iter);
261            let loop_var = expr_to_rust(&for_stmt.target);
262            let inner = translate_body_inner(for_stmt.body.as_slice(), depth + 1);
263            let loop_body = inner
264                .map(|b| b.body)
265                .unwrap_or_else(|| "    // unhandled loop body".to_string());
266            Some(format!(
267                "for {loop_var} in {iter_str} {{\n{loop_body}\n{indent}}}",
268                loop_var = loop_var,
269                iter_str = iter_str,
270                loop_body = loop_body,
271                indent = "    ".repeat(depth)
272            ))
273        }
274        Stmt::AugAssign(aug) => {
275            let lhs = expr_to_rust(&aug.target);
276            let rhs = expr_to_rust(&aug.value);
277            let op = match aug.op {
278                Operator::Add => "+=",
279                Operator::Sub => "-=",
280                Operator::Mult => "*=",
281                Operator::Div => "/=",
282                _ => "+=",
283            };
284            Some(format!("{} {} {};", lhs, op, rhs))
285        }
286        Stmt::While(while_stmt) => {
287            let test = translate_while_test(&while_stmt.test);
288            let inner = translate_body_inner(while_stmt.body.as_slice(), depth + 1);
289            let loop_body = inner
290                .map(|b| b.body)
291                .unwrap_or_else(|| format!("{}    // unhandled while body", "    ".repeat(depth)));
292            Some(format!(
293                "while {test} {{\n{loop_body}\n{indent}}}",
294                test = test,
295                loop_body = loop_body,
296                indent = "    ".repeat(depth)
297            ))
298        }
299        Stmt::Return(ret) => {
300            if let Some(v) = &ret.value {
301                Some(format!("return Ok({});", expr_to_rust(v)))
302            } else {
303                Some("return Ok(());".to_string())
304            }
305        }
306        Stmt::Expr(expr_stmt) => {
307            // Docstring (string constant) → comment, not fallback
308            if let Expr::Constant(c) = expr_stmt.value.as_ref()
309                && matches!(c.value, rustpython_parser::ast::Constant::Str(_))
310            {
311                return Some("// docstring omitted".to_string());
312            }
313            if let Expr::Call(call) = expr_stmt.value.as_ref()
314                && let Expr::Attribute(attr) = call.func.as_ref()
315                && attr.attr.as_str() == "append"
316                && call.args.len() == 1
317            {
318                let target = expr_to_rust(&attr.value);
319                let arg = expr_to_rust(&call.args[0]);
320                return Some(format!("{target}.push({arg});"));
321            }
322            Some(format!("// expr: {}", expr_to_rust(&expr_stmt.value)))
323        }
324        Stmt::If(if_stmt) => {
325            if let Some(guard) = translate_len_guard(&if_stmt.test) {
326                return Some(guard);
327            }
328            let test = expr_to_rust(&if_stmt.test);
329            let body = translate_body_inner(if_stmt.body.as_slice(), depth + 1)
330                .map(|b| b.body)
331                .unwrap_or_else(|| "// unhandled if body".to_string());
332            let orelse = if !if_stmt.orelse.is_empty() {
333                translate_body_inner(if_stmt.orelse.as_slice(), depth + 1)
334                    .map(|b| b.body)
335                    .unwrap_or_else(|| "// unhandled else body".to_string())
336            } else {
337                String::new()
338            };
339            let else_block = if orelse.is_empty() {
340                String::new()
341            } else {
342                format!(" else {{\n{}\n{}}}", orelse, "    ".repeat(depth))
343            };
344            Some(format!(
345                "if {test} {{\n{body}\n{indent}}}{else_block}",
346                test = test,
347                body = body,
348                indent = "    ".repeat(depth),
349                else_block = else_block
350            ))
351        }
352        _ => None,
353    }
354}
355
356fn infer_return_type(expr: &Expr, var_types: &HashMap<String, &str>) -> String {
357    match expr {
358        Expr::Name(n) => {
359            if let Some(&"vec") = var_types.get(n.id.as_str()) {
360                return "Vec<f64>".to_string();
361            }
362            "f64".to_string()
363        }
364        Expr::List(_) | Expr::ListComp(_) => "Vec<f64>".to_string(),
365        Expr::Tuple(_) => "Vec<f64>".to_string(),
366        _ => "f64".to_string(),
367    }
368}