1pub mod expr;
14pub mod infer;
15pub mod render;
16pub mod translate;
17
18use std::fs;
19use std::path::Path;
20
21use anyhow::{Context, Result, anyhow};
22use rustpython_parser::Parse;
23use rustpython_parser::ast::{Stmt, Suite};
24use tracing::{info, warn};
25
26use crate::utils::{GenerationResult, InputSource, TargetSpec, extract_code};
27use render::{
28 render_cargo_toml_with_options, render_function_with_options, render_lib_rs_with_options,
29};
30
31fn detects_numpy(code: &str) -> bool {
33 code.contains("import numpy") || code.contains("from numpy") || code.contains("import np")
34}
35
36fn parse_existing_functions(lib_rs: &str) -> Vec<String> {
38 let mut funcs = Vec::new();
39 let mut current = Vec::new();
40 let mut in_fn = false;
41 let mut brace_balance: i32 = 0;
42 let mut seen_open = false;
43 for line in lib_rs.lines() {
44 let trimmed = line.trim_start();
45 if trimmed.starts_with("#[pyfunction]") {
46 current.clear();
47 in_fn = true;
48 brace_balance = 0;
49 seen_open = false;
50 current.push(line.to_string());
51 continue;
52 }
53
54 if in_fn {
55 current.push(line.to_string());
56 let opens = line.matches('{').count() as i32;
57 let closes = line.matches('}').count() as i32;
58 if opens > 0 {
59 seen_open = true;
60 }
61 brace_balance += opens;
62 brace_balance -= closes;
63
64 if seen_open && brace_balance <= 0 {
65 funcs.push(current.join("\n"));
66 current.clear();
67 in_fn = false;
68 }
69 }
70 }
71 funcs
72}
73
74pub fn generate(
76 source: &InputSource,
77 targets: &[TargetSpec],
78 output: &Path,
79 dry_run: bool,
80) -> Result<GenerationResult> {
81 generate_with_options(source, targets, output, dry_run, false)
82}
83
84pub fn generate_ml(
86 source: &InputSource,
87 targets: &[TargetSpec],
88 output: &Path,
89 dry_run: bool,
90) -> Result<GenerationResult> {
91 generate_with_options(source, targets, output, dry_run, true)
92}
93
94fn generate_with_options(
95 source: &InputSource,
96 targets: &[TargetSpec],
97 output: &Path,
98 dry_run: bool,
99 ml_mode: bool,
100) -> Result<GenerationResult> {
101 if targets.is_empty() {
102 return Err(anyhow!("no targets selected for generation"));
103 }
104
105 fs::create_dir_all(output)
106 .with_context(|| format!("failed to create output dir {}", output.display()))?;
107 let crate_dir = output.join("rustify_ml_ext");
108 if crate_dir.exists() {
109 info!(path = %crate_dir.display(), "reusing existing generated crate directory");
110 } else {
111 fs::create_dir_all(crate_dir.join("src")).context("failed to create crate directories")?;
112 }
113
114 let code = extract_code(source)?;
115 let use_ndarray = ml_mode && detects_numpy(&code);
116 if use_ndarray {
117 info!("numpy detected + ml_mode: using PyReadonlyArray1<f64> params");
118 }
119
120 let suite =
121 Suite::parse(&code, "<input>").context("failed to parse Python input for generation")?;
122 let stmts: &[Stmt] = suite.as_slice();
123
124 let mut functions_by_name: std::collections::HashMap<String, String> =
126 std::collections::HashMap::new();
127
128 let existing_lib = crate_dir.join("src/lib.rs");
129 if existing_lib.exists()
130 && let Ok(existing_src) = std::fs::read_to_string(&existing_lib)
131 {
132 for func_src in parse_existing_functions(&existing_src) {
133 let name = render::extract_fn_name(&func_src);
134 functions_by_name.insert(name, func_src);
135 }
136 }
137
138 let mut fallback_functions = 0usize;
139 for t in targets.iter() {
140 let (code, fallback) = render_function_with_options(t, stmts, use_ndarray);
141 let name = render::extract_fn_name(&code);
142 if fallback {
143 fallback_functions += 1;
144 }
145 functions_by_name.insert(name, code);
146 }
147
148 let functions: Vec<String> = functions_by_name.into_values().collect();
149
150 let lib_rs = render_lib_rs_with_options(&functions, use_ndarray);
151 let cargo_toml = render_cargo_toml_with_options(use_ndarray);
152
153 fs::write(crate_dir.join("src/lib.rs"), lib_rs).context("failed to write lib.rs")?;
154 fs::write(crate_dir.join("Cargo.toml"), cargo_toml).context("failed to write Cargo.toml")?;
155
156 if dry_run {
157 info!(path = %crate_dir.display(), "dry-run: wrote generated files (no build)");
158 }
159 if fallback_functions > 0 {
160 warn!(
161 fallback_functions,
162 "some functions fell back to echo translation"
163 );
164 }
165 info!(path = %crate_dir.display(), funcs = functions.len(), "generated Rust stubs");
166
167 Ok(GenerationResult {
168 crate_dir,
169 generated_functions: functions,
170 fallback_functions,
171 })
172}
173
174#[cfg(test)]
177mod tests {
178 use std::path::PathBuf;
179
180 use rustpython_parser::Parse;
181 use rustpython_parser::ast::{Expr, Operator, Suite};
182 use rustpython_parser::text_size::TextRange;
183 use tempfile::tempdir;
184
185 use crate::utils::{InputSource, TargetSpec};
186
187 use super::expr::expr_to_rust;
188 use super::infer::render_len_checks;
189 use super::translate::translate_function_body;
190 use super::*;
191
192 #[test]
195 fn test_expr_to_rust_range_and_len() {
196 let range_expr = Expr::Call(rustpython_parser::ast::ExprCall {
197 func: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
198 range: TextRange::default(),
199 id: "range".into(),
200 ctx: rustpython_parser::ast::ExprContext::Load,
201 })),
202 args: vec![Expr::Constant(rustpython_parser::ast::ExprConstant {
203 range: TextRange::default(),
204 value: rustpython_parser::ast::Constant::Int(10.into()),
205 kind: None,
206 })],
207 keywords: vec![],
208 range: TextRange::default(),
209 });
210 let len_expr = Expr::Call(rustpython_parser::ast::ExprCall {
211 func: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
212 range: TextRange::default(),
213 id: "len".into(),
214 ctx: rustpython_parser::ast::ExprContext::Load,
215 })),
216 args: vec![Expr::Name(rustpython_parser::ast::ExprName {
217 range: TextRange::default(),
218 id: "a".into(),
219 ctx: rustpython_parser::ast::ExprContext::Load,
220 })],
221 keywords: vec![],
222 range: TextRange::default(),
223 });
224 assert_eq!(expr_to_rust(&range_expr), "0..10");
225 assert_eq!(expr_to_rust(&len_expr), "a.len()");
226 }
227
228 #[test]
229 fn test_expr_to_rust_binop_pow() {
230 let bin = Expr::BinOp(rustpython_parser::ast::ExprBinOp {
231 range: TextRange::default(),
232 left: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
233 range: TextRange::default(),
234 id: "x".into(),
235 ctx: rustpython_parser::ast::ExprContext::Load,
236 })),
237 op: Operator::Pow,
238 right: Box::new(Expr::Constant(rustpython_parser::ast::ExprConstant {
239 range: TextRange::default(),
240 value: rustpython_parser::ast::Constant::Int(2.into()),
241 kind: None,
242 })),
243 });
244 assert_eq!(expr_to_rust(&bin), "(x).powf(2)");
245 }
246
247 #[test]
250 fn test_render_len_checks_multiple_vecs() {
251 let params = vec![
252 ("a".to_string(), "Vec<f64>".to_string()),
253 ("b".to_string(), "Vec<f64>".to_string()),
254 ];
255 let rendered = render_len_checks(¶ms).unwrap();
256 assert!(rendered.contains("a.len() != b.len()"));
257 assert!(rendered.contains("PyValueError"));
258 }
259
260 #[test]
263 fn test_translate_euclidean_body() {
264 let code = r#"
265def euclidean(p1, p2):
266 total = 0.0
267 for i in range(len(p1)):
268 diff = p1[i] - p2[i]
269 total += diff * diff
270 return total ** 0.5
271"#;
272 let suite = Suite::parse(code, "<test>").expect("parse failed");
273 let target = TargetSpec {
274 func: "euclidean".to_string(),
275 line: 1,
276 percent: 100.0,
277 reason: "test".to_string(),
278 };
279 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
280 assert_eq!(t.return_type, "f64");
281 assert!(!t.fallback);
282 assert!(t.body.contains("for i in 0.."));
283 }
284
285 #[test]
286 fn test_translate_stmt_float_assign_init() {
287 let code = "def f(x):\n total = 0.0\n return total\n";
288 let suite = Suite::parse(code, "<test>").expect("parse");
289 let target = TargetSpec {
290 func: "f".to_string(),
291 line: 1,
292 percent: 100.0,
293 reason: "test".to_string(),
294 };
295 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
296 assert!(t.body.contains("let mut total: f64"), "got: {}", t.body);
297 }
298
299 #[test]
300 fn test_translate_stmt_subscript_assign() {
301 let code = "def f(result, i, val):\n result[i] = val\n return result\n";
302 let suite = Suite::parse(code, "<test>").expect("parse");
303 let target = TargetSpec {
304 func: "f".to_string(),
305 line: 1,
306 percent: 100.0,
307 reason: "test".to_string(),
308 };
309 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
310 assert!(t.body.contains("result[i] = val"), "got: {}", t.body);
311 }
312
313 #[test]
314 fn test_translate_stmt_list_init() {
315 let code = "def f(n):\n result = [0.0] * n\n return result\n";
316 let suite = Suite::parse(code, "<test>").expect("parse");
317 let target = TargetSpec {
318 func: "f".to_string(),
319 line: 1,
320 percent: 100.0,
321 reason: "test".to_string(),
322 };
323 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
324 assert!(
325 t.body.contains("vec![") && t.body.contains("result"),
326 "got: {}",
327 t.body
328 );
329 }
330
331 #[test]
332 fn test_translate_list_comprehension() {
333 let code = "def f(data):\n result = [x * 2.0 for x in data]\n return result\n";
335 let suite = Suite::parse(code, "<test>").expect("parse");
336 let target = TargetSpec {
337 func: "f".to_string(),
338 line: 1,
339 percent: 100.0,
340 reason: "test".to_string(),
341 };
342 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
343 assert!(
344 t.body.contains(".iter().map(") && t.body.contains(".collect()"),
345 "expected iter().map().collect() for list comp, got: {}",
346 t.body
347 );
348 assert_eq!(t.return_type, "Vec<f64>");
349 assert!(!t.fallback);
350 }
351
352 #[test]
353 fn test_translate_nested_for_else_triggers_fallback() {
354 let code = "def f(n):\n total = 0\n for i in range(n):\n for j in range(n):\n total += i + j\n else:\n total += 1\n return total\n";
356 let suite = Suite::parse(code, "<test>").expect("parse");
357 let target = TargetSpec {
358 func: "f".to_string(),
359 line: 1,
360 percent: 100.0,
361 reason: "test nested for".to_string(),
362 };
363 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
364 assert!(t.fallback, "expected fallback for for..else, got body:\n{}", t.body);
365 }
366
367 #[test]
368 fn test_translate_dot_product_zero_fallback() {
369 let code = "def dot_product(a, b):\n total = 0.0\n for i in range(len(a)):\n total += a[i] * b[i]\n return total\n";
370 let suite = Suite::parse(code, "<test>").expect("parse");
371 let target = TargetSpec {
372 func: "dot_product".to_string(),
373 line: 1,
374 percent: 80.0,
375 reason: "test".to_string(),
376 };
377 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
378 assert!(
379 !t.fallback,
380 "dot_product should not fallback; body:\n{}",
381 t.body
382 );
383 assert!(t.body.contains("total +="), "got: {}", t.body);
384 }
385
386 #[test]
387 fn test_translate_matmul_nested_loops() {
388 let code = r#"
389def matmul(a, b, n):
390 result = [0.0] * (n * n)
391 for i in range(n):
392 for j in range(n):
393 total = 0.0
394 for k in range(n):
395 total += a[i * n + k] * b[k * n + j]
396 result[i * n + j] = total
397 return result
398"#;
399 let suite = Suite::parse(code, "<test>").expect("parse");
400 let target = TargetSpec {
401 func: "matmul".to_string(),
402 line: 1,
403 percent: 100.0,
404 reason: "test".to_string(),
405 };
406 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
407 assert!(t.body.contains("for i in 0..n"), "got: {}", t.body);
408 assert!(t.body.contains("for j in 0..n"), "got: {}", t.body);
409 assert!(t.body.contains("for k in 0..n"), "got: {}", t.body);
410 assert!(t.body.contains("vec!["), "got: {}", t.body);
411 }
412
413 #[test]
414 fn test_translate_while_loop_bool_flag() {
415 let code = "def count_pairs(tokens):\n counts = 0\n changed = True\n while changed:\n changed = False\n counts += 1\n return counts\n";
416 let suite = Suite::parse(code, "<test>").expect("parse");
417 let target = TargetSpec {
418 func: "count_pairs".to_string(),
419 line: 1,
420 percent: 100.0,
421 reason: "test".to_string(),
422 };
423 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
424 assert!(t.body.contains("while changed"), "got:\n{}", t.body);
425 }
426
427 #[test]
428 fn test_translate_while_comparison() {
429 let code = "def scan(tokens):\n i = 0\n while i < len(tokens):\n i += 1\n return i\n";
430 let suite = Suite::parse(code, "<test>").expect("parse");
431 let target = TargetSpec {
432 func: "scan".to_string(),
433 line: 1,
434 percent: 100.0,
435 reason: "test".to_string(),
436 };
437 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
438 assert!(t.body.contains("while i <"), "got:\n{}", t.body);
439 assert!(t.body.contains("tokens.len()"), "got:\n{}", t.body);
440 }
441
442 #[test]
445 fn test_ndarray_mode_replaces_vec_params() {
446 let code = "import numpy as np\ndef dot_product(a, b):\n total = 0.0\n for i in range(len(a)):\n total += a[i] * b[i]\n return total\n";
447 let source = InputSource::Snippet(code.to_string());
448 let targets = vec![TargetSpec {
449 func: "dot_product".to_string(),
450 line: 1,
451 percent: 100.0,
452 reason: "test".to_string(),
453 }];
454 let tmp = tempdir().expect("tempdir");
455 let result = generate_ml(&source, &targets, tmp.path(), false).expect("generate_ml");
456 let lib_rs = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/src/lib.rs")).unwrap();
457 assert_eq!(result.fallback_functions, 0);
458 assert!(lib_rs.contains("PyReadonlyArray1<f64>"), "got:\n{}", lib_rs);
459 assert!(lib_rs.contains("use numpy;"), "got:\n{}", lib_rs);
460 let cargo = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/Cargo.toml")).unwrap();
461 assert!(cargo.contains("numpy"), "got:\n{}", cargo);
462 }
463
464 #[test]
465 fn test_ndarray_mode_no_numpy_import_stays_vec() {
466 let code = "def dot_product(a, b):\n total = 0.0\n for i in range(len(a)):\n total += a[i] * b[i]\n return total\n";
467 let source = InputSource::Snippet(code.to_string());
468 let targets = vec![TargetSpec {
469 func: "dot_product".to_string(),
470 line: 1,
471 percent: 100.0,
472 reason: "test".to_string(),
473 }];
474 let tmp = tempdir().expect("tempdir");
475 let result = generate_ml(&source, &targets, tmp.path(), false).expect("generate_ml");
476 let lib_rs = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/src/lib.rs")).unwrap();
477 assert_eq!(result.fallback_functions, 0);
478 assert!(!lib_rs.contains("PyReadonlyArray1"), "got:\n{}", lib_rs);
479 assert!(lib_rs.contains("Vec<f64>"), "got:\n{}", lib_rs);
480 }
481
482 #[test]
485 fn test_generate_integration_euclidean() {
486 let path = PathBuf::from("examples/euclidean.py");
487 let code = std::fs::read_to_string(&path).expect("read example");
488 let source = InputSource::File {
489 path: path.clone(),
490 code,
491 };
492 let targets = vec![TargetSpec {
493 func: "euclidean".to_string(),
494 line: 1,
495 percent: 100.0,
496 reason: "test".to_string(),
497 }];
498 let tmp = tempdir().expect("tempdir");
499 let result = generate(&source, &targets, tmp.path(), false).expect("generate");
500 assert_eq!(result.generated_functions.len(), 1);
501 assert_eq!(result.fallback_functions, 0);
502 assert!(tmp.path().join("rustify_ml_ext/src/lib.rs").exists());
503 }
504}