Skip to main content

rustify_ml/generator/
infer.rs

1//! Type inference helpers for Python → Rust parameter and assignment types.
2
3use rustpython_parser::ast::Expr;
4
5/// Infer Rust parameter types from a Python function's argument list.
6pub fn infer_params(args: &rustpython_parser::ast::Arguments) -> Vec<(String, String)> {
7    args.args
8        .iter()
9        .map(|a| {
10            let ty = if a.def.annotation.is_none() {
11                infer_type_from_name(a.def.arg.as_str())
12            } else {
13                infer_type_from_annotation(a.def.annotation.as_deref())
14            };
15            (a.def.arg.to_string(), ty)
16        })
17        .collect()
18}
19
20/// Infer a Rust type string from a Python type annotation expression.
21///
22/// Supported annotations:
23/// - `int` → `usize`
24/// - `float` → `f64`
25/// - `np.ndarray`, `numpy.ndarray` → `numpy::PyReadonlyArray1<f64>`
26/// - `torch.Tensor` → `Vec<f64>`
27/// - anything else → `Vec<f64>` (safe default)
28pub fn infer_type_from_annotation(annotation: Option<&Expr>) -> String {
29    match annotation {
30        Some(Expr::Name(n)) if n.id.as_str() == "int" => "usize".to_string(),
31        Some(Expr::Name(n)) if n.id.as_str() == "float" => "f64".to_string(),
32        Some(Expr::Attribute(attr)) => {
33            if let Expr::Name(base) = attr.value.as_ref() {
34                if base.id.as_str() == "np" || base.id.as_str() == "numpy" {
35                    return "numpy::PyReadonlyArray1<f64>".to_string();
36                }
37                if base.id.as_str() == "torch" && attr.attr.as_str() == "Tensor" {
38                    return "Vec<f64>".to_string();
39                }
40            }
41            "Vec<f64>".to_string()
42        }
43        _ => "Vec<f64>".to_string(),
44    }
45}
46
47/// Heuristic type inference when no annotation is provided.
48///
49/// Common scalar loop/size parameters (e.g., window, k, n, m, length, size) → usize.
50/// Otherwise default to Vec<f64> as the safe ML vector type.
51fn infer_type_from_name(name: &str) -> String {
52    match name {
53        "text" | "string" | "s" | "line" => "String".to_string(),
54        "merges" => "Vec<i64>".to_string(),
55        "window" | "k" | "n" | "m" | "length" | "size" | "count" | "steps" => "usize".to_string(),
56        _ => "Vec<f64>".to_string(),
57    }
58}
59
60/// Infer a Rust type annotation suffix for a simple assignment RHS.
61///
62/// Returns `": f64"` for float literals, `": i64"` for int literals, `""` otherwise.
63pub fn infer_assign_type(value: &Expr) -> &'static str {
64    match value {
65        Expr::Constant(c) => match &c.value {
66            rustpython_parser::ast::Constant::Float(_) => ": f64",
67            rustpython_parser::ast::Constant::Int(_) => ": i64",
68            _ => "",
69        },
70        _ => "",
71    }
72}
73
74/// Emit Rust length-check guards for Vec parameters.
75///
76/// If two or more `Vec<...>` params are present, emits:
77/// ```rust,ignore
78/// // assumes `pyo3::exceptions::PyValueError` is in scope and `a`, `b` are params
79/// if a.len() != b.len() {
80///     return Err(pyo3::exceptions::PyValueError::new_err("length mismatch"));
81/// }
82/// ```
83pub fn render_len_checks(params: &[(String, String)]) -> Option<String> {
84    let vec_params: Vec<&String> = params
85        .iter()
86        .filter(|(_, ty)| ty.contains("Vec<") || ty.contains("[f64]"))
87        .map(|(n, _)| n)
88        .collect();
89
90    if vec_params.len() < 2 {
91        return None;
92    }
93
94    let first = vec_params[0];
95    let mut checks = String::new();
96    for other in vec_params.iter().skip(1) {
97        checks.push_str(&format!(
98            "    if {first}.len() != {other}.len() {{\n        return Err(pyo3::exceptions::PyValueError::new_err(\"length mismatch\"));\n    }}\n",
99            first = first,
100            other = other
101        ));
102    }
103    Some(checks)
104}