Skip to main content

rustify_ml/generator/
expr.rs

1//! Expression-to-Rust translation helpers.
2//!
3//! Converts Python AST expression nodes into Rust source strings.
4//! All functions are pure (no I/O, no state).
5
6use rustpython_parser::ast::{BoolOp, CmpOp, Expr, Operator, UnaryOp};
7
8/// Translate a Python expression to a Rust expression string.
9pub fn expr_to_rust(expr: &Expr) -> String {
10    match expr {
11        Expr::Name(n) => n.id.to_string(),
12        Expr::Constant(c) => constant_to_rust(&c.value),
13        Expr::UnaryOp(unary) => {
14            let operand = expr_to_rust(&unary.operand);
15            match unary.op {
16                UnaryOp::USub => format!("-({operand})"),
17                UnaryOp::Not => format!("!{operand}"),
18                _ => format!("-{operand}"),
19            }
20        }
21        Expr::BoolOp(boolop) if !boolop.values.is_empty() => {
22            let op = match boolop.op {
23                BoolOp::And => "&&",
24                BoolOp::Or => "||",
25            };
26            let rendered: Vec<String> = boolop.values.iter().map(expr_to_rust).collect();
27            rendered.join(&format!(" {op} "))
28        }
29        Expr::Call(call) => {
30            if let Expr::Name(func) = call.func.as_ref() {
31                if func.id.as_str() == "range" && call.args.len() == 1 {
32                    return format!("0..{}", expr_to_rust(&call.args[0]));
33                }
34                if func.id.as_str() == "len" && call.args.len() == 1 {
35                    return format!("{}.len()", expr_to_rust(&call.args[0]));
36                }
37                if (func.id.as_str() == "max" || func.id.as_str() == "min") && call.args.len() == 2
38                {
39                    let a = expr_to_rust(&call.args[0]);
40                    let b = expr_to_rust(&call.args[1]);
41                    let method = if func.id.as_str() == "max" {
42                        "max"
43                    } else {
44                        "min"
45                    };
46                    return format!("({a}).{method}({b})");
47                }
48            }
49            format!("/* call {} fallback */", expr_to_rust(call.func.as_ref()))
50        }
51        Expr::BinOp(binop) => {
52            let left = expr_to_rust(&binop.left);
53            let right = expr_to_rust(&binop.right);
54            let op = match binop.op {
55                Operator::Add => "+",
56                Operator::Sub => "-",
57                Operator::Mult => "*",
58                Operator::Div => "/",
59                Operator::Pow => {
60                    return format!("({}).powf({})", left, right);
61                }
62                _ => "+",
63            };
64            format!("({} {} {})", left, op, right)
65        }
66        Expr::Compare(comp) if comp.ops.len() == 1 && comp.comparators.len() == 1 => {
67            let left = expr_to_rust(&comp.left);
68            let right = expr_to_rust(&comp.comparators[0]);
69            let op = match comp.ops[0] {
70                CmpOp::Lt => "<",
71                CmpOp::LtE => "<=",
72                CmpOp::Gt => ">",
73                CmpOp::GtE => ">=",
74                CmpOp::Eq => "==",
75                CmpOp::NotEq => "!=",
76                _ => "<",
77            };
78            format!("{left} {op} {right}")
79        }
80        Expr::Subscript(sub) => {
81            let value = expr_to_rust(&sub.value);
82            match sub.slice.as_ref() {
83                // Support simple slices without step: x[a:b] → x[a..b].to_vec()
84                rustpython_parser::ast::Expr::Slice(slice) => {
85                    if slice.step.is_some() {
86                        return "/* unsupported slice step */".to_string();
87                    }
88                    let start = slice
89                        .lower
90                        .as_deref()
91                        .map(expr_to_rust)
92                        .unwrap_or_else(|| "0".to_string());
93                    let end = slice
94                        .upper
95                        .as_deref()
96                        .map(expr_to_rust)
97                        .unwrap_or_else(|| format!("{}.len()", value));
98                    format!("{}[{}..{}].to_vec()", value, start, end)
99                }
100                other => {
101                    let index = expr_to_rust(other);
102                    format!("{}[{}]", value, index)
103                }
104            }
105        }
106        Expr::Attribute(attr) => {
107            format!("{}.{}", expr_to_rust(&attr.value), attr.attr)
108        }
109        Expr::List(list) => {
110            let elems: Vec<String> = list.elts.iter().map(expr_to_rust).collect();
111            format!("vec![{}]", elems.join(", "))
112        }
113        Expr::Tuple(tuple) => {
114            let elems: Vec<String> = tuple.elts.iter().map(expr_to_rust).collect();
115            format!("({})", elems.join(", "))
116        }
117        _ => "/* unsupported expr */".to_string(),
118    }
119}
120
121/// Convert a Python constant value to its Rust literal equivalent.
122pub fn constant_to_rust(value: &rustpython_parser::ast::Constant) -> String {
123    use rustpython_parser::ast::Constant;
124    match value {
125        Constant::Int(i) => i.to_string(),
126        Constant::Float(f) => {
127            let s = format!("{}", f);
128            if s.contains('.') || s.contains('e') {
129                s
130            } else {
131                format!("{}.0", s)
132            }
133        }
134        Constant::Bool(b) => b.to_string(),
135        Constant::Str(s) => format!("\"{}\"", s.escape_default()),
136        Constant::None => "()".to_string(),
137        _ => "0".to_string(),
138    }
139}
140
141/// Translate a Python for-loop iterator expression to a Rust range string.
142/// Handles: `range(n)` → `0..n`, `range(a, b)` → `a..b`, fallback to expr_to_rust.
143pub fn translate_for_iter(iter: &Expr) -> String {
144    if let Expr::Call(call) = iter
145        && let Expr::Name(func) = call.func.as_ref()
146        && func.id.as_str() == "range"
147    {
148        match call.args.len() {
149            1 => return format!("0..{}", expr_to_rust(&call.args[0])),
150            2 => {
151                return format!(
152                    "{}..{}",
153                    expr_to_rust(&call.args[0]),
154                    expr_to_rust(&call.args[1])
155                );
156            }
157            _ => {}
158        }
159    }
160    expr_to_rust(iter)
161}
162
163/// Translate a Python while-loop test expression to Rust.
164///
165/// Handles:
166/// - `while changed:` → `while changed`
167/// - `while not changed:` → `while !changed`
168/// - `while i < len(x):` → `while i < x.len()`
169pub fn translate_while_test(test: &Expr) -> String {
170    match test {
171        Expr::Name(n) => n.id.to_string(),
172        Expr::UnaryOp(unary) => {
173            use rustpython_parser::ast::UnaryOp;
174            if matches!(unary.op, UnaryOp::Not) {
175                format!("!{}", translate_while_test(&unary.operand))
176            } else {
177                expr_to_rust(test)
178            }
179        }
180        Expr::Compare(comp) if comp.ops.len() == 1 && comp.comparators.len() == 1 => {
181            let left = expr_to_rust(&comp.left);
182            let right = expr_to_rust(&comp.comparators[0]);
183            let op = match comp.ops[0] {
184                CmpOp::Lt => "<",
185                CmpOp::LtE => "<=",
186                CmpOp::Gt => ">",
187                CmpOp::GtE => ">=",
188                CmpOp::Eq => "==",
189                CmpOp::NotEq => "!=",
190                _ => "<",
191            };
192            format!("{left} {op} {right}")
193        }
194        _ => expr_to_rust(test),
195    }
196}
197
198/// Translate a Python if-test that guards a length check into a Rust guard string.
199/// Returns `None` if the test is not a simple equality/inequality comparison.
200pub fn translate_len_guard(test: &Expr) -> Option<String> {
201    if let Expr::Compare(comp) = test
202        && comp.ops.len() == 1
203        && comp.comparators.len() == 1
204    {
205        let op = &comp.ops[0];
206        let left = expr_to_rust(&comp.left);
207        let right = expr_to_rust(&comp.comparators[0]);
208        let cond = match op {
209            CmpOp::Eq => format!("{left} == {right}"),
210            CmpOp::NotEq => format!("{left} != {right}"),
211            _ => return None,
212        };
213        return Some(format!(
214            "if {cond} {{\n        return Err(pyo3::exceptions::PyValueError::new_err(\"Vectors must be same length\"));\n    }}",
215            cond = cond
216        ));
217    }
218    None
219}