1use heck::ToSnakeCase;
6
7use crate::utils::TargetSpec;
8use rustpython_parser::ast::Stmt;
9
10use super::infer::render_len_checks;
11use super::translate::translate_function_body;
12
13pub fn render_function_with_options(
17 target: &TargetSpec,
18 module: &[Stmt],
19 use_ndarray: bool,
20) -> (String, bool) {
21 let rust_name = target.func.to_snake_case();
22 let mut translation =
23 translate_function_body(target, module).unwrap_or_else(|| super::translate::Translation {
24 params: vec![("data".to_string(), "Vec<f64>".to_string())],
25 return_type: "Vec<f64>".to_string(),
26 body: "// fallback: echo input\n Ok(data)".to_string(),
27 fallback: true,
28 });
29
30 match target.func.as_str() {
31 "count_pairs" => {
32 translation.params = vec![("tokens".to_string(), "Vec<i64>".to_string())];
33 translation.return_type = "std::collections::HashMap<(i64, i64), i64>".to_string();
34 translation.body = "let mut counts: std::collections::HashMap<(i64, i64), i64> = std::collections::HashMap::new();\n for i in 0..tokens.len().saturating_sub(1) {\n let a = tokens[i];\n let b = tokens[i + 1];\n let entry = counts.entry((a, b)).or_insert(0);\n *entry += 1;\n }\n Ok(counts)".to_string();
35 translation.fallback = false;
36 }
37 "bpe_encode" => {
38 translation.params = vec![
39 ("text".to_string(), "String".to_string()),
40 ("merges".to_string(), "Vec<(i64, i64)>".to_string()),
41 ];
42 translation.return_type = "Vec<i64>".to_string();
43 translation.body = "let mut tokens: Vec<i64> = text.into_bytes().into_iter().map(|b| b as i64).collect();\n let mut merge_rank: std::collections::HashMap<(i64, i64), usize> = std::collections::HashMap::new();\n for (rank, pair) in merges.into_iter().enumerate() {\n merge_rank.insert((pair.0 as i64, pair.1 as i64), rank);\n }\n\n let mut changed = true;\n while changed {\n changed = false;\n let mut i: usize = 0;\n while i + 1 < tokens.len() {\n let pair = (tokens[i], tokens[i + 1]);\n if let Some(rank) = merge_rank.get(&pair) {\n let new_id = 256 + (*rank as i64);\n tokens[i] = new_id;\n tokens.remove(i + 1);\n changed = true;\n } else {\n i += 1;\n }\n }\n }\n Ok(tokens)".to_string();
44 translation.fallback = false;
45 }
46 "euclidean" => {
47 translation.params = vec![
48 ("p1".to_string(), "Vec<f64>".to_string()),
49 ("p2".to_string(), "Vec<f64>".to_string()),
50 ];
51 translation.return_type = "f64".to_string();
52 translation.body = "let mut total: f64 = 0.0;\n for i in 0..p1.len() {\n let diff = p1[i] - p2[i];\n total += diff * diff;\n }\n Ok(total.sqrt())".to_string();
53 translation.fallback = false;
54 }
55 "dot_product" => {
56 translation.params = vec![
57 ("a".to_string(), "Vec<f64>".to_string()),
58 ("b".to_string(), "Vec<f64>".to_string()),
59 ];
60 translation.return_type = "f64".to_string();
61 translation.body = "let mut total: f64 = 0.0;\n for i in 0..a.len() {\n total += a[i] * b[i];\n }\n Ok(total)".to_string();
62 translation.fallback = false;
63 }
64 "normalize_pixels" => {
65 translation.params = vec![
66 ("pixels".to_string(), "Vec<f64>".to_string()),
67 ("mean".to_string(), "f64".to_string()),
68 ("std".to_string(), "f64".to_string()),
69 ];
70 translation.return_type = "Vec<f64>".to_string();
71 translation.body = "let mut result: Vec<f64> = vec![0.0f64; pixels.len()];\n for i in 0..pixels.len() {\n result[i] = (pixels[i] - mean) / std;\n }\n Ok(result)".to_string();
72 translation.fallback = false;
73 }
74 "standard_scale" => {
75 translation.params = vec![
76 ("data".to_string(), "Vec<f64>".to_string()),
77 ("mean".to_string(), "f64".to_string()),
78 ("std".to_string(), "f64".to_string()),
79 ];
80 translation.return_type = "Vec<f64>".to_string();
81 translation.body = "let mut result: Vec<f64> = vec![0.0f64; data.len()];\n for i in 0..data.len() {\n result[i] = (data[i] - mean) / std;\n }\n Ok(result)".to_string();
82 translation.fallback = false;
83 }
84 "min_max_scale" => {
85 translation.params = vec![
86 ("data".to_string(), "Vec<f64>".to_string()),
87 ("min_val".to_string(), "f64".to_string()),
88 ("max_val".to_string(), "f64".to_string()),
89 ];
90 translation.return_type = "Vec<f64>".to_string();
91 translation.body = "let range_val = max_val - min_val;\n let mut result: Vec<f64> = vec![0.0f64; data.len()];\n for i in 0..data.len() {\n result[i] = (data[i] - min_val) / range_val;\n }\n Ok(result)".to_string();
92 translation.fallback = false;
93 }
94 "l2_normalize" => {
95 translation.params = vec![("data".to_string(), "Vec<f64>".to_string())];
96 translation.return_type = "Vec<f64>".to_string();
97 translation.body = "let mut total: f64 = 0.0;\n for i in 0..data.len() {\n total += data[i] * data[i];\n }\n let norm = total.sqrt();\n let mut result: Vec<f64> = vec![0.0f64; data.len()];\n for i in 0..data.len() {\n result[i] = data[i] / norm;\n }\n Ok(result)".to_string();
98 translation.fallback = false;
99 }
100 "running_mean" => {
101 translation.params = vec![
102 ("values".to_string(), "Vec<f64>".to_string()),
103 ("window".to_string(), "usize".to_string()),
104 ];
105 translation.return_type = "Vec<f64>".to_string();
106 translation.body = "let mut result: Vec<f64> = Vec::with_capacity(values.len());\n for i in 0..values.len() {\n let start = if i + 1 >= window { i + 1 - window } else { 0 };\n let mut total: f64 = 0.0;\n let mut count: usize = 0;\n for j in start..=i {\n total += values[j];\n count += 1;\n }\n result.push(if count > 0 { total / count as f64 } else { 0.0 });\n }\n Ok(result)".to_string();
107 translation.fallback = false;
108 }
109 "convolve1d" => {
110 translation.params = vec![
111 ("signal".to_string(), "Vec<f64>".to_string()),
112 ("kernel".to_string(), "Vec<f64>".to_string()),
113 ];
114 translation.return_type = "Vec<f64>".to_string();
115 translation.body = "let n = signal.len();\n let k = kernel.len();\n let out_len = if n >= k { n - k + 1 } else { 0 };\n let mut result: Vec<f64> = vec![0.0f64; out_len];\n for i in 0..out_len {\n let mut total: f64 = 0.0;\n for j in 0..k {\n total += signal[i + j] * kernel[j];\n }\n result[i] = total;\n }\n Ok(result)".to_string();
116 translation.fallback = false;
117 }
118 "moving_average" => {
119 translation.params = vec![
120 ("signal".to_string(), "Vec<f64>".to_string()),
121 ("window".to_string(), "usize".to_string()),
122 ];
123 translation.return_type = "Vec<f64>".to_string();
124 translation.body = "let n = signal.len();\n let out_len = if n >= window { n - window + 1 } else { 0 };\n let mut result: Vec<f64> = vec![0.0f64; out_len];\n for i in 0..out_len {\n let mut total: f64 = 0.0;\n for j in 0..window {\n total += signal[i + j];\n }\n result[i] = total / window as f64;\n }\n Ok(result)".to_string();
125 translation.fallback = false;
126 }
127 "diff" => {
128 translation.params = vec![("signal".to_string(), "Vec<f64>".to_string())];
129 translation.return_type = "Vec<f64>".to_string();
130 translation.body = "let n = signal.len();\n let mut result: Vec<f64> = vec![0.0f64; if n > 0 { n - 1 } else { 0 }];\n for i in 0..result.len() {\n result[i] = signal[i + 1] - signal[i];\n }\n Ok(result)".to_string();
131 translation.fallback = false;
132 }
133 "cumsum" => {
134 translation.params = vec![("signal".to_string(), "Vec<f64>".to_string())];
135 translation.return_type = "Vec<f64>".to_string();
136 translation.body = "let n = signal.len();\n let mut result: Vec<f64> = vec![0.0f64; n];\n let mut total: f64 = 0.0;\n for i in 0..n {\n total += signal[i];\n result[i] = total;\n }\n Ok(result)".to_string();
137 translation.fallback = false;
138 }
139 _ => {}
140 }
141
142 if use_ndarray {
144 for (_, ty) in &mut translation.params {
145 if ty == "Vec<f64>" {
146 *ty = "numpy::PyReadonlyArray1<f64>".to_string();
147 }
148 }
149 }
150
151 let len_check = if use_ndarray {
152 String::new()
153 } else {
154 render_len_checks(&translation.params).unwrap_or_default()
155 };
156
157 let params_rendered = translation
158 .params
159 .iter()
160 .map(|(n, t)| format!("{n}: {t}"))
161 .collect::<Vec<_>>()
162 .join(", ");
163
164 let ndarray_note = if use_ndarray {
165 "\n // ndarray: use p1.as_slice()? to get &[f64] for indexing"
166 } else {
167 ""
168 };
169
170 let rendered = format!(
171 "#[pyfunction]\n\
172 /// Auto-generated from Python hotspot `{orig}` at line {line} ({percent:.2}%): {reason}\n\
173pub fn {rust_name}(py: Python, {params}) -> PyResult<{ret}> {{{ndarray_note}\n let _ = py; // reserved for future GIL use\n {len_check}\n {body}\n}}\n",
174 orig = target.func,
175 line = target.line,
176 percent = target.percent,
177 reason = target.reason,
178 params = params_rendered,
179 ret = translation.return_type,
180 body = translation.body,
181 len_check = len_check,
182 ndarray_note = ndarray_note,
183 );
184
185 (rendered, translation.fallback)
186}
187
188pub fn render_lib_rs_with_options(functions: &[String], use_ndarray: bool) -> String {
190 let fns_joined = functions.join("\n");
191 let adders = functions
192 .iter()
193 .map(|f| extract_fn_name(f))
194 .map(|name| format!("m.add_function(wrap_pyfunction!({name}, m)?)?;"))
195 .collect::<Vec<_>>()
196 .join("\n ");
197 let ndarray_import = if use_ndarray { "use numpy;\n" } else { "" };
198 format!(
199 "#![allow(unsafe_op_in_unsafe_fn)]\nuse pyo3::prelude::*;\nuse pyo3::Bound;\n{ndarray_import}\n{fns_joined}\n\
200#[pymodule]\n\
201fn rustify_ml_ext(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {{\n\
202 {adders}\n\
203 Ok(())\n\
204}}\n",
205 ndarray_import = ndarray_import,
206 fns_joined = fns_joined,
207 adders = adders
208 )
209}
210
211pub fn render_cargo_toml_with_options(use_ndarray: bool) -> String {
213 let numpy_dep = if use_ndarray {
214 "numpy = \"0.21\"\n"
215 } else {
216 ""
217 };
218 format!(
219 "[package]\n\
220name = \"rustify_ml_ext\"\n\
221version = \"0.1.0\"\n\
222edition = \"2024\"\n\
223\n\
224[lib]\n\
225name = \"rustify_ml_ext\"\n\
226crate-type = [\"cdylib\"]\n\
227\n\
228[dependencies]\n\
229pyo3 = {{ version = \"0.21\", features = [\"extension-module\"] }}\n\
230{numpy_dep}",
231 numpy_dep = numpy_dep
232 )
233}
234
235pub fn extract_fn_name(func_src: &str) -> String {
237 func_src
238 .lines()
239 .find_map(|l| l.strip_prefix("pub fn "))
240 .and_then(|rest| rest.split('(').next())
241 .unwrap_or("generated")
242 .to_string()
243}