1pub mod expr;
14pub mod infer;
15pub mod render;
16pub mod translate;
17
18use std::fs;
19use std::path::Path;
20
21use anyhow::{Context, Result, anyhow};
22use rustpython_parser::Parse;
23use rustpython_parser::ast::{Stmt, Suite};
24use tracing::{info, warn};
25
26use crate::utils::{GenerationResult, InputSource, TargetSpec, extract_code};
27use render::{
28 render_cargo_toml_with_options, render_function_with_options, render_lib_rs_with_options,
29};
30
31fn detects_numpy(code: &str) -> bool {
33 code.contains("import numpy")
34 || code.contains("from numpy")
35 || code.contains("import np")
36 || code.contains("np.ndarray")
37 || code.contains("numpy.ndarray")
38 || code.contains("np.array")
39}
40
41fn parse_existing_functions(lib_rs: &str) -> Vec<String> {
43 let mut funcs = Vec::new();
44 let mut current = Vec::new();
45 let mut in_fn = false;
46 let mut brace_balance: i32 = 0;
47 let mut seen_open = false;
48 for line in lib_rs.lines() {
49 let trimmed = line.trim_start();
50 if trimmed.starts_with("#[pyfunction]") {
51 current.clear();
52 in_fn = true;
53 brace_balance = 0;
54 seen_open = false;
55 current.push(line.to_string());
56 continue;
57 }
58
59 if in_fn {
60 current.push(line.to_string());
61 let opens = line.matches('{').count() as i32;
62 let closes = line.matches('}').count() as i32;
63 if opens > 0 {
64 seen_open = true;
65 }
66 brace_balance += opens;
67 brace_balance -= closes;
68
69 if seen_open && brace_balance <= 0 {
70 funcs.push(current.join("\n"));
71 current.clear();
72 in_fn = false;
73 }
74 }
75 }
76 funcs
77}
78
79pub fn generate(
81 source: &InputSource,
82 targets: &[TargetSpec],
83 output: &Path,
84 dry_run: bool,
85) -> Result<GenerationResult> {
86 generate_with_options(source, targets, output, dry_run, false)
87}
88
89pub fn generate_ml(
91 source: &InputSource,
92 targets: &[TargetSpec],
93 output: &Path,
94 dry_run: bool,
95) -> Result<GenerationResult> {
96 generate_with_options(source, targets, output, dry_run, true)
97}
98
99fn generate_with_options(
100 source: &InputSource,
101 targets: &[TargetSpec],
102 output: &Path,
103 dry_run: bool,
104 ml_mode: bool,
105) -> Result<GenerationResult> {
106 if targets.is_empty() {
107 return Err(anyhow!("no targets selected for generation"));
108 }
109
110 fs::create_dir_all(output)
111 .with_context(|| format!("failed to create output dir {}", output.display()))?;
112 let crate_dir = output.join("rustify_ml_ext");
113 if crate_dir.exists() {
114 info!(path = %crate_dir.display(), "reusing existing generated crate directory");
115 } else {
116 fs::create_dir_all(crate_dir.join("src")).context("failed to create crate directories")?;
117 }
118
119 let code = extract_code(source)?;
120 let use_ndarray = ml_mode && detects_numpy(&code);
121 if use_ndarray {
122 info!("numpy detected + ml_mode: using PyReadonlyArray1<f64> params");
123 }
124
125 let suite =
126 Suite::parse(&code, "<input>").context("failed to parse Python input for generation")?;
127 let stmts: &[Stmt] = suite.as_slice();
128
129 let mut functions_by_name: std::collections::HashMap<String, String> =
131 std::collections::HashMap::new();
132
133 let existing_lib = crate_dir.join("src/lib.rs");
134 if existing_lib.exists()
135 && let Ok(existing_src) = std::fs::read_to_string(&existing_lib)
136 {
137 for func_src in parse_existing_functions(&existing_src) {
138 let name = render::extract_fn_name(&func_src);
139 functions_by_name.insert(name, func_src);
140 }
141 }
142
143 let mut fallback_functions = 0usize;
144 for t in targets.iter() {
145 let (code, fallback) = render_function_with_options(t, stmts, use_ndarray);
146 let name = render::extract_fn_name(&code);
147 if fallback {
148 fallback_functions += 1;
149 }
150 functions_by_name.insert(name, code);
151 }
152
153 let functions: Vec<String> = functions_by_name.into_values().collect();
154
155 let lib_rs = render_lib_rs_with_options(&functions, use_ndarray);
156 let cargo_toml = render_cargo_toml_with_options(use_ndarray);
157
158 fs::write(crate_dir.join("src/lib.rs"), lib_rs).context("failed to write lib.rs")?;
159 fs::write(crate_dir.join("Cargo.toml"), cargo_toml).context("failed to write Cargo.toml")?;
160
161 if dry_run {
162 info!(path = %crate_dir.display(), "dry-run: wrote generated files (no build)");
163 }
164 if fallback_functions > 0 {
165 warn!(
166 fallback_functions,
167 "some functions fell back to echo translation"
168 );
169 }
170 info!(path = %crate_dir.display(), funcs = functions.len(), "generated Rust stubs");
171
172 Ok(GenerationResult {
173 crate_dir,
174 generated_functions: functions,
175 fallback_functions,
176 })
177}
178
179#[cfg(test)]
182mod tests {
183 use std::path::PathBuf;
184
185 use rustpython_parser::Parse;
186 use rustpython_parser::ast::{Expr, Operator, Suite};
187 use rustpython_parser::text_size::TextRange;
188 use tempfile::tempdir;
189
190 use crate::utils::{InputSource, TargetSpec};
191
192 use super::expr::expr_to_rust;
193 use super::infer::render_len_checks;
194 use super::translate::translate_function_body;
195 use super::*;
196
197 #[test]
200 fn test_expr_to_rust_range_and_len() {
201 let range_expr = Expr::Call(rustpython_parser::ast::ExprCall {
202 func: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
203 range: TextRange::default(),
204 id: "range".into(),
205 ctx: rustpython_parser::ast::ExprContext::Load,
206 })),
207 args: vec![Expr::Constant(rustpython_parser::ast::ExprConstant {
208 range: TextRange::default(),
209 value: rustpython_parser::ast::Constant::Int(10.into()),
210 kind: None,
211 })],
212 keywords: vec![],
213 range: TextRange::default(),
214 });
215 let len_expr = Expr::Call(rustpython_parser::ast::ExprCall {
216 func: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
217 range: TextRange::default(),
218 id: "len".into(),
219 ctx: rustpython_parser::ast::ExprContext::Load,
220 })),
221 args: vec![Expr::Name(rustpython_parser::ast::ExprName {
222 range: TextRange::default(),
223 id: "a".into(),
224 ctx: rustpython_parser::ast::ExprContext::Load,
225 })],
226 keywords: vec![],
227 range: TextRange::default(),
228 });
229 assert_eq!(expr_to_rust(&range_expr), "0..10");
230 assert_eq!(expr_to_rust(&len_expr), "a.len()");
231 }
232
233 #[test]
234 fn test_expr_to_rust_binop_pow() {
235 let bin = Expr::BinOp(rustpython_parser::ast::ExprBinOp {
236 range: TextRange::default(),
237 left: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
238 range: TextRange::default(),
239 id: "x".into(),
240 ctx: rustpython_parser::ast::ExprContext::Load,
241 })),
242 op: Operator::Pow,
243 right: Box::new(Expr::Constant(rustpython_parser::ast::ExprConstant {
244 range: TextRange::default(),
245 value: rustpython_parser::ast::Constant::Int(2.into()),
246 kind: None,
247 })),
248 });
249 assert_eq!(expr_to_rust(&bin), "(x).powf(2)");
250 }
251
252 #[test]
255 fn test_render_len_checks_multiple_vecs() {
256 let params = vec![
257 ("a".to_string(), "Vec<f64>".to_string()),
258 ("b".to_string(), "Vec<f64>".to_string()),
259 ];
260 let rendered = render_len_checks(¶ms).unwrap();
261 assert!(rendered.contains("a.len() != b.len()"));
262 assert!(rendered.contains("PyValueError"));
263 }
264
265 #[test]
268 fn test_translate_euclidean_body() {
269 let code = r#"
270def euclidean(p1, p2):
271 total = 0.0
272 for i in range(len(p1)):
273 diff = p1[i] - p2[i]
274 total += diff * diff
275 return total ** 0.5
276"#;
277 let suite = Suite::parse(code, "<test>").expect("parse failed");
278 let target = TargetSpec {
279 func: "euclidean".to_string(),
280 line: 1,
281 percent: 100.0,
282 reason: "test".to_string(),
283 };
284 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
285 assert_eq!(t.return_type, "f64");
286 assert!(!t.fallback);
287 assert!(t.body.contains("for i in 0.."));
288 }
289
290 #[test]
291 fn test_translate_stmt_float_assign_init() {
292 let code = "def f(x):\n total = 0.0\n return total\n";
293 let suite = Suite::parse(code, "<test>").expect("parse");
294 let target = TargetSpec {
295 func: "f".to_string(),
296 line: 1,
297 percent: 100.0,
298 reason: "test".to_string(),
299 };
300 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
301 assert!(t.body.contains("let mut total: f64"), "got: {}", t.body);
302 }
303
304 #[test]
305 fn test_translate_stmt_subscript_assign() {
306 let code = "def f(result, i, val):\n result[i] = val\n return result\n";
307 let suite = Suite::parse(code, "<test>").expect("parse");
308 let target = TargetSpec {
309 func: "f".to_string(),
310 line: 1,
311 percent: 100.0,
312 reason: "test".to_string(),
313 };
314 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
315 assert!(t.body.contains("result[i] = val"), "got: {}", t.body);
316 }
317
318 #[test]
319 fn test_translate_stmt_list_init() {
320 let code = "def f(n):\n result = [0.0] * n\n return result\n";
321 let suite = Suite::parse(code, "<test>").expect("parse");
322 let target = TargetSpec {
323 func: "f".to_string(),
324 line: 1,
325 percent: 100.0,
326 reason: "test".to_string(),
327 };
328 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
329 assert!(
330 t.body.contains("vec![") && t.body.contains("result"),
331 "got: {}",
332 t.body
333 );
334 }
335
336 #[test]
337 fn test_translate_list_comprehension() {
338 let code = "def f(data):\n result = [x * 2.0 for x in data]\n return result\n";
340 let suite = Suite::parse(code, "<test>").expect("parse");
341 let target = TargetSpec {
342 func: "f".to_string(),
343 line: 1,
344 percent: 100.0,
345 reason: "test".to_string(),
346 };
347 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
348 assert!(
349 t.body.contains(".iter().map(") && t.body.contains(".collect()"),
350 "expected iter().map().collect() for list comp, got: {}",
351 t.body
352 );
353 assert_eq!(t.return_type, "Vec<f64>");
354 assert!(!t.fallback);
355 }
356
357 #[test]
358 fn test_translate_argmax_tuple_return() {
359 let code = "def argmax(xs):\n best_idx = 0\n best_val = xs[0]\n for i in range(len(xs)):\n if xs[i] > best_val:\n best_val = xs[i]\n best_idx = i\n return (best_idx, best_val)\n";
361 let suite = Suite::parse(code, "<test>").expect("parse");
362 let target = TargetSpec {
363 func: "argmax".to_string(),
364 line: 1,
365 percent: 100.0,
366 reason: "argmax tuple".to_string(),
367 };
368 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
369 assert_eq!(t.return_type, "(usize, f64)");
370 assert!(
371 t.body.contains("return Ok((best_idx, best_val));"),
372 "body: {}",
373 t.body
374 );
375 assert!(!t.fallback);
376 }
377
378 #[test]
379 fn test_translate_nested_for_else_triggers_fallback() {
380 let code = "def f(n):\n total = 0\n for i in range(n):\n for j in range(n):\n total += i + j\n else:\n total += 1\n return total\n";
382 let suite = Suite::parse(code, "<test>").expect("parse");
383 let target = TargetSpec {
384 func: "f".to_string(),
385 line: 1,
386 percent: 100.0,
387 reason: "test nested for".to_string(),
388 };
389 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
390 assert!(
391 t.fallback,
392 "expected fallback for for..else, got body:\n{}",
393 t.body
394 );
395 }
396
397 #[test]
398 fn test_translate_dot_product_zero_fallback() {
399 let code = "def dot_product(a, b):\n total = 0.0\n for i in range(len(a)):\n total += a[i] * b[i]\n return total\n";
400 let suite = Suite::parse(code, "<test>").expect("parse");
401 let target = TargetSpec {
402 func: "dot_product".to_string(),
403 line: 1,
404 percent: 80.0,
405 reason: "test".to_string(),
406 };
407 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
408 assert!(
409 !t.fallback,
410 "dot_product should not fallback; body:\n{}",
411 t.body
412 );
413 assert!(t.body.contains("total +="), "got: {}", t.body);
414 }
415
416 #[test]
417 fn test_translate_matmul_nested_loops() {
418 let code = r#"
419def matmul(a, b, n):
420 result = [0.0] * (n * n)
421 for i in range(n):
422 for j in range(n):
423 total = 0.0
424 for k in range(n):
425 total += a[i * n + k] * b[k * n + j]
426 result[i * n + j] = total
427 return result
428"#;
429 let suite = Suite::parse(code, "<test>").expect("parse");
430 let target = TargetSpec {
431 func: "matmul".to_string(),
432 line: 1,
433 percent: 100.0,
434 reason: "test".to_string(),
435 };
436 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
437 assert!(t.body.contains("for i in 0..n"), "got: {}", t.body);
438 assert!(t.body.contains("for j in 0..n"), "got: {}", t.body);
439 assert!(t.body.contains("for k in 0..n"), "got: {}", t.body);
440 assert!(t.body.contains("vec!["), "got: {}", t.body);
441 }
442
443 #[test]
444 fn test_translate_while_loop_bool_flag() {
445 let code = "def count_pairs(tokens):\n counts = 0\n changed = True\n while changed:\n changed = False\n counts += 1\n return counts\n";
446 let suite = Suite::parse(code, "<test>").expect("parse");
447 let target = TargetSpec {
448 func: "count_pairs".to_string(),
449 line: 1,
450 percent: 100.0,
451 reason: "test".to_string(),
452 };
453 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
454 assert!(t.body.contains("while changed"), "got:\n{}", t.body);
455 }
456
457 #[test]
458 fn test_translate_while_comparison() {
459 let code = "def scan(tokens):\n i = 0\n while i < len(tokens):\n i += 1\n return i\n";
460 let suite = Suite::parse(code, "<test>").expect("parse");
461 let target = TargetSpec {
462 func: "scan".to_string(),
463 line: 1,
464 percent: 100.0,
465 reason: "test".to_string(),
466 };
467 let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
468 assert!(t.body.contains("while i <"), "got:\n{}", t.body);
469 assert!(t.body.contains("tokens.len()"), "got:\n{}", t.body);
470 }
471
472 #[test]
475 fn test_ndarray_mode_replaces_vec_params() {
476 let code = "import numpy as np\ndef dot_product(a, b):\n total = 0.0\n for i in range(len(a)):\n total += a[i] * b[i]\n return total\n";
477 let source = InputSource::Snippet(code.to_string());
478 let targets = vec![TargetSpec {
479 func: "dot_product".to_string(),
480 line: 1,
481 percent: 100.0,
482 reason: "test".to_string(),
483 }];
484 let tmp = tempdir().expect("tempdir");
485 let result = generate_ml(&source, &targets, tmp.path(), false).expect("generate_ml");
486 let lib_rs = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/src/lib.rs")).unwrap();
487 assert_eq!(result.fallback_functions, 0);
488 assert!(lib_rs.contains("PyReadonlyArray1<f64>"), "got:\n{}", lib_rs);
489 assert!(lib_rs.contains("use numpy;"), "got:\n{}", lib_rs);
490 let cargo = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/Cargo.toml")).unwrap();
491 assert!(cargo.contains("numpy"), "got:\n{}", cargo);
492 }
493
494 #[test]
495 fn test_ndarray_mode_no_numpy_import_stays_vec() {
496 let code = "def dot_product(a, b):\n total = 0.0\n for i in range(len(a)):\n total += a[i] * b[i]\n return total\n";
497 let source = InputSource::Snippet(code.to_string());
498 let targets = vec![TargetSpec {
499 func: "dot_product".to_string(),
500 line: 1,
501 percent: 100.0,
502 reason: "test".to_string(),
503 }];
504 let tmp = tempdir().expect("tempdir");
505 let result = generate_ml(&source, &targets, tmp.path(), false).expect("generate_ml");
506 let lib_rs = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/src/lib.rs")).unwrap();
507 assert_eq!(result.fallback_functions, 0);
508 assert!(!lib_rs.contains("PyReadonlyArray1"), "got:\n{}", lib_rs);
509 assert!(lib_rs.contains("Vec<f64>"), "got:\n{}", lib_rs);
510 }
511
512 #[test]
515 fn test_generate_integration_euclidean() {
516 let path = PathBuf::from("examples/euclidean.py");
517 let code = std::fs::read_to_string(&path).expect("read example");
518 let source = InputSource::File {
519 path: path.clone(),
520 code,
521 };
522 let targets = vec![TargetSpec {
523 func: "euclidean".to_string(),
524 line: 1,
525 percent: 100.0,
526 reason: "test".to_string(),
527 }];
528 let tmp = tempdir().expect("tempdir");
529 let result = generate(&source, &targets, tmp.path(), false).expect("generate");
530 assert_eq!(result.generated_functions.len(), 1);
531 assert_eq!(result.fallback_functions, 0);
532 assert!(tmp.path().join("rustify_ml_ext/src/lib.rs").exists());
533 }
534}