Skip to main content

rustify_ml/generator/
infer.rs

1//! Type inference helpers for Python → Rust parameter and assignment types.
2
3use rustpython_parser::ast::Expr;
4
5/// Infer Rust parameter types from a Python function's argument list.
6pub fn infer_params(args: &rustpython_parser::ast::Arguments) -> Vec<(String, String)> {
7    args.args
8        .iter()
9        .map(|a| {
10            let ty = if a.def.annotation.is_none() {
11                infer_type_from_name(a.def.arg.as_str())
12            } else {
13                infer_type_from_annotation(a.def.annotation.as_deref())
14            };
15            (a.def.arg.to_string(), ty)
16        })
17        .collect()
18}
19
20/// Infer a Rust type string from a Python type annotation expression.
21///
22/// Supported annotations:
23/// - `int` → `usize`
24/// - `float` → `f64`
25/// - `np.ndarray`, `numpy.ndarray` → `Vec<f64>`
26/// - `torch.Tensor` → `Vec<f64>`
27/// - anything else → `Vec<f64>` (safe default)
28pub fn infer_type_from_annotation(annotation: Option<&Expr>) -> String {
29    match annotation {
30        Some(Expr::Name(n)) if n.id.as_str() == "int" => "usize".to_string(),
31        Some(Expr::Name(n)) if n.id.as_str() == "float" => "f64".to_string(),
32        Some(Expr::Attribute(attr)) => {
33            if let Expr::Name(base) = attr.value.as_ref() {
34                if base.id.as_str() == "np" || base.id.as_str() == "numpy" {
35                    return "Vec<f64>".to_string();
36                }
37                if base.id.as_str() == "torch" && attr.attr.as_str() == "Tensor" {
38                    return "Vec<f64>".to_string();
39                }
40            }
41            "Vec<f64>".to_string()
42        }
43        _ => "Vec<f64>".to_string(),
44    }
45}
46
47/// Heuristic type inference when no annotation is provided.
48///
49/// Common scalar loop/size parameters (e.g., window, k, n, m, length, size) → usize.
50/// Otherwise default to Vec<f64> as the safe ML vector type.
51fn infer_type_from_name(name: &str) -> String {
52    match name {
53        "window" | "k" | "n" | "m" | "length" | "size" | "count" | "steps" => "usize".to_string(),
54        _ => "Vec<f64>".to_string(),
55    }
56}
57
58/// Infer a Rust type annotation suffix for a simple assignment RHS.
59///
60/// Returns `": f64"` for float literals, `": i64"` for int literals, `""` otherwise.
61pub fn infer_assign_type(value: &Expr) -> &'static str {
62    match value {
63        Expr::Constant(c) => match &c.value {
64            rustpython_parser::ast::Constant::Float(_) => ": f64",
65            rustpython_parser::ast::Constant::Int(_) => ": i64",
66            _ => "",
67        },
68        _ => "",
69    }
70}
71
72/// Emit Rust length-check guards for Vec parameters.
73///
74/// If two or more `Vec<...>` params are present, emits:
75/// ```rust
76/// if a.len() != b.len() {
77///     return Err(PyValueError::new_err("length mismatch"));
78/// }
79/// ```
80pub fn render_len_checks(params: &[(String, String)]) -> Option<String> {
81    let vec_params: Vec<&String> = params
82        .iter()
83        .filter(|(_, ty)| ty.contains("Vec<") || ty.contains("[f64]"))
84        .map(|(n, _)| n)
85        .collect();
86
87    if vec_params.len() < 2 {
88        return None;
89    }
90
91    let first = vec_params[0];
92    let mut checks = String::new();
93    for other in vec_params.iter().skip(1) {
94        checks.push_str(&format!(
95            "    if {first}.len() != {other}.len() {{\n        return Err(pyo3::exceptions::PyValueError::new_err(\"length mismatch\"));\n    }}\n",
96            first = first,
97            other = other
98        ));
99    }
100    Some(checks)
101}