Skip to main content

rustify_ml/
builder.rs

1use std::process::Command;
2
3use anyhow::{Context, Result, anyhow};
4use tracing::{info, warn};
5
6use crate::utils::{GenerationResult, InputSource, TargetSpec, extract_code};
7
8/// Run `cargo check` on the generated crate to catch translation errors early.
9/// Returns Ok(()) if the check passes, or an error with the compiler output.
10/// This is a fast-fail step: it does NOT require maturin or a Python environment.
11pub fn cargo_check_generated(r#gen: &GenerationResult) -> Result<()> {
12    info!(path = %r#gen.crate_dir.display(), "running cargo check on generated crate");
13
14    let output = Command::new("cargo")
15        .args(["check", "--message-format=short"])
16        .current_dir(&r#gen.crate_dir)
17        .output()
18        .context("failed to spawn cargo; ensure Rust toolchain is installed")?;
19
20    if !output.status.success() {
21        let stderr = String::from_utf8_lossy(&output.stderr);
22        let stdout = String::from_utf8_lossy(&output.stdout);
23        let combined = format!("{}\n{}", stdout.trim(), stderr.trim());
24        warn!(
25            path = %r#gen.crate_dir.display(),
26            "cargo check failed on generated crate — review generated code:\n{}",
27            combined
28        );
29        return Err(anyhow!(
30            "generated Rust code failed cargo check. Review {} and fix translation issues.\n\nCompiler output:\n{}",
31            r#gen.crate_dir.join("src/lib.rs").display(),
32            combined
33        ));
34    }
35
36    info!(path = %r#gen.crate_dir.display(), "cargo check passed on generated crate");
37    Ok(())
38}
39
40pub fn build_extension(r#gen: &GenerationResult, dry_run: bool) -> Result<()> {
41    if dry_run {
42        info!(path = %r#gen.crate_dir.display(), "dry-run: skipping maturin build");
43        return Ok(());
44    }
45
46    // Run cargo check first as a fast-fail before the full maturin build.
47    // Warn but don't abort if cargo is not available (e.g., unusual CI setups).
48    if let Err(e) = cargo_check_generated(r#gen) {
49        warn!(
50            path = %r#gen.crate_dir.display(),
51            err = %e,
52            "cargo check failed; proceeding with maturin anyway (review generated code)"
53        );
54    }
55
56    // Ensure maturin is available with a user-friendly hint.
57    if let Err(e) = Command::new("maturin").arg("--version").output() {
58        return Err(anyhow!(
59            "maturin not found: install with `pip install maturin` and ensure it is on PATH (error: {e})"
60        ));
61    }
62
63    let status = Command::new("maturin")
64        .args(["develop", "--release"])
65        .current_dir(&r#gen.crate_dir)
66        .status()
67        .context("failed to spawn maturin; install via `pip install maturin` and ensure on PATH")?;
68
69    if !status.success() {
70        return Err(anyhow!("maturin build failed with status {status}"));
71    }
72
73    info!(
74        path = %r#gen.crate_dir.display(),
75        fallback_functions = r#gen.fallback_functions,
76        "maturin build completed"
77    );
78    Ok(())
79}
80
81/// Run a Python timing harness comparing the original Python function against the
82/// generated Rust extension. Prints a speedup table to stdout.
83///
84/// Requires: maturin develop already run (extension importable), Python on PATH.
85pub fn run_benchmark(
86    source: &InputSource,
87    result: &GenerationResult,
88    targets: &[TargetSpec],
89) -> Result<()> {
90    use crate::profiler::detect_python;
91
92    let python = detect_python()?;
93    let code = extract_code(source)?;
94    let module_name = result
95        .crate_dir
96        .file_name()
97        .and_then(|n| n.to_str())
98        .unwrap_or("rustify_ml_ext");
99
100    // Build a self-contained Python benchmark script
101    let func_names: Vec<String> = targets
102        .iter()
103        .filter(|t| {
104            // Only benchmark functions that were fully translated (no fallback)
105            result
106                .generated_functions
107                .iter()
108                .any(|f| f.contains(&format!("pub fn {}", t.func)) && !f.contains("// fallback"))
109        })
110        .map(|t| t.func.clone())
111        .collect();
112
113    if func_names.is_empty() {
114        warn!("no fully-translated functions to benchmark; skipping");
115        return Ok(());
116    }
117
118    let harness = build_benchmark_harness(&code, module_name, &func_names);
119
120    let output = Command::new(&python)
121        .args(["-c", &harness])
122        .output()
123        .with_context(|| format!("failed to run {} for benchmark", python))?;
124
125    if !output.status.success() {
126        let stderr = String::from_utf8_lossy(&output.stderr);
127        return Err(anyhow!("benchmark harness failed: {}", stderr.trim()));
128    }
129
130    let stdout = String::from_utf8_lossy(&output.stdout);
131    println!("\n{}", stdout.trim());
132    Ok(())
133}
134
135/// Generate a Python benchmark script that times original vs Rust for each function.
136fn build_benchmark_harness(code: &str, module_name: &str, func_names: &[String]) -> String {
137    let escaped_code = code.replace('\\', "\\\\").replace('"', "\\\"");
138    let funcs_list = func_names
139        .iter()
140        .map(|f| format!("\"{}\"", f))
141        .collect::<Vec<_>>()
142        .join(", ");
143
144    format!(
145        r#"
146import timeit, sys, importlib, types
147
148# --- original Python code ---
149_src = """{code}"""
150_mod = types.ModuleType("_orig")
151exec(compile(_src, "<rustify_bench>", "exec"), _mod.__dict__)
152
153# --- accelerated Rust extension ---
154try:
155    _ext = importlib.import_module("{module}")
156except ImportError as e:
157    print(f"Could not import {module}: {{e}}")
158    sys.exit(1)
159
160_funcs = [{funcs}]
161_iters = 1000
162
163print()
164print(f"{{'':-<60}}")
165print(f"  rustify-ml benchmark  ({{_iters}} iterations each)")
166print(f"{{'':-<60}}")
167print(f"  {{\"Function\":<22}} | {{\"Python\":>10}} | {{\"Rust\":>10}} | {{\"Speedup\":>8}}")
168print(f"  {{'':-<22}}-+-{{'':-<10}}-+-{{'':-<10}}-+-{{'':-<8}}")
169
170for fn_name in _funcs:
171    py_fn = getattr(_mod, fn_name, None)
172    rs_fn = getattr(_ext, fn_name, None)
173    if py_fn is None or rs_fn is None:
174        print(f"  {{fn_name:<22}} | skipped (not found)")
175        continue
176    # Build a simple call with dummy float vectors
177    try:
178        import inspect
179        sig = inspect.signature(py_fn)
180        n_params = len(sig.parameters)
181        dummy = [float(i) for i in range(100)]
182        args = tuple(dummy for _ in range(n_params))
183        py_time = timeit.timeit(lambda: py_fn(*args), number=_iters)
184        rs_time = timeit.timeit(lambda: rs_fn(*args), number=_iters)
185        speedup = py_time / rs_time if rs_time > 0 else float("inf")
186        print(f"  {{fn_name:<22}} | {{py_time:>9.4f}}s | {{rs_time:>9.4f}}s | {{speedup:>7.1f}}x")
187    except Exception as e:
188        print(f"  {{fn_name:<22}} | error: {{e}}")
189
190print(f"{{'':-<60}}")
191print()
192"#,
193        code = escaped_code,
194        module = module_name,
195        funcs = funcs_list,
196    )
197}