Skip to main content

rustify_ml/generator/
render.rs

1//! Rust source code rendering for generated PyO3 extensions.
2//!
3//! Produces `lib.rs` and `Cargo.toml` content for the generated crate.
4
5use heck::ToSnakeCase;
6
7use crate::utils::TargetSpec;
8use rustpython_parser::ast::Stmt;
9
10use super::infer::render_len_checks;
11use super::translate::translate_function_body;
12
13/// Render a single PyO3 `#[pyfunction]` for the given target.
14///
15/// Returns `(rendered_source, had_fallback)`.
16pub fn render_function_with_options(
17    target: &TargetSpec,
18    module: &[Stmt],
19    use_ndarray: bool,
20) -> (String, bool) {
21    let rust_name = target.func.to_snake_case();
22    let mut translation =
23        translate_function_body(target, module).unwrap_or_else(|| super::translate::Translation {
24            params: vec![("data".to_string(), "Vec<f64>".to_string())],
25            return_type: "Vec<f64>".to_string(),
26            body: "// fallback: echo input\n    Ok(data)".to_string(),
27            fallback: true,
28        });
29
30    // ndarray mode: replace Vec<f64> params with PyReadonlyArray1<f64>
31    if use_ndarray {
32        for (_, ty) in &mut translation.params {
33            if ty == "Vec<f64>" {
34                *ty = "numpy::PyReadonlyArray1<f64>".to_string();
35            }
36        }
37    }
38
39    let len_check = if use_ndarray {
40        String::new()
41    } else {
42        render_len_checks(&translation.params).unwrap_or_default()
43    };
44
45    let params_rendered = translation
46        .params
47        .iter()
48        .map(|(n, t)| format!("{n}: {t}"))
49        .collect::<Vec<_>>()
50        .join(", ");
51
52    let ndarray_note = if use_ndarray {
53        "\n    // ndarray: use p1.as_slice()? to get &[f64] for indexing"
54    } else {
55        ""
56    };
57
58    let rendered = format!(
59        "#[pyfunction]\n\
60    /// Auto-generated from Python hotspot `{orig}` at line {line} ({percent:.2}%): {reason}\n\
61pub fn {rust_name}(py: Python, {params}) -> PyResult<{ret}> {{{ndarray_note}\n    let _ = py; // reserved for future GIL use\n    {len_check}\n    {body}\n}}\n",
62        orig = target.func,
63        line = target.line,
64        percent = target.percent,
65        reason = target.reason,
66        params = params_rendered,
67        ret = translation.return_type,
68        body = translation.body,
69        len_check = len_check,
70        ndarray_note = ndarray_note,
71    );
72
73    (rendered, translation.fallback)
74}
75
76/// Render the full `lib.rs` content for the generated crate.
77pub fn render_lib_rs_with_options(functions: &[String], use_ndarray: bool) -> String {
78    let fns_joined = functions.join("\n");
79    let adders = functions
80        .iter()
81        .map(|f| extract_fn_name(f))
82        .map(|name| format!("m.add_function(wrap_pyfunction!({name}, m)?)?;"))
83        .collect::<Vec<_>>()
84        .join("\n    ");
85    let ndarray_import = if use_ndarray { "use numpy;\n" } else { "" };
86    format!(
87        "use pyo3::prelude::*;\n{ndarray_import}\n{fns_joined}\n\
88#[pymodule]\n\
89fn rustify_ml_ext(_py: Python, m: &PyModule) -> PyResult<()> {{\n\
90    {adders}\n\
91    Ok(())\n\
92}}\n",
93        ndarray_import = ndarray_import,
94        fns_joined = fns_joined,
95        adders = adders
96    )
97}
98
99/// Render the `Cargo.toml` content for the generated crate.
100pub fn render_cargo_toml_with_options(use_ndarray: bool) -> String {
101    let numpy_dep = if use_ndarray {
102        "numpy = \"0.21\"\n"
103    } else {
104        ""
105    };
106    format!(
107        "[package]\n\
108name = \"rustify_ml_ext\"\n\
109version = \"0.1.0\"\n\
110edition = \"2024\"\n\
111\n\
112[lib]\n\
113name = \"rustify_ml_ext\"\n\
114crate-type = [\"cdylib\"]\n\
115\n\
116[dependencies]\n\
117pyo3 = {{ version = \"0.21\", features = [\"extension-module\"] }}\n\
118{numpy_dep}",
119        numpy_dep = numpy_dep
120    )
121}
122
123/// Extract the function name from a rendered `pub fn <name>(` line.
124pub fn extract_fn_name(func_src: &str) -> String {
125    func_src
126        .lines()
127        .find_map(|l| l.strip_prefix("pub fn "))
128        .and_then(|rest| rest.split('(').next())
129        .unwrap_or("generated")
130        .to_string()
131}