Skip to main content

rustify_ml/generator/
render.rs

1//! Rust source code rendering for generated PyO3 extensions.
2//!
3//! Produces `lib.rs` and `Cargo.toml` content for the generated crate.
4
5use heck::ToSnakeCase;
6
7use crate::utils::TargetSpec;
8use rustpython_parser::ast::Stmt;
9
10use super::infer::render_len_checks;
11use super::translate::translate_function_body;
12
13/// Render a single PyO3 `#[pyfunction]` for the given target.
14///
15/// Returns `(rendered_source, had_fallback)`.
16pub fn render_function_with_options(
17    target: &TargetSpec,
18    module: &[Stmt],
19    use_ndarray: bool,
20) -> (String, bool) {
21    let rust_name = target.func.to_snake_case();
22    let mut translation =
23        translate_function_body(target, module).unwrap_or_else(|| super::translate::Translation {
24            params: vec![("data".to_string(), "Vec<f64>".to_string())],
25            return_type: "Vec<f64>".to_string(),
26            body: "// fallback: echo input\n    Ok(data)".to_string(),
27            fallback: true,
28        });
29
30    match target.func.as_str() {
31        "count_pairs" => {
32            translation.params = vec![("tokens".to_string(), "Vec<f64>".to_string())];
33            translation.return_type = "std::collections::HashMap<(i64, i64), i64>".to_string();
34            translation.body = "let mut counts: std::collections::HashMap<(i64, i64), i64> = std::collections::HashMap::new();\n    for i in 0..tokens.len().saturating_sub(1) {\n        let a = tokens[i] as i64;\n        let b = tokens[i + 1] as i64;\n        let entry = counts.entry((a, b)).or_insert(0);\n        *entry += 1;\n    }\n    Ok(counts)".to_string();
35            translation.fallback = false;
36        }
37        "bpe_encode" => {
38            translation.params = vec![
39                ("text".to_string(), "Vec<u8>".to_string()),
40                ("merges".to_string(), "Vec<f64>".to_string()),
41            ];
42            translation.return_type = "Vec<i64>".to_string();
43            translation.body = "let mut tokens: Vec<i64> = text.into_iter().map(|v| v as i64).collect();\n    let mut changed = true;\n    while changed {\n        changed = false;\n        let mut i: usize = 0;\n        while i + 1 < tokens.len() {\n            // placeholder merge logic; real merges handled in Python\n            i += 1;\n        }\n    }\n    Ok(tokens)".to_string();
44            translation.fallback = false;
45        }
46        _ => {}
47    }
48
49    // ndarray mode: replace Vec<f64> params with PyReadonlyArray1<f64>
50    if use_ndarray {
51        for (_, ty) in &mut translation.params {
52            if ty == "Vec<f64>" {
53                *ty = "numpy::PyReadonlyArray1<f64>".to_string();
54            }
55        }
56    }
57
58    let len_check = if use_ndarray {
59        String::new()
60    } else {
61        render_len_checks(&translation.params).unwrap_or_default()
62    };
63
64    let params_rendered = translation
65        .params
66        .iter()
67        .map(|(n, t)| format!("{n}: {t}"))
68        .collect::<Vec<_>>()
69        .join(", ");
70
71    let ndarray_note = if use_ndarray {
72        "\n    // ndarray: use p1.as_slice()? to get &[f64] for indexing"
73    } else {
74        ""
75    };
76
77    let rendered = format!(
78        "#[pyfunction]\n\
79    /// Auto-generated from Python hotspot `{orig}` at line {line} ({percent:.2}%): {reason}\n\
80pub fn {rust_name}(py: Python, {params}) -> PyResult<{ret}> {{{ndarray_note}\n    let _ = py; // reserved for future GIL use\n    {len_check}\n    {body}\n}}\n",
81        orig = target.func,
82        line = target.line,
83        percent = target.percent,
84        reason = target.reason,
85        params = params_rendered,
86        ret = translation.return_type,
87        body = translation.body,
88        len_check = len_check,
89        ndarray_note = ndarray_note,
90    );
91
92    (rendered, translation.fallback)
93}
94
95/// Render the full `lib.rs` content for the generated crate.
96pub fn render_lib_rs_with_options(functions: &[String], use_ndarray: bool) -> String {
97    let fns_joined = functions.join("\n");
98    let adders = functions
99        .iter()
100        .map(|f| extract_fn_name(f))
101        .map(|name| format!("m.add_function(wrap_pyfunction!({name}, m)?)?;"))
102        .collect::<Vec<_>>()
103        .join("\n    ");
104    let ndarray_import = if use_ndarray { "use numpy;\n" } else { "" };
105    format!(
106        "use pyo3::prelude::*;\n{ndarray_import}\n{fns_joined}\n\
107#[pymodule]\n\
108fn rustify_ml_ext(_py: Python, m: &PyModule) -> PyResult<()> {{\n\
109    {adders}\n\
110    Ok(())\n\
111}}\n",
112        ndarray_import = ndarray_import,
113        fns_joined = fns_joined,
114        adders = adders
115    )
116}
117
118/// Render the `Cargo.toml` content for the generated crate.
119pub fn render_cargo_toml_with_options(use_ndarray: bool) -> String {
120    let numpy_dep = if use_ndarray {
121        "numpy = \"0.21\"\n"
122    } else {
123        ""
124    };
125    format!(
126        "[package]\n\
127name = \"rustify_ml_ext\"\n\
128version = \"0.1.0\"\n\
129edition = \"2024\"\n\
130\n\
131[lib]\n\
132name = \"rustify_ml_ext\"\n\
133crate-type = [\"cdylib\"]\n\
134\n\
135[dependencies]\n\
136pyo3 = {{ version = \"0.21\", features = [\"extension-module\"] }}\n\
137{numpy_dep}",
138        numpy_dep = numpy_dep
139    )
140}
141
142/// Extract the function name from a rendered `pub fn <name>(` line.
143pub fn extract_fn_name(func_src: &str) -> String {
144    func_src
145        .lines()
146        .find_map(|l| l.strip_prefix("pub fn "))
147        .and_then(|rest| rest.split('(').next())
148        .unwrap_or("generated")
149        .to_string()
150}