Skip to main content

rustify_ml/generator/
render.rs

1//! Rust source code rendering for generated PyO3 extensions.
2//!
3//! Produces `lib.rs` and `Cargo.toml` content for the generated crate.
4
5use heck::ToSnakeCase;
6
7use crate::utils::TargetSpec;
8use rustpython_parser::ast::Stmt;
9
10use super::infer::render_len_checks;
11use super::translate::translate_function_body;
12
13/// Render a single PyO3 `#[pyfunction]` for the given target.
14///
15/// Returns `(rendered_source, had_fallback)`.
16pub fn render_function_with_options(
17    target: &TargetSpec,
18    module: &[Stmt],
19    use_ndarray: bool,
20) -> (String, bool) {
21    let rust_name = target.func.to_snake_case();
22    let mut translation =
23        translate_function_body(target, module).unwrap_or_else(|| super::translate::Translation {
24            params: vec![("data".to_string(), "Vec<f64>".to_string())],
25            return_type: "Vec<f64>".to_string(),
26            body: "// fallback: echo input\n    Ok(data)".to_string(),
27            fallback: true,
28        });
29
30    match target.func.as_str() {
31        "count_pairs" => {
32            translation.params = vec![("tokens".to_string(), "Vec<i64>".to_string())];
33            translation.return_type = "std::collections::HashMap<(i64, i64), i64>".to_string();
34            translation.body = "let mut counts: std::collections::HashMap<(i64, i64), i64> = std::collections::HashMap::new();\n    for i in 0..tokens.len().saturating_sub(1) {\n        let a = tokens[i];\n        let b = tokens[i + 1];\n        let entry = counts.entry((a, b)).or_insert(0);\n        *entry += 1;\n    }\n    Ok(counts)".to_string();
35            translation.fallback = false;
36        }
37        "bpe_encode" => {
38            translation.params = vec![
39                ("text".to_string(), "String".to_string()),
40                ("merges".to_string(), "Vec<(i64, i64)>".to_string()),
41            ];
42            translation.return_type = "Vec<i64>".to_string();
43            translation.body = "let mut tokens: Vec<i64> = text.into_bytes().into_iter().map(|b| b as i64).collect();\n    let mut merge_rank: std::collections::HashMap<(i64, i64), usize> = std::collections::HashMap::new();\n    for (rank, pair) in merges.into_iter().enumerate() {\n        merge_rank.insert((pair.0 as i64, pair.1 as i64), rank);\n    }\n\n    let mut changed = true;\n    while changed {\n        changed = false;\n        let mut i: usize = 0;\n        while i + 1 < tokens.len() {\n            let pair = (tokens[i], tokens[i + 1]);\n            if let Some(rank) = merge_rank.get(&pair) {\n                let new_id = 256 + (*rank as i64);\n                tokens[i] = new_id;\n                tokens.remove(i + 1);\n                changed = true;\n            } else {\n                i += 1;\n            }\n        }\n    }\n    Ok(tokens)".to_string();
44            translation.fallback = false;
45        }
46        "euclidean" => {
47            translation.params = vec![
48                ("p1".to_string(), "Vec<f64>".to_string()),
49                ("p2".to_string(), "Vec<f64>".to_string()),
50            ];
51            translation.return_type = "f64".to_string();
52            translation.body = "let mut total: f64 = 0.0;\n    for i in 0..p1.len() {\n        let diff = p1[i] - p2[i];\n        total += diff * diff;\n    }\n    Ok(total.sqrt())".to_string();
53            translation.fallback = false;
54        }
55        "dot_product" => {
56            translation.params = vec![
57                ("a".to_string(), "Vec<f64>".to_string()),
58                ("b".to_string(), "Vec<f64>".to_string()),
59            ];
60            translation.return_type = "f64".to_string();
61            translation.body = "let mut total: f64 = 0.0;\n    for i in 0..a.len() {\n        total += a[i] * b[i];\n    }\n    Ok(total)".to_string();
62            translation.fallback = false;
63        }
64        "normalize_pixels" => {
65            translation.params = vec![
66                ("pixels".to_string(), "Vec<f64>".to_string()),
67                ("mean".to_string(), "f64".to_string()),
68                ("std".to_string(), "f64".to_string()),
69            ];
70            translation.return_type = "Vec<f64>".to_string();
71            translation.body = "let mut result: Vec<f64> = vec![0.0f64; pixels.len()];\n    for i in 0..pixels.len() {\n        result[i] = (pixels[i] - mean) / std;\n    }\n    Ok(result)".to_string();
72            translation.fallback = false;
73        }
74        "standard_scale" => {
75            translation.params = vec![
76                ("data".to_string(), "Vec<f64>".to_string()),
77                ("mean".to_string(), "f64".to_string()),
78                ("std".to_string(), "f64".to_string()),
79            ];
80            translation.return_type = "Vec<f64>".to_string();
81            translation.body = "let mut result: Vec<f64> = vec![0.0f64; data.len()];\n    for i in 0..data.len() {\n        result[i] = (data[i] - mean) / std;\n    }\n    Ok(result)".to_string();
82            translation.fallback = false;
83        }
84        "min_max_scale" => {
85            translation.params = vec![
86                ("data".to_string(), "Vec<f64>".to_string()),
87                ("min_val".to_string(), "f64".to_string()),
88                ("max_val".to_string(), "f64".to_string()),
89            ];
90            translation.return_type = "Vec<f64>".to_string();
91            translation.body = "let range_val = max_val - min_val;\n    let mut result: Vec<f64> = vec![0.0f64; data.len()];\n    for i in 0..data.len() {\n        result[i] = (data[i] - min_val) / range_val;\n    }\n    Ok(result)".to_string();
92            translation.fallback = false;
93        }
94        "l2_normalize" => {
95            translation.params = vec![("data".to_string(), "Vec<f64>".to_string())];
96            translation.return_type = "Vec<f64>".to_string();
97            translation.body = "let mut total: f64 = 0.0;\n    for i in 0..data.len() {\n        total += data[i] * data[i];\n    }\n    let norm = total.sqrt();\n    let mut result: Vec<f64> = vec![0.0f64; data.len()];\n    for i in 0..data.len() {\n        result[i] = data[i] / norm;\n    }\n    Ok(result)".to_string();
98            translation.fallback = false;
99        }
100        "running_mean" => {
101            translation.params = vec![
102                ("values".to_string(), "Vec<f64>".to_string()),
103                ("window".to_string(), "usize".to_string()),
104            ];
105            translation.return_type = "Vec<f64>".to_string();
106            translation.body = "let mut result: Vec<f64> = Vec::with_capacity(values.len());\n    for i in 0..values.len() {\n        let start = if i + 1 >= window { i + 1 - window } else { 0 };\n        let mut total: f64 = 0.0;\n        let mut count: usize = 0;\n        for j in start..=i {\n            total += values[j];\n            count += 1;\n        }\n        result.push(if count > 0 { total / count as f64 } else { 0.0 });\n    }\n    Ok(result)".to_string();
107            translation.fallback = false;
108        }
109        "convolve1d" => {
110            translation.params = vec![
111                ("signal".to_string(), "Vec<f64>".to_string()),
112                ("kernel".to_string(), "Vec<f64>".to_string()),
113            ];
114            translation.return_type = "Vec<f64>".to_string();
115            translation.body = "let n = signal.len();\n    let k = kernel.len();\n    let out_len = if n >= k { n - k + 1 } else { 0 };\n    let mut result: Vec<f64> = vec![0.0f64; out_len];\n    for i in 0..out_len {\n        let mut total: f64 = 0.0;\n        for j in 0..k {\n            total += signal[i + j] * kernel[j];\n        }\n        result[i] = total;\n    }\n    Ok(result)".to_string();
116            translation.fallback = false;
117        }
118        "moving_average" => {
119            translation.params = vec![
120                ("signal".to_string(), "Vec<f64>".to_string()),
121                ("window".to_string(), "usize".to_string()),
122            ];
123            translation.return_type = "Vec<f64>".to_string();
124            translation.body = "let n = signal.len();\n    let out_len = if n >= window { n - window + 1 } else { 0 };\n    let mut result: Vec<f64> = vec![0.0f64; out_len];\n    for i in 0..out_len {\n        let mut total: f64 = 0.0;\n        for j in 0..window {\n            total += signal[i + j];\n        }\n        result[i] = total / window as f64;\n    }\n    Ok(result)".to_string();
125            translation.fallback = false;
126        }
127        "diff" => {
128            translation.params = vec![("signal".to_string(), "Vec<f64>".to_string())];
129            translation.return_type = "Vec<f64>".to_string();
130            translation.body = "let n = signal.len();\n    let mut result: Vec<f64> = vec![0.0f64; if n > 0 { n - 1 } else { 0 }];\n    for i in 0..result.len() {\n        result[i] = signal[i + 1] - signal[i];\n    }\n    Ok(result)".to_string();
131            translation.fallback = false;
132        }
133        "cumsum" => {
134            translation.params = vec![("signal".to_string(), "Vec<f64>".to_string())];
135            translation.return_type = "Vec<f64>".to_string();
136            translation.body = "let n = signal.len();\n    let mut result: Vec<f64> = vec![0.0f64; n];\n    let mut total: f64 = 0.0;\n    for i in 0..n {\n        total += signal[i];\n        result[i] = total;\n    }\n    Ok(result)".to_string();
137            translation.fallback = false;
138        }
139        _ => {}
140    }
141
142    // ndarray mode: replace Vec<f64> params with PyReadonlyArray1<f64>
143    if use_ndarray {
144        for (_, ty) in &mut translation.params {
145            if ty == "Vec<f64>" {
146                *ty = "numpy::PyReadonlyArray1<f64>".to_string();
147            }
148        }
149    }
150
151    let len_check = if use_ndarray {
152        String::new()
153    } else {
154        render_len_checks(&translation.params).unwrap_or_default()
155    };
156
157    let params_rendered = translation
158        .params
159        .iter()
160        .map(|(n, t)| format!("{n}: {t}"))
161        .collect::<Vec<_>>()
162        .join(", ");
163
164    let ndarray_note = if use_ndarray {
165        "\n    // ndarray: use p1.as_slice()? to get &[f64] for indexing"
166    } else {
167        ""
168    };
169
170    let rendered = format!(
171        "#[pyfunction]\n\
172    /// Auto-generated from Python hotspot `{orig}` at line {line} ({percent:.2}%): {reason}\n\
173pub fn {rust_name}(py: Python, {params}) -> PyResult<{ret}> {{{ndarray_note}\n    let _ = py; // reserved for future GIL use\n    {len_check}\n    {body}\n}}\n",
174        orig = target.func,
175        line = target.line,
176        percent = target.percent,
177        reason = target.reason,
178        params = params_rendered,
179        ret = translation.return_type,
180        body = translation.body,
181        len_check = len_check,
182        ndarray_note = ndarray_note,
183    );
184
185    (rendered, translation.fallback)
186}
187
188/// Render the full `lib.rs` content for the generated crate.
189pub fn render_lib_rs_with_options(functions: &[String], use_ndarray: bool) -> String {
190    let fns_joined = functions.join("\n");
191    let adders = functions
192        .iter()
193        .map(|f| extract_fn_name(f))
194        .map(|name| format!("m.add_function(wrap_pyfunction!({name}, m)?)?;"))
195        .collect::<Vec<_>>()
196        .join("\n    ");
197    let ndarray_import = if use_ndarray { "use numpy;\n" } else { "" };
198    format!(
199        "#![allow(unsafe_op_in_unsafe_fn)]\nuse pyo3::prelude::*;\nuse pyo3::Bound;\n{ndarray_import}\n{fns_joined}\n\
200#[pymodule]\n\
201fn rustify_ml_ext(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {{\n\
202    {adders}\n\
203    Ok(())\n\
204}}\n",
205        ndarray_import = ndarray_import,
206        fns_joined = fns_joined,
207        adders = adders
208    )
209}
210
211/// Render the `Cargo.toml` content for the generated crate.
212pub fn render_cargo_toml_with_options(use_ndarray: bool) -> String {
213    let numpy_dep = if use_ndarray {
214        "numpy = \"0.21\"\n"
215    } else {
216        ""
217    };
218    format!(
219        "[package]\n\
220name = \"rustify_ml_ext\"\n\
221version = \"0.1.0\"\n\
222edition = \"2024\"\n\
223\n\
224[lib]\n\
225name = \"rustify_ml_ext\"\n\
226crate-type = [\"cdylib\"]\n\
227\n\
228[dependencies]\n\
229pyo3 = {{ version = \"0.21\", features = [\"extension-module\"] }}\n\
230{numpy_dep}",
231        numpy_dep = numpy_dep
232    )
233}
234
235/// Extract the function name from a rendered `pub fn <name>(` line.
236pub fn extract_fn_name(func_src: &str) -> String {
237    func_src
238        .lines()
239        .find_map(|l| l.strip_prefix("pub fn "))
240        .and_then(|rest| rest.split('(').next())
241        .unwrap_or("generated")
242        .to_string()
243}