rustify_ml/generator/
expr.rs1use rustpython_parser::ast::{BoolOp, CmpOp, Expr, Operator, UnaryOp};
7
8pub fn expr_to_rust(expr: &Expr) -> String {
10 match expr {
11 Expr::Name(n) => n.id.to_string(),
12 Expr::Constant(c) => constant_to_rust(&c.value),
13 Expr::UnaryOp(unary) => {
14 let operand = expr_to_rust(&unary.operand);
15 match unary.op {
16 UnaryOp::USub => format!("-({operand})"),
17 UnaryOp::Not => format!("!{operand}"),
18 _ => format!("-{operand}"),
19 }
20 }
21 Expr::BoolOp(boolop) if !boolop.values.is_empty() => {
22 let op = match boolop.op {
23 BoolOp::And => "&&",
24 BoolOp::Or => "||",
25 };
26 let rendered: Vec<String> = boolop.values.iter().map(expr_to_rust).collect();
27 rendered.join(&format!(" {op} "))
28 }
29 Expr::Call(call) => {
30 if let Expr::Name(func) = call.func.as_ref() {
31 if func.id.as_str() == "range" && call.args.len() == 1 {
32 return format!("0..{}", expr_to_rust(&call.args[0]));
33 }
34 if func.id.as_str() == "len" && call.args.len() == 1 {
35 return format!("{}.len()", expr_to_rust(&call.args[0]));
36 }
37 if (func.id.as_str() == "max" || func.id.as_str() == "min") && call.args.len() == 2
38 {
39 let a = expr_to_rust(&call.args[0]);
40 let b = expr_to_rust(&call.args[1]);
41 let method = if func.id.as_str() == "max" {
42 "max"
43 } else {
44 "min"
45 };
46 return format!("({a}).{method}({b})");
47 }
48 }
49 format!("/* call {} fallback */", expr_to_rust(call.func.as_ref()))
50 }
51 Expr::BinOp(binop) => {
52 let left = expr_to_rust(&binop.left);
53 let right = expr_to_rust(&binop.right);
54 let op = match binop.op {
55 Operator::Add => "+",
56 Operator::Sub => "-",
57 Operator::Mult => "*",
58 Operator::Div => "/",
59 Operator::Pow => {
60 return format!("({}).powf({})", left, right);
61 }
62 _ => "+",
63 };
64 format!("({} {} {})", left, op, right)
65 }
66 Expr::Compare(comp) if comp.ops.len() == 1 && comp.comparators.len() == 1 => {
67 let left = expr_to_rust(&comp.left);
68 let right = expr_to_rust(&comp.comparators[0]);
69 let op = match comp.ops[0] {
70 CmpOp::Lt => "<",
71 CmpOp::LtE => "<=",
72 CmpOp::Gt => ">",
73 CmpOp::GtE => ">=",
74 CmpOp::Eq => "==",
75 CmpOp::NotEq => "!=",
76 _ => "<",
77 };
78 format!("{left} {op} {right}")
79 }
80 Expr::Subscript(sub) => {
81 let value = expr_to_rust(&sub.value);
82 match sub.slice.as_ref() {
83 rustpython_parser::ast::Expr::Slice(slice) => {
85 if slice.step.is_some() {
86 return "/* unsupported slice step */".to_string();
87 }
88 let start = slice
89 .lower
90 .as_deref()
91 .map(expr_to_rust)
92 .unwrap_or_else(|| "0".to_string());
93 let end = slice
94 .upper
95 .as_deref()
96 .map(expr_to_rust)
97 .unwrap_or_else(|| format!("{}.len()", value));
98 format!("{}[{}..{}].to_vec()", value, start, end)
99 }
100 other => {
101 let index = expr_to_rust(other);
102 format!("{}[{}]", value, index)
103 }
104 }
105 }
106 Expr::Attribute(attr) => {
107 format!("{}.{}", expr_to_rust(&attr.value), attr.attr)
108 }
109 Expr::List(list) => {
110 let elems: Vec<String> = list.elts.iter().map(expr_to_rust).collect();
111 format!("vec![{}]", elems.join(", "))
112 }
113 Expr::Tuple(tuple) => {
114 let elems: Vec<String> = tuple.elts.iter().map(expr_to_rust).collect();
115 format!("({})", elems.join(", "))
116 }
117 _ => "/* unsupported expr */".to_string(),
118 }
119}
120
121pub fn constant_to_rust(value: &rustpython_parser::ast::Constant) -> String {
123 use rustpython_parser::ast::Constant;
124 match value {
125 Constant::Int(i) => i.to_string(),
126 Constant::Float(f) => {
127 let s = format!("{}", f);
128 if s.contains('.') || s.contains('e') {
129 s
130 } else {
131 format!("{}.0", s)
132 }
133 }
134 Constant::Bool(b) => b.to_string(),
135 Constant::Str(s) => format!("\"{}\"", s.escape_default()),
136 Constant::None => "()".to_string(),
137 _ => "0".to_string(),
138 }
139}
140
141pub fn translate_for_iter(iter: &Expr) -> String {
144 if let Expr::Call(call) = iter
145 && let Expr::Name(func) = call.func.as_ref()
146 && func.id.as_str() == "range"
147 {
148 match call.args.len() {
149 1 => return format!("0..{}", expr_to_rust(&call.args[0])),
150 2 => {
151 return format!(
152 "{}..{}",
153 expr_to_rust(&call.args[0]),
154 expr_to_rust(&call.args[1])
155 );
156 }
157 _ => {}
158 }
159 }
160 expr_to_rust(iter)
161}
162
163pub fn translate_while_test(test: &Expr) -> String {
170 match test {
171 Expr::Name(n) => n.id.to_string(),
172 Expr::UnaryOp(unary) => {
173 use rustpython_parser::ast::UnaryOp;
174 if matches!(unary.op, UnaryOp::Not) {
175 format!("!{}", translate_while_test(&unary.operand))
176 } else {
177 expr_to_rust(test)
178 }
179 }
180 Expr::Compare(comp) if comp.ops.len() == 1 && comp.comparators.len() == 1 => {
181 let left = expr_to_rust(&comp.left);
182 let right = expr_to_rust(&comp.comparators[0]);
183 let op = match comp.ops[0] {
184 CmpOp::Lt => "<",
185 CmpOp::LtE => "<=",
186 CmpOp::Gt => ">",
187 CmpOp::GtE => ">=",
188 CmpOp::Eq => "==",
189 CmpOp::NotEq => "!=",
190 _ => "<",
191 };
192 format!("{left} {op} {right}")
193 }
194 _ => expr_to_rust(test),
195 }
196}
197
198pub fn translate_len_guard(test: &Expr) -> Option<String> {
201 if let Expr::Compare(comp) = test
202 && comp.ops.len() == 1
203 && comp.comparators.len() == 1
204 {
205 let op = &comp.ops[0];
206 let left = expr_to_rust(&comp.left);
207 let right = expr_to_rust(&comp.comparators[0]);
208 let cond = match op {
209 CmpOp::Eq => format!("{left} == {right}"),
210 CmpOp::NotEq => format!("{left} != {right}"),
211 _ => return None,
212 };
213 return Some(format!(
214 "if {cond} {{\n return Err(pyo3::exceptions::PyValueError::new_err(\"Vectors must be same length\"));\n }}",
215 cond = cond
216 ));
217 }
218 None
219}