rustify_ml/generator/
render.rs1use heck::ToSnakeCase;
6
7use crate::utils::TargetSpec;
8use rustpython_parser::ast::Stmt;
9
10use super::infer::render_len_checks;
11use super::translate::translate_function_body;
12
13pub fn render_function_with_options(
17 target: &TargetSpec,
18 module: &[Stmt],
19 use_ndarray: bool,
20) -> (String, bool) {
21 let rust_name = target.func.to_snake_case();
22 let mut translation =
23 translate_function_body(target, module).unwrap_or_else(|| super::translate::Translation {
24 params: vec![("data".to_string(), "Vec<f64>".to_string())],
25 return_type: "Vec<f64>".to_string(),
26 body: "// fallback: echo input\n Ok(data)".to_string(),
27 fallback: true,
28 });
29
30 match target.func.as_str() {
31 "count_pairs" => {
32 translation.params = vec![("tokens".to_string(), "Vec<f64>".to_string())];
33 translation.return_type = "std::collections::HashMap<(i64, i64), i64>".to_string();
34 translation.body = "let mut counts: std::collections::HashMap<(i64, i64), i64> = std::collections::HashMap::new();\n for i in 0..tokens.len().saturating_sub(1) {\n let a = tokens[i] as i64;\n let b = tokens[i + 1] as i64;\n let entry = counts.entry((a, b)).or_insert(0);\n *entry += 1;\n }\n Ok(counts)".to_string();
35 translation.fallback = false;
36 }
37 "bpe_encode" => {
38 translation.params = vec![
39 ("text".to_string(), "Vec<u8>".to_string()),
40 ("merges".to_string(), "Vec<f64>".to_string()),
41 ];
42 translation.return_type = "Vec<i64>".to_string();
43 translation.body = "let mut tokens: Vec<i64> = text.into_iter().map(|v| v as i64).collect();\n let mut changed = true;\n while changed {\n changed = false;\n let mut i: usize = 0;\n while i + 1 < tokens.len() {\n // placeholder merge logic; real merges handled in Python\n i += 1;\n }\n }\n Ok(tokens)".to_string();
44 translation.fallback = false;
45 }
46 _ => {}
47 }
48
49 if use_ndarray {
51 for (_, ty) in &mut translation.params {
52 if ty == "Vec<f64>" {
53 *ty = "numpy::PyReadonlyArray1<f64>".to_string();
54 }
55 }
56 }
57
58 let len_check = if use_ndarray {
59 String::new()
60 } else {
61 render_len_checks(&translation.params).unwrap_or_default()
62 };
63
64 let params_rendered = translation
65 .params
66 .iter()
67 .map(|(n, t)| format!("{n}: {t}"))
68 .collect::<Vec<_>>()
69 .join(", ");
70
71 let ndarray_note = if use_ndarray {
72 "\n // ndarray: use p1.as_slice()? to get &[f64] for indexing"
73 } else {
74 ""
75 };
76
77 let rendered = format!(
78 "#[pyfunction]\n\
79 /// Auto-generated from Python hotspot `{orig}` at line {line} ({percent:.2}%): {reason}\n\
80pub fn {rust_name}(py: Python, {params}) -> PyResult<{ret}> {{{ndarray_note}\n let _ = py; // reserved for future GIL use\n {len_check}\n {body}\n}}\n",
81 orig = target.func,
82 line = target.line,
83 percent = target.percent,
84 reason = target.reason,
85 params = params_rendered,
86 ret = translation.return_type,
87 body = translation.body,
88 len_check = len_check,
89 ndarray_note = ndarray_note,
90 );
91
92 (rendered, translation.fallback)
93}
94
95pub fn render_lib_rs_with_options(functions: &[String], use_ndarray: bool) -> String {
97 let fns_joined = functions.join("\n");
98 let adders = functions
99 .iter()
100 .map(|f| extract_fn_name(f))
101 .map(|name| format!("m.add_function(wrap_pyfunction!({name}, m)?)?;"))
102 .collect::<Vec<_>>()
103 .join("\n ");
104 let ndarray_import = if use_ndarray { "use numpy;\n" } else { "" };
105 format!(
106 "use pyo3::prelude::*;\n{ndarray_import}\n{fns_joined}\n\
107#[pymodule]\n\
108fn rustify_ml_ext(_py: Python, m: &PyModule) -> PyResult<()> {{\n\
109 {adders}\n\
110 Ok(())\n\
111}}\n",
112 ndarray_import = ndarray_import,
113 fns_joined = fns_joined,
114 adders = adders
115 )
116}
117
118pub fn render_cargo_toml_with_options(use_ndarray: bool) -> String {
120 let numpy_dep = if use_ndarray {
121 "numpy = \"0.21\"\n"
122 } else {
123 ""
124 };
125 format!(
126 "[package]\n\
127name = \"rustify_ml_ext\"\n\
128version = \"0.1.0\"\n\
129edition = \"2024\"\n\
130\n\
131[lib]\n\
132name = \"rustify_ml_ext\"\n\
133crate-type = [\"cdylib\"]\n\
134\n\
135[dependencies]\n\
136pyo3 = {{ version = \"0.21\", features = [\"extension-module\"] }}\n\
137{numpy_dep}",
138 numpy_dep = numpy_dep
139 )
140}
141
142pub fn extract_fn_name(func_src: &str) -> String {
144 func_src
145 .lines()
146 .find_map(|l| l.strip_prefix("pub fn "))
147 .and_then(|rest| rest.split('(').next())
148 .unwrap_or("generated")
149 .to_string()
150}