Skip to main content

rustify_ml/generator/
mod.rs

1//! Generator module — public API for Python → Rust PyO3 stub generation.
2//!
3//! # Submodules
4//! - [`expr`]      — expression-to-Rust translation (pure, no I/O)
5//! - [`infer`]     — type inference for params and assignments
6//! - [`translate`] — statement/body translation (AST walk)
7//! - [`render`]    — PyO3 function + lib.rs + Cargo.toml rendering
8//!
9//! # Entry points
10//! - [`generate`]    — standard generation
11//! - [`generate_ml`] — ML mode (numpy → PyReadonlyArray1)
12
13pub mod expr;
14pub mod infer;
15pub mod render;
16pub mod translate;
17
18use std::fs;
19use std::path::Path;
20
21use anyhow::{Context, Result, anyhow};
22use rustpython_parser::Parse;
23use rustpython_parser::ast::{Stmt, Suite};
24use tracing::{info, warn};
25
26use crate::utils::{GenerationResult, InputSource, TargetSpec, extract_code};
27use render::{
28    render_cargo_toml_with_options, render_function_with_options, render_lib_rs_with_options,
29};
30
31/// Detect numpy usage in Python source (triggers ndarray mode).
32fn detects_numpy(code: &str) -> bool {
33    code.contains("import numpy") || code.contains("from numpy") || code.contains("import np")
34}
35
36/// Parse existing lib.rs content and return complete #[pyfunction] blocks using brace balance.
37fn parse_existing_functions(lib_rs: &str) -> Vec<String> {
38    let mut funcs = Vec::new();
39    let mut current = Vec::new();
40    let mut in_fn = false;
41    let mut brace_balance: i32 = 0;
42    let mut seen_open = false;
43    for line in lib_rs.lines() {
44        let trimmed = line.trim_start();
45        if trimmed.starts_with("#[pyfunction]") {
46            current.clear();
47            in_fn = true;
48            brace_balance = 0;
49            seen_open = false;
50            current.push(line.to_string());
51            continue;
52        }
53
54        if in_fn {
55            current.push(line.to_string());
56            let opens = line.matches('{').count() as i32;
57            let closes = line.matches('}').count() as i32;
58            if opens > 0 {
59                seen_open = true;
60            }
61            brace_balance += opens;
62            brace_balance -= closes;
63
64            if seen_open && brace_balance <= 0 {
65                funcs.push(current.join("\n"));
66                current.clear();
67                in_fn = false;
68            }
69        }
70    }
71    funcs
72}
73
74/// Generate Rust + PyO3 stubs for the given targets.
75pub fn generate(
76    source: &InputSource,
77    targets: &[TargetSpec],
78    output: &Path,
79    dry_run: bool,
80) -> Result<GenerationResult> {
81    generate_with_options(source, targets, output, dry_run, false)
82}
83
84/// Generate with ML mode: detects numpy imports → uses `PyReadonlyArray1<f64>` params.
85pub fn generate_ml(
86    source: &InputSource,
87    targets: &[TargetSpec],
88    output: &Path,
89    dry_run: bool,
90) -> Result<GenerationResult> {
91    generate_with_options(source, targets, output, dry_run, true)
92}
93
94fn generate_with_options(
95    source: &InputSource,
96    targets: &[TargetSpec],
97    output: &Path,
98    dry_run: bool,
99    ml_mode: bool,
100) -> Result<GenerationResult> {
101    if targets.is_empty() {
102        return Err(anyhow!("no targets selected for generation"));
103    }
104
105    fs::create_dir_all(output)
106        .with_context(|| format!("failed to create output dir {}", output.display()))?;
107    let crate_dir = output.join("rustify_ml_ext");
108    if crate_dir.exists() {
109        info!(path = %crate_dir.display(), "reusing existing generated crate directory");
110    } else {
111        fs::create_dir_all(crate_dir.join("src")).context("failed to create crate directories")?;
112    }
113
114    let code = extract_code(source)?;
115    let use_ndarray = ml_mode && detects_numpy(&code);
116    if use_ndarray {
117        info!("numpy detected + ml_mode: using PyReadonlyArray1<f64> params");
118    }
119
120    let suite =
121        Suite::parse(&code, "<input>").context("failed to parse Python input for generation")?;
122    let stmts: &[Stmt] = suite.as_slice();
123
124    // Merge previously generated functions (if crate already exists) with newly generated ones.
125    let mut functions_by_name: std::collections::HashMap<String, String> =
126        std::collections::HashMap::new();
127
128    let existing_lib = crate_dir.join("src/lib.rs");
129    if existing_lib.exists()
130        && let Ok(existing_src) = std::fs::read_to_string(&existing_lib)
131    {
132        for func_src in parse_existing_functions(&existing_src) {
133            let name = render::extract_fn_name(&func_src);
134            functions_by_name.insert(name, func_src);
135        }
136    }
137
138    let mut fallback_functions = 0usize;
139    for t in targets.iter() {
140        let (code, fallback) = render_function_with_options(t, stmts, use_ndarray);
141        let name = render::extract_fn_name(&code);
142        if fallback {
143            fallback_functions += 1;
144        }
145        functions_by_name.insert(name, code);
146    }
147
148    let functions: Vec<String> = functions_by_name.into_values().collect();
149
150    let lib_rs = render_lib_rs_with_options(&functions, use_ndarray);
151    let cargo_toml = render_cargo_toml_with_options(use_ndarray);
152
153    fs::write(crate_dir.join("src/lib.rs"), lib_rs).context("failed to write lib.rs")?;
154    fs::write(crate_dir.join("Cargo.toml"), cargo_toml).context("failed to write Cargo.toml")?;
155
156    if dry_run {
157        info!(path = %crate_dir.display(), "dry-run: wrote generated files (no build)");
158    }
159    if fallback_functions > 0 {
160        warn!(
161            fallback_functions,
162            "some functions fell back to echo translation"
163        );
164    }
165    info!(path = %crate_dir.display(), funcs = functions.len(), "generated Rust stubs");
166
167    Ok(GenerationResult {
168        crate_dir,
169        generated_functions: functions,
170        fallback_functions,
171    })
172}
173
174// ── Tests ─────────────────────────────────────────────────────────────────────
175
176#[cfg(test)]
177mod tests {
178    use std::path::PathBuf;
179
180    use rustpython_parser::Parse;
181    use rustpython_parser::ast::{Expr, Operator, Suite};
182    use rustpython_parser::text_size::TextRange;
183    use tempfile::tempdir;
184
185    use crate::utils::{InputSource, TargetSpec};
186
187    use super::expr::expr_to_rust;
188    use super::infer::render_len_checks;
189    use super::translate::translate_function_body;
190    use super::*;
191
192    // ── expr tests ────────────────────────────────────────────────────────────
193
194    #[test]
195    fn test_expr_to_rust_range_and_len() {
196        let range_expr = Expr::Call(rustpython_parser::ast::ExprCall {
197            func: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
198                range: TextRange::default(),
199                id: "range".into(),
200                ctx: rustpython_parser::ast::ExprContext::Load,
201            })),
202            args: vec![Expr::Constant(rustpython_parser::ast::ExprConstant {
203                range: TextRange::default(),
204                value: rustpython_parser::ast::Constant::Int(10.into()),
205                kind: None,
206            })],
207            keywords: vec![],
208            range: TextRange::default(),
209        });
210        let len_expr = Expr::Call(rustpython_parser::ast::ExprCall {
211            func: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
212                range: TextRange::default(),
213                id: "len".into(),
214                ctx: rustpython_parser::ast::ExprContext::Load,
215            })),
216            args: vec![Expr::Name(rustpython_parser::ast::ExprName {
217                range: TextRange::default(),
218                id: "a".into(),
219                ctx: rustpython_parser::ast::ExprContext::Load,
220            })],
221            keywords: vec![],
222            range: TextRange::default(),
223        });
224        assert_eq!(expr_to_rust(&range_expr), "0..10");
225        assert_eq!(expr_to_rust(&len_expr), "a.len()");
226    }
227
228    #[test]
229    fn test_expr_to_rust_binop_pow() {
230        let bin = Expr::BinOp(rustpython_parser::ast::ExprBinOp {
231            range: TextRange::default(),
232            left: Box::new(Expr::Name(rustpython_parser::ast::ExprName {
233                range: TextRange::default(),
234                id: "x".into(),
235                ctx: rustpython_parser::ast::ExprContext::Load,
236            })),
237            op: Operator::Pow,
238            right: Box::new(Expr::Constant(rustpython_parser::ast::ExprConstant {
239                range: TextRange::default(),
240                value: rustpython_parser::ast::Constant::Int(2.into()),
241                kind: None,
242            })),
243        });
244        assert_eq!(expr_to_rust(&bin), "(x).powf(2)");
245    }
246
247    // ── infer tests ───────────────────────────────────────────────────────────
248
249    #[test]
250    fn test_render_len_checks_multiple_vecs() {
251        let params = vec![
252            ("a".to_string(), "Vec<f64>".to_string()),
253            ("b".to_string(), "Vec<f64>".to_string()),
254        ];
255        let rendered = render_len_checks(&params).unwrap();
256        assert!(rendered.contains("a.len() != b.len()"));
257        assert!(rendered.contains("PyValueError"));
258    }
259
260    // ── translate tests ───────────────────────────────────────────────────────
261
262    #[test]
263    fn test_translate_euclidean_body() {
264        let code = r#"
265def euclidean(p1, p2):
266    total = 0.0
267    for i in range(len(p1)):
268        diff = p1[i] - p2[i]
269        total += diff * diff
270    return total ** 0.5
271"#;
272        let suite = Suite::parse(code, "<test>").expect("parse failed");
273        let target = TargetSpec {
274            func: "euclidean".to_string(),
275            line: 1,
276            percent: 100.0,
277            reason: "test".to_string(),
278        };
279        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
280        assert_eq!(t.return_type, "f64");
281        assert!(!t.fallback);
282        assert!(t.body.contains("for i in 0.."));
283    }
284
285    #[test]
286    fn test_translate_stmt_float_assign_init() {
287        let code = "def f(x):\n    total = 0.0\n    return total\n";
288        let suite = Suite::parse(code, "<test>").expect("parse");
289        let target = TargetSpec {
290            func: "f".to_string(),
291            line: 1,
292            percent: 100.0,
293            reason: "test".to_string(),
294        };
295        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
296        assert!(t.body.contains("let mut total: f64"), "got: {}", t.body);
297    }
298
299    #[test]
300    fn test_translate_stmt_subscript_assign() {
301        let code = "def f(result, i, val):\n    result[i] = val\n    return result\n";
302        let suite = Suite::parse(code, "<test>").expect("parse");
303        let target = TargetSpec {
304            func: "f".to_string(),
305            line: 1,
306            percent: 100.0,
307            reason: "test".to_string(),
308        };
309        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
310        assert!(t.body.contains("result[i] = val"), "got: {}", t.body);
311    }
312
313    #[test]
314    fn test_translate_stmt_list_init() {
315        let code = "def f(n):\n    result = [0.0] * n\n    return result\n";
316        let suite = Suite::parse(code, "<test>").expect("parse");
317        let target = TargetSpec {
318            func: "f".to_string(),
319            line: 1,
320            percent: 100.0,
321            reason: "test".to_string(),
322        };
323        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
324        assert!(
325            t.body.contains("vec![") && t.body.contains("result"),
326            "got: {}",
327            t.body
328        );
329    }
330
331    #[test]
332    fn test_translate_list_comprehension() {
333        // result = [x * 2.0 for x in data] → let result: Vec<f64> = data.iter().map(|x| ...).collect();
334        let code = "def f(data):\n    result = [x * 2.0 for x in data]\n    return result\n";
335        let suite = Suite::parse(code, "<test>").expect("parse");
336        let target = TargetSpec {
337            func: "f".to_string(),
338            line: 1,
339            percent: 100.0,
340            reason: "test".to_string(),
341        };
342        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
343        assert!(
344            t.body.contains(".iter().map(") && t.body.contains(".collect()"),
345            "expected iter().map().collect() for list comp, got: {}",
346            t.body
347        );
348        assert_eq!(t.return_type, "Vec<f64>");
349        assert!(!t.fallback);
350    }
351
352    #[test]
353    fn test_translate_nested_for_else_triggers_fallback() {
354        // for..else is unsupported; expect fallback
355        let code = "def f(n):\n    total = 0\n    for i in range(n):\n        for j in range(n):\n            total += i + j\n    else:\n        total += 1\n    return total\n";
356        let suite = Suite::parse(code, "<test>").expect("parse");
357        let target = TargetSpec {
358            func: "f".to_string(),
359            line: 1,
360            percent: 100.0,
361            reason: "test nested for".to_string(),
362        };
363        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
364        assert!(t.fallback, "expected fallback for for..else, got body:\n{}", t.body);
365    }
366
367    #[test]
368    fn test_translate_dot_product_zero_fallback() {
369        let code = "def dot_product(a, b):\n    total = 0.0\n    for i in range(len(a)):\n        total += a[i] * b[i]\n    return total\n";
370        let suite = Suite::parse(code, "<test>").expect("parse");
371        let target = TargetSpec {
372            func: "dot_product".to_string(),
373            line: 1,
374            percent: 80.0,
375            reason: "test".to_string(),
376        };
377        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
378        assert!(
379            !t.fallback,
380            "dot_product should not fallback; body:\n{}",
381            t.body
382        );
383        assert!(t.body.contains("total +="), "got: {}", t.body);
384    }
385
386    #[test]
387    fn test_translate_matmul_nested_loops() {
388        let code = r#"
389def matmul(a, b, n):
390    result = [0.0] * (n * n)
391    for i in range(n):
392        for j in range(n):
393            total = 0.0
394            for k in range(n):
395                total += a[i * n + k] * b[k * n + j]
396            result[i * n + j] = total
397    return result
398"#;
399        let suite = Suite::parse(code, "<test>").expect("parse");
400        let target = TargetSpec {
401            func: "matmul".to_string(),
402            line: 1,
403            percent: 100.0,
404            reason: "test".to_string(),
405        };
406        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
407        assert!(t.body.contains("for i in 0..n"), "got: {}", t.body);
408        assert!(t.body.contains("for j in 0..n"), "got: {}", t.body);
409        assert!(t.body.contains("for k in 0..n"), "got: {}", t.body);
410        assert!(t.body.contains("vec!["), "got: {}", t.body);
411    }
412
413    #[test]
414    fn test_translate_while_loop_bool_flag() {
415        let code = "def count_pairs(tokens):\n    counts = 0\n    changed = True\n    while changed:\n        changed = False\n        counts += 1\n    return counts\n";
416        let suite = Suite::parse(code, "<test>").expect("parse");
417        let target = TargetSpec {
418            func: "count_pairs".to_string(),
419            line: 1,
420            percent: 100.0,
421            reason: "test".to_string(),
422        };
423        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
424        assert!(t.body.contains("while changed"), "got:\n{}", t.body);
425    }
426
427    #[test]
428    fn test_translate_while_comparison() {
429        let code = "def scan(tokens):\n    i = 0\n    while i < len(tokens):\n        i += 1\n    return i\n";
430        let suite = Suite::parse(code, "<test>").expect("parse");
431        let target = TargetSpec {
432            func: "scan".to_string(),
433            line: 1,
434            percent: 100.0,
435            reason: "test".to_string(),
436        };
437        let t = translate_function_body(&target, suite.as_slice()).expect("no translation");
438        assert!(t.body.contains("while i <"), "got:\n{}", t.body);
439        assert!(t.body.contains("tokens.len()"), "got:\n{}", t.body);
440    }
441
442    // ── ndarray / ml_mode tests ───────────────────────────────────────────────
443
444    #[test]
445    fn test_ndarray_mode_replaces_vec_params() {
446        let code = "import numpy as np\ndef dot_product(a, b):\n    total = 0.0\n    for i in range(len(a)):\n        total += a[i] * b[i]\n    return total\n";
447        let source = InputSource::Snippet(code.to_string());
448        let targets = vec![TargetSpec {
449            func: "dot_product".to_string(),
450            line: 1,
451            percent: 100.0,
452            reason: "test".to_string(),
453        }];
454        let tmp = tempdir().expect("tempdir");
455        let result = generate_ml(&source, &targets, tmp.path(), false).expect("generate_ml");
456        let lib_rs = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/src/lib.rs")).unwrap();
457        assert_eq!(result.fallback_functions, 0);
458        assert!(lib_rs.contains("PyReadonlyArray1<f64>"), "got:\n{}", lib_rs);
459        assert!(lib_rs.contains("use numpy;"), "got:\n{}", lib_rs);
460        let cargo = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/Cargo.toml")).unwrap();
461        assert!(cargo.contains("numpy"), "got:\n{}", cargo);
462    }
463
464    #[test]
465    fn test_ndarray_mode_no_numpy_import_stays_vec() {
466        let code = "def dot_product(a, b):\n    total = 0.0\n    for i in range(len(a)):\n        total += a[i] * b[i]\n    return total\n";
467        let source = InputSource::Snippet(code.to_string());
468        let targets = vec![TargetSpec {
469            func: "dot_product".to_string(),
470            line: 1,
471            percent: 100.0,
472            reason: "test".to_string(),
473        }];
474        let tmp = tempdir().expect("tempdir");
475        let result = generate_ml(&source, &targets, tmp.path(), false).expect("generate_ml");
476        let lib_rs = std::fs::read_to_string(tmp.path().join("rustify_ml_ext/src/lib.rs")).unwrap();
477        assert_eq!(result.fallback_functions, 0);
478        assert!(!lib_rs.contains("PyReadonlyArray1"), "got:\n{}", lib_rs);
479        assert!(lib_rs.contains("Vec<f64>"), "got:\n{}", lib_rs);
480    }
481
482    // ── integration tests ─────────────────────────────────────────────────────
483
484    #[test]
485    fn test_generate_integration_euclidean() {
486        let path = PathBuf::from("examples/euclidean.py");
487        let code = std::fs::read_to_string(&path).expect("read example");
488        let source = InputSource::File {
489            path: path.clone(),
490            code,
491        };
492        let targets = vec![TargetSpec {
493            func: "euclidean".to_string(),
494            line: 1,
495            percent: 100.0,
496            reason: "test".to_string(),
497        }];
498        let tmp = tempdir().expect("tempdir");
499        let result = generate(&source, &targets, tmp.path(), false).expect("generate");
500        assert_eq!(result.generated_functions.len(), 1);
501        assert_eq!(result.fallback_functions, 0);
502        assert!(tmp.path().join("rustify_ml_ext/src/lib.rs").exists());
503    }
504}