rustify_ml/generator/
infer.rs1use rustpython_parser::ast::Expr;
4
5pub fn infer_params(args: &rustpython_parser::ast::Arguments) -> Vec<(String, String)> {
7 args.args
8 .iter()
9 .map(|a| {
10 let ty = if a.def.annotation.is_none() {
11 infer_type_from_name(a.def.arg.as_str())
12 } else {
13 infer_type_from_annotation(a.def.annotation.as_deref())
14 };
15 (a.def.arg.to_string(), ty)
16 })
17 .collect()
18}
19
20pub fn infer_type_from_annotation(annotation: Option<&Expr>) -> String {
29 match annotation {
30 Some(Expr::Name(n)) if n.id.as_str() == "int" => "usize".to_string(),
31 Some(Expr::Name(n)) if n.id.as_str() == "float" => "f64".to_string(),
32 Some(Expr::Attribute(attr)) => {
33 if let Expr::Name(base) = attr.value.as_ref() {
34 if base.id.as_str() == "np" || base.id.as_str() == "numpy" {
35 return "Vec<f64>".to_string();
36 }
37 if base.id.as_str() == "torch" && attr.attr.as_str() == "Tensor" {
38 return "Vec<f64>".to_string();
39 }
40 }
41 "Vec<f64>".to_string()
42 }
43 _ => "Vec<f64>".to_string(),
44 }
45}
46
47fn infer_type_from_name(name: &str) -> String {
52 match name {
53 "window" | "k" | "n" | "m" | "length" | "size" | "count" | "steps" => "usize".to_string(),
54 _ => "Vec<f64>".to_string(),
55 }
56}
57
58pub fn infer_assign_type(value: &Expr) -> &'static str {
62 match value {
63 Expr::Constant(c) => match &c.value {
64 rustpython_parser::ast::Constant::Float(_) => ": f64",
65 rustpython_parser::ast::Constant::Int(_) => ": i64",
66 _ => "",
67 },
68 _ => "",
69 }
70}
71
72pub fn render_len_checks(params: &[(String, String)]) -> Option<String> {
81 let vec_params: Vec<&String> = params
82 .iter()
83 .filter(|(_, ty)| ty.contains("Vec<") || ty.contains("[f64]"))
84 .map(|(n, _)| n)
85 .collect();
86
87 if vec_params.len() < 2 {
88 return None;
89 }
90
91 let first = vec_params[0];
92 let mut checks = String::new();
93 for other in vec_params.iter().skip(1) {
94 checks.push_str(&format!(
95 " if {first}.len() != {other}.len() {{\n return Err(pyo3::exceptions::PyValueError::new_err(\"length mismatch\"));\n }}\n",
96 first = first,
97 other = other
98 ));
99 }
100 Some(checks)
101}