rustify_ml/generator/
infer.rs1use rustpython_parser::ast::Expr;
4
5pub fn infer_params(args: &rustpython_parser::ast::Arguments) -> Vec<(String, String)> {
7 args.args
8 .iter()
9 .map(|a| {
10 let ty = if a.def.annotation.is_none() {
11 infer_type_from_name(a.def.arg.as_str())
12 } else {
13 infer_type_from_annotation(a.def.annotation.as_deref())
14 };
15 (a.def.arg.to_string(), ty)
16 })
17 .collect()
18}
19
20pub fn infer_type_from_annotation(annotation: Option<&Expr>) -> String {
29 match annotation {
30 Some(Expr::Name(n)) if n.id.as_str() == "int" => "usize".to_string(),
31 Some(Expr::Name(n)) if n.id.as_str() == "float" => "f64".to_string(),
32 Some(Expr::Attribute(attr)) => {
33 if let Expr::Name(base) = attr.value.as_ref() {
34 if base.id.as_str() == "np" || base.id.as_str() == "numpy" {
35 return "numpy::PyReadonlyArray1<f64>".to_string();
36 }
37 if base.id.as_str() == "torch" && attr.attr.as_str() == "Tensor" {
38 return "Vec<f64>".to_string();
39 }
40 }
41 "Vec<f64>".to_string()
42 }
43 _ => "Vec<f64>".to_string(),
44 }
45}
46
47fn infer_type_from_name(name: &str) -> String {
52 match name {
53 "text" | "string" | "s" | "line" => "String".to_string(),
54 "merges" => "Vec<i64>".to_string(),
55 "window" | "k" | "n" | "m" | "length" | "size" | "count" | "steps" => "usize".to_string(),
56 _ => "Vec<f64>".to_string(),
57 }
58}
59
60pub fn infer_assign_type(value: &Expr) -> &'static str {
64 match value {
65 Expr::Constant(c) => match &c.value {
66 rustpython_parser::ast::Constant::Float(_) => ": f64",
67 rustpython_parser::ast::Constant::Int(_) => ": i64",
68 _ => "",
69 },
70 _ => "",
71 }
72}
73
74pub fn render_len_checks(params: &[(String, String)]) -> Option<String> {
84 let vec_params: Vec<&String> = params
85 .iter()
86 .filter(|(_, ty)| ty.contains("Vec<") || ty.contains("[f64]"))
87 .map(|(n, _)| n)
88 .collect();
89
90 if vec_params.len() < 2 {
91 return None;
92 }
93
94 let first = vec_params[0];
95 let mut checks = String::new();
96 for other in vec_params.iter().skip(1) {
97 checks.push_str(&format!(
98 " if {first}.len() != {other}.len() {{\n return Err(pyo3::exceptions::PyValueError::new_err(\"length mismatch\"));\n }}\n",
99 first = first,
100 other = other
101 ));
102 }
103 Some(checks)
104}