Skip to main content

rustify_ml/generator/
expr.rs

1//! Expression-to-Rust translation helpers.
2//!
3//! Converts Python AST expression nodes into Rust source strings.
4//! All functions are pure (no I/O, no state).
5
6use rustpython_parser::ast::{BoolOp, CmpOp, Expr, Operator, UnaryOp};
7
8/// Translate a Python expression to a Rust expression string.
9pub fn expr_to_rust(expr: &Expr) -> String {
10    match expr {
11        Expr::Name(n) => n.id.to_string(),
12        Expr::Constant(c) => constant_to_rust(&c.value),
13        Expr::UnaryOp(unary) => {
14            let operand = expr_to_rust(&unary.operand);
15            match unary.op {
16                UnaryOp::USub => format!("-({operand})"),
17                UnaryOp::Not => format!("!{operand}"),
18                _ => format!("-{operand}"),
19            }
20        }
21        Expr::BoolOp(boolop) if !boolop.values.is_empty() => {
22            let op = match boolop.op {
23                BoolOp::And => "&&",
24                BoolOp::Or => "||",
25            };
26            let rendered: Vec<String> = boolop.values.iter().map(expr_to_rust).collect();
27            rendered.join(&format!(" {op} "))
28        }
29        Expr::Call(call) => {
30            if let Expr::Name(func) = call.func.as_ref() {
31                if func.id.as_str() == "range" && call.args.len() == 1 {
32                    return format!("0..{}", expr_to_rust(&call.args[0]));
33                }
34                if func.id.as_str() == "len" && call.args.len() == 1 {
35                    return format!("{}.len()", expr_to_rust(&call.args[0]));
36                }
37                if (func.id.as_str() == "max" || func.id.as_str() == "min") && call.args.len() == 2
38                {
39                    let a = expr_to_rust(&call.args[0]);
40                    let b = expr_to_rust(&call.args[1]);
41                    let method = if func.id.as_str() == "max" {
42                        "max"
43                    } else {
44                        "min"
45                    };
46                    return format!("({a}).{method}({b})");
47                }
48            }
49            format!("/* call {} fallback */", expr_to_rust(call.func.as_ref()))
50        }
51        Expr::BinOp(binop) => {
52            let left = expr_to_rust(&binop.left);
53            let right = expr_to_rust(&binop.right);
54            let op = match binop.op {
55                Operator::Add => "+",
56                Operator::Sub => "-",
57                Operator::Mult => "*",
58                Operator::Div => "/",
59                Operator::Pow => {
60                    return format!("({}).powf({})", left, right);
61                }
62                _ => "+",
63            };
64            format!("({} {} {})", left, op, right)
65        }
66        Expr::Compare(comp) if comp.ops.len() == 1 && comp.comparators.len() == 1 => {
67            let left = expr_to_rust(&comp.left);
68            let right = expr_to_rust(&comp.comparators[0]);
69            let op = match comp.ops[0] {
70                CmpOp::Lt => "<",
71                CmpOp::LtE => "<=",
72                CmpOp::Gt => ">",
73                CmpOp::GtE => ">=",
74                CmpOp::Eq => "==",
75                CmpOp::NotEq => "!=",
76                _ => "<",
77            };
78            format!("{left} {op} {right}")
79        }
80        Expr::Subscript(sub) => {
81            let value = expr_to_rust(&sub.value);
82            let index = expr_to_rust(&sub.slice);
83            format!("{}[{}]", value, index)
84        }
85        Expr::Attribute(attr) => {
86            format!("{}.{}", expr_to_rust(&attr.value), attr.attr)
87        }
88        Expr::List(list) => {
89            let elems: Vec<String> = list.elts.iter().map(expr_to_rust).collect();
90            format!("vec![{}]", elems.join(", "))
91        }
92        Expr::Tuple(tuple) => {
93            let elems: Vec<String> = tuple.elts.iter().map(expr_to_rust).collect();
94            format!("({})", elems.join(", "))
95        }
96        _ => "/* unsupported expr */".to_string(),
97    }
98}
99
100/// Convert a Python constant value to its Rust literal equivalent.
101pub fn constant_to_rust(value: &rustpython_parser::ast::Constant) -> String {
102    use rustpython_parser::ast::Constant;
103    match value {
104        Constant::Int(i) => i.to_string(),
105        Constant::Float(f) => {
106            let s = format!("{}", f);
107            if s.contains('.') || s.contains('e') {
108                s
109            } else {
110                format!("{}.0", s)
111            }
112        }
113        Constant::Bool(b) => b.to_string(),
114        Constant::Str(s) => format!("\"{}\"", s.escape_default()),
115        Constant::None => "()".to_string(),
116        _ => "0".to_string(),
117    }
118}
119
120/// Translate a Python for-loop iterator expression to a Rust range string.
121/// Handles: `range(n)` → `0..n`, `range(a, b)` → `a..b`, fallback to expr_to_rust.
122pub fn translate_for_iter(iter: &Expr) -> String {
123    if let Expr::Call(call) = iter
124        && let Expr::Name(func) = call.func.as_ref()
125        && func.id.as_str() == "range"
126    {
127        match call.args.len() {
128            1 => return format!("0..{}", expr_to_rust(&call.args[0])),
129            2 => {
130                return format!(
131                    "{}..{}",
132                    expr_to_rust(&call.args[0]),
133                    expr_to_rust(&call.args[1])
134                );
135            }
136            _ => {}
137        }
138    }
139    expr_to_rust(iter)
140}
141
142/// Translate a Python while-loop test expression to Rust.
143///
144/// Handles:
145/// - `while changed:` → `while changed`
146/// - `while not changed:` → `while !changed`
147/// - `while i < len(x):` → `while i < x.len()`
148pub fn translate_while_test(test: &Expr) -> String {
149    match test {
150        Expr::Name(n) => n.id.to_string(),
151        Expr::UnaryOp(unary) => {
152            use rustpython_parser::ast::UnaryOp;
153            if matches!(unary.op, UnaryOp::Not) {
154                format!("!{}", translate_while_test(&unary.operand))
155            } else {
156                expr_to_rust(test)
157            }
158        }
159        Expr::Compare(comp) if comp.ops.len() == 1 && comp.comparators.len() == 1 => {
160            let left = expr_to_rust(&comp.left);
161            let right = expr_to_rust(&comp.comparators[0]);
162            let op = match comp.ops[0] {
163                CmpOp::Lt => "<",
164                CmpOp::LtE => "<=",
165                CmpOp::Gt => ">",
166                CmpOp::GtE => ">=",
167                CmpOp::Eq => "==",
168                CmpOp::NotEq => "!=",
169                _ => "<",
170            };
171            format!("{left} {op} {right}")
172        }
173        _ => expr_to_rust(test),
174    }
175}
176
177/// Translate a Python if-test that guards a length check into a Rust guard string.
178/// Returns `None` if the test is not a simple equality/inequality comparison.
179pub fn translate_len_guard(test: &Expr) -> Option<String> {
180    if let Expr::Compare(comp) = test
181        && comp.ops.len() == 1
182        && comp.comparators.len() == 1
183    {
184        let op = &comp.ops[0];
185        let left = expr_to_rust(&comp.left);
186        let right = expr_to_rust(&comp.comparators[0]);
187        let cond = match op {
188            CmpOp::Eq => format!("{left} == {right}"),
189            CmpOp::NotEq => format!("{left} != {right}"),
190            _ => return None,
191        };
192        return Some(format!(
193            "if {cond} {{\n        return Err(pyo3::exceptions::PyValueError::new_err(\"Vectors must be same length\"));\n    }}",
194            cond = cond
195        ));
196    }
197    None
198}