Skip to main content

rustify_ml/
builder.rs

1use std::process::Command;
2
3use anyhow::{Context, Result, anyhow};
4use tracing::{info, warn};
5
6use crate::utils::{GenerationResult, InputSource, TargetSpec, extract_code};
7
8/// Run `cargo check` on the generated crate to catch translation errors early.
9/// Returns Ok(()) if the check passes, or an error with the compiler output.
10/// This is a fast-fail step: it does NOT require maturin or a Python environment.
11pub fn cargo_check_generated(r#gen: &GenerationResult) -> Result<()> {
12    info!(path = %r#gen.crate_dir.display(), "running cargo check on generated crate");
13
14    let output = Command::new("cargo")
15        .args(["check", "--message-format=short"])
16        .current_dir(&r#gen.crate_dir)
17        .output()
18        .context("failed to spawn cargo; ensure Rust toolchain is installed")?;
19
20    if !output.status.success() {
21        let stderr = String::from_utf8_lossy(&output.stderr);
22        let stdout = String::from_utf8_lossy(&output.stdout);
23        let combined = format!("{}\n{}", stdout.trim(), stderr.trim());
24        warn!(
25            path = %r#gen.crate_dir.display(),
26            "cargo check failed on generated crate — review generated code:\n{}",
27            combined
28        );
29        return Err(anyhow!(
30            "generated Rust code failed cargo check. Review {} and fix translation issues.\n\nCompiler output:\n{}",
31            r#gen.crate_dir.join("src/lib.rs").display(),
32            combined
33        ));
34    }
35
36    info!(path = %r#gen.crate_dir.display(), "cargo check passed on generated crate");
37    Ok(())
38}
39
40pub fn build_extension(r#gen: &GenerationResult, dry_run: bool) -> Result<()> {
41    if dry_run {
42        info!(path = %r#gen.crate_dir.display(), "dry-run: skipping maturin build");
43        return Ok(());
44    }
45
46    // Run cargo check first as a fast-fail before the full maturin build.
47    // Warn but don't abort if cargo is not available (e.g., unusual CI setups).
48    if let Err(e) = cargo_check_generated(r#gen) {
49        warn!(
50            path = %r#gen.crate_dir.display(),
51            err = %e,
52            "cargo check failed; proceeding with maturin anyway (review generated code)"
53        );
54    }
55
56    let status = Command::new("maturin")
57        .args(["develop", "--release"])
58        .current_dir(&r#gen.crate_dir)
59        .status()
60        .context("failed to spawn maturin; install via `pip install maturin` and ensure on PATH")?;
61
62    if !status.success() {
63        return Err(anyhow!("maturin build failed with status {status}"));
64    }
65
66    info!(
67        path = %r#gen.crate_dir.display(),
68        fallback_functions = r#gen.fallback_functions,
69        "maturin build completed"
70    );
71    Ok(())
72}
73
74/// Run a Python timing harness comparing the original Python function against the
75/// generated Rust extension. Prints a speedup table to stdout.
76///
77/// Requires: maturin develop already run (extension importable), Python on PATH.
78pub fn run_benchmark(
79    source: &InputSource,
80    result: &GenerationResult,
81    targets: &[TargetSpec],
82) -> Result<()> {
83    use crate::profiler::detect_python;
84
85    let python = detect_python()?;
86    let code = extract_code(source)?;
87    let module_name = result
88        .crate_dir
89        .file_name()
90        .and_then(|n| n.to_str())
91        .unwrap_or("rustify_ml_ext");
92
93    // Build a self-contained Python benchmark script
94    let func_names: Vec<String> = targets
95        .iter()
96        .filter(|t| {
97            // Only benchmark functions that were fully translated (no fallback)
98            result
99                .generated_functions
100                .iter()
101                .any(|f| f.contains(&format!("pub fn {}", t.func)) && !f.contains("// fallback"))
102        })
103        .map(|t| t.func.clone())
104        .collect();
105
106    if func_names.is_empty() {
107        warn!("no fully-translated functions to benchmark; skipping");
108        return Ok(());
109    }
110
111    let harness = build_benchmark_harness(&code, module_name, &func_names);
112
113    let output = Command::new(&python)
114        .args(["-c", &harness])
115        .output()
116        .with_context(|| format!("failed to run {} for benchmark", python))?;
117
118    if !output.status.success() {
119        let stderr = String::from_utf8_lossy(&output.stderr);
120        return Err(anyhow!("benchmark harness failed: {}", stderr.trim()));
121    }
122
123    let stdout = String::from_utf8_lossy(&output.stdout);
124    println!("\n{}", stdout.trim());
125    Ok(())
126}
127
128/// Generate a Python benchmark script that times original vs Rust for each function.
129fn build_benchmark_harness(code: &str, module_name: &str, func_names: &[String]) -> String {
130    let escaped_code = code.replace('\\', "\\\\").replace('"', "\\\"");
131    let funcs_list = func_names
132        .iter()
133        .map(|f| format!("\"{}\"", f))
134        .collect::<Vec<_>>()
135        .join(", ");
136
137    format!(
138        r#"
139import timeit, sys, importlib, types
140
141# --- original Python code ---
142_src = """{code}"""
143_mod = types.ModuleType("_orig")
144exec(compile(_src, "<rustify_bench>", "exec"), _mod.__dict__)
145
146# --- accelerated Rust extension ---
147try:
148    _ext = importlib.import_module("{module}")
149except ImportError as e:
150    print(f"Could not import {module}: {{e}}")
151    sys.exit(1)
152
153_funcs = [{funcs}]
154_iters = 1000
155
156print()
157print(f"{{'':-<60}}")
158print(f"  rustify-ml benchmark  ({{_iters}} iterations each)")
159print(f"{{'':-<60}}")
160print(f"  {{\"Function\":<22}} | {{\"Python\":>10}} | {{\"Rust\":>10}} | {{\"Speedup\":>8}}")
161print(f"  {{'':-<22}}-+-{{'':-<10}}-+-{{'':-<10}}-+-{{'':-<8}}")
162
163for fn_name in _funcs:
164    py_fn = getattr(_mod, fn_name, None)
165    rs_fn = getattr(_ext, fn_name, None)
166    if py_fn is None or rs_fn is None:
167        print(f"  {{fn_name:<22}} | skipped (not found)")
168        continue
169    # Build a simple call with dummy float vectors
170    try:
171        import inspect
172        sig = inspect.signature(py_fn)
173        n_params = len(sig.parameters)
174        dummy = [float(i) for i in range(100)]
175        args = tuple(dummy for _ in range(n_params))
176        py_time = timeit.timeit(lambda: py_fn(*args), number=_iters)
177        rs_time = timeit.timeit(lambda: rs_fn(*args), number=_iters)
178        speedup = py_time / rs_time if rs_time > 0 else float("inf")
179        print(f"  {{fn_name:<22}} | {{py_time:>9.4f}}s | {{rs_time:>9.4f}}s | {{speedup:>7.1f}}x")
180    except Exception as e:
181        print(f"  {{fn_name:<22}} | error: {{e}}")
182
183print(f"{{'':-<60}}")
184print()
185"#,
186        code = escaped_code,
187        module = module_name,
188        funcs = funcs_list,
189    )
190}