rustify_ml/generator/
expr.rs1use rustpython_parser::ast::{BoolOp, CmpOp, Expr, Operator, UnaryOp};
7
8pub fn expr_to_rust(expr: &Expr) -> String {
10 match expr {
11 Expr::Name(n) => n.id.to_string(),
12 Expr::Constant(c) => constant_to_rust(&c.value),
13 Expr::UnaryOp(unary) => {
14 let operand = expr_to_rust(&unary.operand);
15 match unary.op {
16 UnaryOp::USub => format!("-({operand})"),
17 UnaryOp::Not => format!("!{operand}"),
18 _ => format!("-{operand}"),
19 }
20 }
21 Expr::BoolOp(boolop) if !boolop.values.is_empty() => {
22 let op = match boolop.op {
23 BoolOp::And => "&&",
24 BoolOp::Or => "||",
25 };
26 let rendered: Vec<String> = boolop.values.iter().map(expr_to_rust).collect();
27 rendered.join(&format!(" {op} "))
28 }
29 Expr::Call(call) => {
30 if let Expr::Name(func) = call.func.as_ref() {
31 if func.id.as_str() == "range" && call.args.len() == 1 {
32 return format!("0..{}", expr_to_rust(&call.args[0]));
33 }
34 if func.id.as_str() == "len" && call.args.len() == 1 {
35 return format!("{}.len()", expr_to_rust(&call.args[0]));
36 }
37 if (func.id.as_str() == "max" || func.id.as_str() == "min") && call.args.len() == 2
38 {
39 let a = expr_to_rust(&call.args[0]);
40 let b = expr_to_rust(&call.args[1]);
41 let method = if func.id.as_str() == "max" {
42 "max"
43 } else {
44 "min"
45 };
46 return format!("({a}).{method}({b})");
47 }
48 }
49 format!("/* call {} fallback */", expr_to_rust(call.func.as_ref()))
50 }
51 Expr::BinOp(binop) => {
52 let left = expr_to_rust(&binop.left);
53 let right = expr_to_rust(&binop.right);
54 let op = match binop.op {
55 Operator::Add => "+",
56 Operator::Sub => "-",
57 Operator::Mult => "*",
58 Operator::Div => "/",
59 Operator::Pow => {
60 return format!("({}).powf({})", left, right);
61 }
62 _ => "+",
63 };
64 format!("({} {} {})", left, op, right)
65 }
66 Expr::Compare(comp) if comp.ops.len() == 1 && comp.comparators.len() == 1 => {
67 let left = expr_to_rust(&comp.left);
68 let right = expr_to_rust(&comp.comparators[0]);
69 let op = match comp.ops[0] {
70 CmpOp::Lt => "<",
71 CmpOp::LtE => "<=",
72 CmpOp::Gt => ">",
73 CmpOp::GtE => ">=",
74 CmpOp::Eq => "==",
75 CmpOp::NotEq => "!=",
76 _ => "<",
77 };
78 format!("{left} {op} {right}")
79 }
80 Expr::Subscript(sub) => {
81 let value = expr_to_rust(&sub.value);
82 let index = expr_to_rust(&sub.slice);
83 format!("{}[{}]", value, index)
84 }
85 Expr::Attribute(attr) => {
86 format!("{}.{}", expr_to_rust(&attr.value), attr.attr)
87 }
88 Expr::List(list) => {
89 let elems: Vec<String> = list.elts.iter().map(expr_to_rust).collect();
90 format!("vec![{}]", elems.join(", "))
91 }
92 Expr::Tuple(tuple) => {
93 let elems: Vec<String> = tuple.elts.iter().map(expr_to_rust).collect();
94 format!("({})", elems.join(", "))
95 }
96 _ => "/* unsupported expr */".to_string(),
97 }
98}
99
100pub fn constant_to_rust(value: &rustpython_parser::ast::Constant) -> String {
102 use rustpython_parser::ast::Constant;
103 match value {
104 Constant::Int(i) => i.to_string(),
105 Constant::Float(f) => {
106 let s = format!("{}", f);
107 if s.contains('.') || s.contains('e') {
108 s
109 } else {
110 format!("{}.0", s)
111 }
112 }
113 Constant::Bool(b) => b.to_string(),
114 Constant::Str(s) => format!("\"{}\"", s.escape_default()),
115 Constant::None => "()".to_string(),
116 _ => "0".to_string(),
117 }
118}
119
120pub fn translate_for_iter(iter: &Expr) -> String {
123 if let Expr::Call(call) = iter
124 && let Expr::Name(func) = call.func.as_ref()
125 && func.id.as_str() == "range"
126 {
127 match call.args.len() {
128 1 => return format!("0..{}", expr_to_rust(&call.args[0])),
129 2 => {
130 return format!(
131 "{}..{}",
132 expr_to_rust(&call.args[0]),
133 expr_to_rust(&call.args[1])
134 );
135 }
136 _ => {}
137 }
138 }
139 expr_to_rust(iter)
140}
141
142pub fn translate_while_test(test: &Expr) -> String {
149 match test {
150 Expr::Name(n) => n.id.to_string(),
151 Expr::UnaryOp(unary) => {
152 use rustpython_parser::ast::UnaryOp;
153 if matches!(unary.op, UnaryOp::Not) {
154 format!("!{}", translate_while_test(&unary.operand))
155 } else {
156 expr_to_rust(test)
157 }
158 }
159 Expr::Compare(comp) if comp.ops.len() == 1 && comp.comparators.len() == 1 => {
160 let left = expr_to_rust(&comp.left);
161 let right = expr_to_rust(&comp.comparators[0]);
162 let op = match comp.ops[0] {
163 CmpOp::Lt => "<",
164 CmpOp::LtE => "<=",
165 CmpOp::Gt => ">",
166 CmpOp::GtE => ">=",
167 CmpOp::Eq => "==",
168 CmpOp::NotEq => "!=",
169 _ => "<",
170 };
171 format!("{left} {op} {right}")
172 }
173 _ => expr_to_rust(test),
174 }
175}
176
177pub fn translate_len_guard(test: &Expr) -> Option<String> {
180 if let Expr::Compare(comp) = test
181 && comp.ops.len() == 1
182 && comp.comparators.len() == 1
183 {
184 let op = &comp.ops[0];
185 let left = expr_to_rust(&comp.left);
186 let right = expr_to_rust(&comp.comparators[0]);
187 let cond = match op {
188 CmpOp::Eq => format!("{left} == {right}"),
189 CmpOp::NotEq => format!("{left} != {right}"),
190 _ => return None,
191 };
192 return Some(format!(
193 "if {cond} {{\n return Err(pyo3::exceptions::PyValueError::new_err(\"Vectors must be same length\"));\n }}",
194 cond = cond
195 ));
196 }
197 None
198}