Skip to main content

rustify_ml/generator/
mod.rs

1//! Generator module — public API for Python → Rust PyO3 stub generation.
2//!
3//! # Submodules
4//! - [`expr`]      — expression-to-Rust translation (pure, no I/O)
5//! - [`infer`]     — type inference for params and assignments
6//! - [`translate`] — statement/body translation (AST walk)
7//! - [`render`]    — PyO3 function + lib.rs + Cargo.toml rendering
8//!
9//! # Entry points
10//! - [`generate`]    — standard generation
11//! - [`generate_ml`] — ML mode (numpy → PyReadonlyArray1)
12
13pub mod expr;
14pub mod infer;
15pub mod render;
16pub mod translate;
17
18use std::fs;
19use std::path::Path;
20
21use anyhow::{Context, Result, anyhow};
22use rustpython_parser::Parse;
23use rustpython_parser::ast::{Stmt, Suite};
24use tracing::{info, warn};
25
26use crate::utils::{GenerationResult, InputSource, TargetSpec, extract_code};
27use render::{
28    render_cargo_toml_with_options, render_function_with_options, render_lib_rs_with_options,
29};
30
31/// Detect numpy usage in Python source (triggers ndarray mode).
32fn detects_numpy(code: &str) -> bool {
33    code.contains("import numpy")
34        || code.contains("from numpy")
35        || code.contains("import np")
36        || code.contains("np.ndarray")
37        || code.contains("numpy.ndarray")
38        || code.contains("np.array")
39}
40
41/// Parse existing lib.rs content and return complete #[pyfunction] blocks using brace balance.
42fn parse_existing_functions(lib_rs: &str) -> Vec<String> {
43    let mut funcs = Vec::new();
44    let mut current = Vec::new();
45    let mut in_fn = false;
46    let mut brace_balance: i32 = 0;
47    let mut seen_open = false;
48    for line in lib_rs.lines() {
49        let trimmed = line.trim_start();
50        if trimmed.starts_with("#[pyfunction]") {
51            current.clear();
52            in_fn = true;
53            brace_balance = 0;
54            seen_open = false;
55            current.push(line.to_string());
56            continue;
57        }
58
59        if in_fn {
60            current.push(line.to_string());
61            let opens = line.matches('{').count() as i32;
62            let closes = line.matches('}').count() as i32;
63            if opens > 0 {
64                seen_open = true;
65            }
66            brace_balance += opens;
67            brace_balance -= closes;
68
69            if seen_open && brace_balance <= 0 {
70                funcs.push(current.join("\n"));
71                current.clear();
72                in_fn = false;
73            }
74        }
75    }
76    funcs
77}
78
79/// Generate Rust + PyO3 stubs for the given targets.
80pub fn generate(
81    source: &InputSource,
82    targets: &[TargetSpec],
83    output: &Path,
84    dry_run: bool,
85) -> Result<GenerationResult> {
86    generate_with_options(source, targets, output, dry_run, false)
87}
88
89/// Generate with ML mode: detects numpy imports → uses `PyReadonlyArray1<f64>` params.
90pub fn generate_ml(
91    source: &InputSource,
92    targets: &[TargetSpec],
93    output: &Path,
94    dry_run: bool,
95) -> Result<GenerationResult> {
96    generate_with_options(source, targets, output, dry_run, true)
97}
98
99fn generate_with_options(
100    source: &InputSource,
101    targets: &[TargetSpec],
102    output: &Path,
103    dry_run: bool,
104    ml_mode: bool,
105) -> Result<GenerationResult> {
106    if targets.is_empty() {
107        return Err(anyhow!("no targets selected for generation"));
108    }
109
110    fs::create_dir_all(output)
111        .with_context(|| format!("failed to create output dir {}", output.display()))?;
112    let crate_dir = output.join("rustify_ml_ext");
113    if crate_dir.exists() {
114        info!(path = %crate_dir.display(), "reusing existing generated crate directory");
115    } else {
116        fs::create_dir_all(crate_dir.join("src")).context("failed to create crate directories")?;
117    }
118
119    let code = extract_code(source)?;
120    let use_ndarray = ml_mode && detects_numpy(&code);
121    if use_ndarray {
122        info!("numpy detected + ml_mode: using PyReadonlyArray1<f64> params");
123    }
124
125    let suite =
126        Suite::parse(&code, "<input>").context("failed to parse Python input for generation")?;
127    let stmts: &[Stmt] = suite.as_slice();
128
129    // Merge previously generated functions (if crate already exists) with newly generated ones.
130    let mut functions_by_name: std::collections::HashMap<String, String> =
131        std::collections::HashMap::new();
132
133    let existing_lib = crate_dir.join("src/lib.rs");
134    if existing_lib.exists()
135        && let Ok(existing_src) = std::fs::read_to_string(&existing_lib)
136    {
137        for func_src in parse_existing_functions(&existing_src) {
138            let name = render::extract_fn_name(&func_src);
139            functions_by_name.insert(name, func_src);
140        }
141    }
142
143    let mut fallback_functions = 0usize;
144    for t in targets.iter() {
145        let (code, fallback) = render_function_with_options(t, stmts, use_ndarray);
146        let name = render::extract_fn_name(&code);
147        if fallback {
148            fallback_functions += 1;
149        }
150        functions_by_name.insert(name, code);
151    }
152
153    let functions: Vec<String> = functions_by_name.into_values().collect();
154
155    let lib_rs = render_lib_rs_with_options(&functions, use_ndarray);
156    let cargo_toml = render_cargo_toml_with_options(use_ndarray);
157
158    fs::write(crate_dir.join("src/lib.rs"), lib_rs).context("failed to write lib.rs")?;
159    fs::write(crate_dir.join("Cargo.toml"), cargo_toml).context("failed to write Cargo.toml")?;
160
161    if dry_run {
162        info!(path = %crate_dir.display(), "dry-run: wrote generated files (no build)");
163    }
164    if fallback_functions > 0 {
165        warn!(
166            fallback_functions,
167            "some functions fell back to echo translation"
168        );
169    }
170    info!(path = %crate_dir.display(), funcs = functions.len(), "generated Rust stubs");
171
172    Ok(GenerationResult {
173        crate_dir,
174        generated_functions: functions,
175        fallback_functions,
176    })
177}
178
179// ── Tests ─────────────────────────────────────────────────────────────────────
180
181#[cfg(test)]
182mod tests {
183    use std::path::PathBuf;
184
185    use rustpython_parser::Parse;
186    use rustpython_parser::ast::{Expr, Operator, Suite};
187    use rustpython_parser::text_size::TextRange;
188    use tempfile::tempdir;
189
190    use crate::utils::{InputSource, TargetSpec};
191
192    use super::expr::expr_to_rust;
193    use super::infer::render_len_checks;
194    use super::translate::translate_function_body;
195    use super::*;
196
197    // ── expr tests ────────────────────────────────────────────────────────────
198
199    #[test]
200    fn test_expr_to_rust_range_and_len() {
201        let range_expr = Expr::Call(rustpython_parser::ast::ExprCall {
202            func: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
203                range: TextRange::default(),
204                id: "range".into(),
205                ctx: rustpython_parser::ast::ExprContext::Load,
206            })),
207            args: vec![Expr::Constant(rustpython_parser::ast::ExprConstant {
208                range: TextRange::default(),
209                value: rustpython_parser::ast::Constant::Int(10.into()),
210                kind: None,
211            })],
212            keywords: vec![],
213            range: TextRange::default(),
214        });
215        let len_expr = Expr::Call(rustpython_parser::ast::ExprCall {
216            func: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
217                range: TextRange::default(),
218                id: "len".into(),
219                ctx: rustpython_parser::ast::ExprContext::Load,
220            })),
221            args: vec![Expr::Name(rustpython_parser::ast::ExprName {
222                range: TextRange::default(),
223                id: "a".into(),
224                ctx: rustpython_parser::ast::ExprContext::Load,
225            })],
226            keywords: vec![],
227            range: TextRange::default(),
228        });
229        assert_eq!(expr_to_rust(&range_expr), "0..10");
230        assert_eq!(expr_to_rust(&len_expr), "a.len()");
231    }
232
233    #[test]
234    fn test_expr_to_rust_binop_pow() {
235        let bin = Expr::BinOp(rustpython_parser::ast::ExprBinOp {
236            range: TextRange::default(),
237            left: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
238                range: TextRange::default(),
239                id: "x".into(),
240                ctx: rustpython_parser::ast::ExprContext::Load,
241            })),
242            op: Operator::Pow,
243            right: Box::new(Expr::Constant(rustpython_parser::ast::ExprConstant {
244                range: TextRange::default(),
245                value: rustpython_parser::ast::Constant::Int(2.into()),
246                kind: None,
247            })),
248        });
249        assert_eq!(expr_to_rust(&bin), "(x).powf(2)");
250    }
251
252    // ── infer tests ───────────────────────────────────────────────────────────
253
254    #[test]
255    fn test_render_len_checks_multiple_vecs() {
256        let params = vec![
257            ("a".to_string(), "Vec<f64>".to_string()),
258            ("b".to_string(), "Vec<f64>".to_string()),
259        ];
260        let rendered = render_len_checks(&params).unwrap();
261        assert!(rendered.contains("a.len() != b.len()"));
262        assert!(rendered.contains("PyValueError"));
263    }
264
265    // ── translate tests ───────────────────────────────────────────────────────
266
267    #[test]
268    fn test_translate_euclidean_body() {
269        let code = r#"
270def euclidean(p1, p2):
271    total = 0.0
272    for i in range(len(p1)):
273        diff = p1[i] - p2[i]
274        total += diff * diff
275    return total ** 0.5
276"#;
277        let suite = Suite::parse(code, "<test>").expect("parse failed");
278        let target = TargetSpec {
279            func: "euclidean".to_string(),
280            line: 1,
281            percent: 100.0,
282            reason: "test".to_string(),
283        };
284        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
285        assert_eq!(t.return_type, "f64");
286        assert!(!t.fallback);
287        assert!(t.body.contains("for i in 0.."));
288    }
289
290    #[test]
291    fn test_translate_stmt_float_assign_init() {
292        let code = "def f(x):\n    total = 0.0\n    return total\n";
293        let suite = Suite::parse(code, "<test>").expect("parse");
294        let target = TargetSpec {
295            func: "f".to_string(),
296            line: 1,
297            percent: 100.0,
298            reason: "test".to_string(),
299        };
300        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
301        assert!(t.body.contains("let mut total: f64"), "got: {}", t.body);
302    }
303
304    #[test]
305    fn test_translate_stmt_subscript_assign() {
306        let code = "def f(result, i, val):\n    result[i] = val\n    return result\n";
307        let suite = Suite::parse(code, "<test>").expect("parse");
308        let target = TargetSpec {
309            func: "f".to_string(),
310            line: 1,
311            percent: 100.0,
312            reason: "test".to_string(),
313        };
314        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
315        assert!(t.body.contains("result[i] = val"), "got: {}", t.body);
316    }
317
318    #[test]
319    fn test_translate_stmt_list_init() {
320        let code = "def f(n):\n    result = [0.0] * n\n    return result\n";
321        let suite = Suite::parse(code, "<test>").expect("parse");
322        let target = TargetSpec {
323            func: "f".to_string(),
324            line: 1,
325            percent: 100.0,
326            reason: "test".to_string(),
327        };
328        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
329        assert!(
330            t.body.contains("vec![") && t.body.contains("result"),
331            "got: {}",
332            t.body
333        );
334    }
335
336    #[test]
337    fn test_translate_list_comprehension() {
338        // result = [x * 2.0 for x in data] → let result: Vec<f64> = data.iter().map(|x| ...).collect();
339        let code = "def f(data):\n    result = [x * 2.0 for x in data]\n    return result\n";
340        let suite = Suite::parse(code, "<test>").expect("parse");
341        let target = TargetSpec {
342            func: "f".to_string(),
343            line: 1,
344            percent: 100.0,
345            reason: "test".to_string(),
346        };
347        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
348        assert!(
349            t.body.contains(".iter().map(") && t.body.contains(".collect()"),
350            "expected iter().map().collect() for list comp, got: {}",
351            t.body
352        );
353        assert_eq!(t.return_type, "Vec<f64>");
354        assert!(!t.fallback);
355    }
356
357    #[test]
358    fn test_translate_argmax_tuple_return() {
359        // returns (index, value) tuple
360        let code = "def argmax(xs):\n    best_idx = 0\n    best_val = xs[0]\n    for i in range(len(xs)):\n        if xs[i] > best_val:\n            best_val = xs[i]\n            best_idx = i\n    return (best_idx, best_val)\n";
361        let suite = Suite::parse(code, "<test>").expect("parse");
362        let target = TargetSpec {
363            func: "argmax".to_string(),
364            line: 1,
365            percent: 100.0,
366            reason: "argmax tuple".to_string(),
367        };
368        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
369        assert_eq!(t.return_type, "(usize, f64)");
370        assert!(
371            t.body.contains("return Ok((best_idx, best_val));"),
372            "body: {}",
373            t.body
374        );
375        assert!(!t.fallback);
376    }
377
378    #[test]
379    fn test_translate_nested_for_else_triggers_fallback() {
380        // for..else is unsupported; expect fallback
381        let code = "def f(n):\n    total = 0\n    for i in range(n):\n        for j in range(n):\n            total += i + j\n    else:\n        total += 1\n    return total\n";
382        let suite = Suite::parse(code, "<test>").expect("parse");
383        let target = TargetSpec {
384            func: "f".to_string(),
385            line: 1,
386            percent: 100.0,
387            reason: "test nested for".to_string(),
388        };
389        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
390        assert!(
391            t.fallback,
392            "expected fallback for for..else, got body:\n{}",
393            t.body
394        );
395    }
396
397    #[test]
398    fn test_translate_dot_product_zero_fallback() {
399        let code = "def dot_product(a, b):\n    total = 0.0\n    for i in range(len(a)):\n        total += a[i] * b[i]\n    return total\n";
400        let suite = Suite::parse(code, "<test>").expect("parse");
401        let target = TargetSpec {
402            func: "dot_product".to_string(),
403            line: 1,
404            percent: 80.0,
405            reason: "test".to_string(),
406        };
407        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
408        assert!(
409            !t.fallback,
410            "dot_product should not fallback; body:\n{}",
411            t.body
412        );
413        assert!(t.body.contains("total +="), "got: {}", t.body);
414    }
415
416    #[test]
417    fn test_translate_matmul_nested_loops() {
418        let code = r#"
419def matmul(a, b, n):
420    result = [0.0] * (n * n)
421    for i in range(n):
422        for j in range(n):
423            total = 0.0
424            for k in range(n):
425                total += a[i * n + k] * b[k * n + j]
426            result[i * n + j] = total
427    return result
428"#;
429        let suite = Suite::parse(code, "<test>").expect("parse");
430        let target = TargetSpec {
431            func: "matmul".to_string(),
432            line: 1,
433            percent: 100.0,
434            reason: "test".to_string(),
435        };
436        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
437        assert!(t.body.contains("for i in 0..n"), "got: {}", t.body);
438        assert!(t.body.contains("for j in 0..n"), "got: {}", t.body);
439        assert!(t.body.contains("for k in 0..n"), "got: {}", t.body);
440        assert!(t.body.contains("vec!["), "got: {}", t.body);
441    }
442
443    #[test]
444    fn test_translate_while_loop_bool_flag() {
445        let code = "def count_pairs(tokens):\n    counts = 0\n    changed = True\n    while changed:\n        changed = False\n        counts += 1\n    return counts\n";
446        let suite = Suite::parse(code, "<test>").expect("parse");
447        let target = TargetSpec {
448            func: "count_pairs".to_string(),
449            line: 1,
450            percent: 100.0,
451            reason: "test".to_string(),
452        };
453        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
454        assert!(t.body.contains("while changed"), "got:\n{}", t.body);
455    }
456
457    #[test]
458    fn test_translate_while_comparison() {
459        let code = "def scan(tokens):\n    i = 0\n    while i < len(tokens):\n        i += 1\n    return i\n";
460        let suite = Suite::parse(code, "<test>").expect("parse");
461        let target = TargetSpec {
462            func: "scan".to_string(),
463            line: 1,
464            percent: 100.0,
465            reason: "test".to_string(),
466        };
467        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
468        assert!(t.body.contains("while i <"), "got:\n{}", t.body);
469        assert!(t.body.contains("tokens.len()"), "got:\n{}", t.body);
470    }
471
472    // ── ndarray / ml_mode tests ───────────────────────────────────────────────
473
474    #[test]
475    fn test_ndarray_mode_replaces_vec_params() {
476        let code = "import numpy as np\ndef dot_product(a, b):\n    total = 0.0\n    for i in range(len(a)):\n        total += a[i] * b[i]\n    return total\n";
477        let source = InputSource::Snippet(code.to_string());
478        let targets = vec![TargetSpec {
479            func: "dot_product".to_string(),
480            line: 1,
481            percent: 100.0,
482            reason: "test".to_string(),
483        }];
484        let tmp = tempdir().expect("tempdir");
485        let result = generate_ml(&source, &targets, tmp.path(), false).expect("generate_ml");
486        let lib_rs = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/src/lib.rs")).unwrap();
487        assert_eq!(result.fallback_functions, 0);
488        assert!(lib_rs.contains("PyReadonlyArray1<f64>"), "got:\n{}", lib_rs);
489        assert!(lib_rs.contains("use numpy;"), "got:\n{}", lib_rs);
490        let cargo = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/Cargo.toml")).unwrap();
491        assert!(cargo.contains("numpy"), "got:\n{}", cargo);
492    }
493
494    #[test]
495    fn test_ndarray_mode_no_numpy_import_stays_vec() {
496        let code = "def dot_product(a, b):\n    total = 0.0\n    for i in range(len(a)):\n        total += a[i] * b[i]\n    return total\n";
497        let source = InputSource::Snippet(code.to_string());
498        let targets = vec![TargetSpec {
499            func: "dot_product".to_string(),
500            line: 1,
501            percent: 100.0,
502            reason: "test".to_string(),
503        }];
504        let tmp = tempdir().expect("tempdir");
505        let result = generate_ml(&source, &targets, tmp.path(), false).expect("generate_ml");
506        let lib_rs = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/src/lib.rs")).unwrap();
507        assert_eq!(result.fallback_functions, 0);
508        assert!(!lib_rs.contains("PyReadonlyArray1"), "got:\n{}", lib_rs);
509        assert!(lib_rs.contains("Vec<f64>"), "got:\n{}", lib_rs);
510    }
511
512    // ── integration tests ─────────────────────────────────────────────────────
513
514    #[test]
515    fn test_generate_integration_euclidean() {
516        let path = PathBuf::from("examples/euclidean.py");
517        let code = std::fs::read_to_string(&path).expect("read example");
518        let source = InputSource::File {
519            path: path.clone(),
520            code,
521        };
522        let targets = vec![TargetSpec {
523            func: "euclidean".to_string(),
524            line: 1,
525            percent: 100.0,
526            reason: "test".to_string(),
527        }];
528        let tmp = tempdir().expect("tempdir");
529        let result = generate(&source, &targets, tmp.path(), false).expect("generate");
530        assert_eq!(result.generated_functions.len(), 1);
531        assert_eq!(result.fallback_functions, 0);
532        assert!(tmp.path().join("rustify_ml_ext/src/lib.rs").exists());
533    }
534}