rustify_ml/generator/
render.rs1use heck::ToSnakeCase;
6
7use crate::utils::TargetSpec;
8use rustpython_parser::ast::Stmt;
9
10use super::infer::render_len_checks;
11use super::translate::translate_function_body;
12
13pub fn render_function_with_options(
17 target: &TargetSpec,
18 module: &[Stmt],
19 use_ndarray: bool,
20) -> (String, bool) {
21 let rust_name = target.func.to_snake_case();
22 let mut translation =
23 translate_function_body(target, module).unwrap_or_else(|| super::translate::Translation {
24 params: vec![("data".to_string(), "Vec<f64>".to_string())],
25 return_type: "Vec<f64>".to_string(),
26 body: "// fallback: echo input\n Ok(data)".to_string(),
27 fallback: true,
28 });
29
30 if use_ndarray {
32 for (_, ty) in &mut translation.params {
33 if ty == "Vec<f64>" {
34 *ty = "numpy::PyReadonlyArray1<f64>".to_string();
35 }
36 }
37 }
38
39 let len_check = if use_ndarray {
40 String::new()
41 } else {
42 render_len_checks(&translation.params).unwrap_or_default()
43 };
44
45 let params_rendered = translation
46 .params
47 .iter()
48 .map(|(n, t)| format!("{n}: {t}"))
49 .collect::<Vec<_>>()
50 .join(", ");
51
52 let ndarray_note = if use_ndarray {
53 "\n // ndarray: use p1.as_slice()? to get &[f64] for indexing"
54 } else {
55 ""
56 };
57
58 let rendered = format!(
59 "#[pyfunction]\n\
60 /// Auto-generated from Python hotspot `{orig}` at line {line} ({percent:.2}%): {reason}\n\
61pub fn {rust_name}(py: Python, {params}) -> PyResult<{ret}> {{{ndarray_note}\n let _ = py; // reserved for future GIL use\n {len_check}\n {body}\n}}\n",
62 orig = target.func,
63 line = target.line,
64 percent = target.percent,
65 reason = target.reason,
66 params = params_rendered,
67 ret = translation.return_type,
68 body = translation.body,
69 len_check = len_check,
70 ndarray_note = ndarray_note,
71 );
72
73 (rendered, translation.fallback)
74}
75
76pub fn render_lib_rs_with_options(functions: &[String], use_ndarray: bool) -> String {
78 let fns_joined = functions.join("\n");
79 let adders = functions
80 .iter()
81 .map(|f| extract_fn_name(f))
82 .map(|name| format!("m.add_function(wrap_pyfunction!({name}, m)?)?;"))
83 .collect::<Vec<_>>()
84 .join("\n ");
85 let ndarray_import = if use_ndarray { "use numpy;\n" } else { "" };
86 format!(
87 "use pyo3::prelude::*;\n{ndarray_import}\n{fns_joined}\n\
88#[pymodule]\n\
89fn rustify_ml_ext(_py: Python, m: &PyModule) -> PyResult<()> {{\n\
90 {adders}\n\
91 Ok(())\n\
92}}\n",
93 ndarray_import = ndarray_import,
94 fns_joined = fns_joined,
95 adders = adders
96 )
97}
98
99pub fn render_cargo_toml_with_options(use_ndarray: bool) -> String {
101 let numpy_dep = if use_ndarray {
102 "numpy = \"0.21\"\n"
103 } else {
104 ""
105 };
106 format!(
107 "[package]\n\
108name = \"rustify_ml_ext\"\n\
109version = \"0.1.0\"\n\
110edition = \"2024\"\n\
111\n\
112[lib]\n\
113name = \"rustify_ml_ext\"\n\
114crate-type = [\"cdylib\"]\n\
115\n\
116[dependencies]\n\
117pyo3 = {{ version = \"0.21\", features = [\"extension-module\"] }}\n\
118{numpy_dep}",
119 numpy_dep = numpy_dep
120 )
121}
122
123pub fn extract_fn_name(func_src: &str) -> String {
125 func_src
126 .lines()
127 .find_map(|l| l.strip_prefix("pub fn "))
128 .and_then(|rest| rest.split('(').next())
129 .unwrap_or("generated")
130 .to_string()
131}