1use std::process::Command;
2
3use anyhow::{Context, Result, anyhow};
4use tracing::{info, warn};
5
6use crate::utils::{GenerationResult, InputSource, TargetSpec, extract_code};
7
8pub fn cargo_check_generated(r#gen: &GenerationResult) -> Result<()> {
12 info!(path = %r#gen.crate_dir.display(), "running cargo check on generated crate");
13
14 let output = Command::new("cargo")
15 .args(["check", "--message-format=short"])
16 .current_dir(&r#gen.crate_dir)
17 .output()
18 .context("failed to spawn cargo; ensure Rust toolchain is installed")?;
19
20 if !output.status.success() {
21 let stderr = String::from_utf8_lossy(&output.stderr);
22 let stdout = String::from_utf8_lossy(&output.stdout);
23 let combined = format!("{}\n{}", stdout.trim(), stderr.trim());
24 warn!(
25 path = %r#gen.crate_dir.display(),
26 "cargo check failed on generated crate — review generated code:\n{}",
27 combined
28 );
29 return Err(anyhow!(
30 "generated Rust code failed cargo check. Review {} and fix translation issues.\n\nCompiler output:\n{}",
31 r#gen.crate_dir.join("src/lib.rs").display(),
32 combined
33 ));
34 }
35
36 info!(path = %r#gen.crate_dir.display(), "cargo check passed on generated crate");
37 Ok(())
38}
39
40pub fn build_extension(r#gen: &GenerationResult, dry_run: bool) -> Result<()> {
41 if dry_run {
42 info!(path = %r#gen.crate_dir.display(), "dry-run: skipping maturin build");
43 return Ok(());
44 }
45
46 if let Err(e) = cargo_check_generated(r#gen) {
49 warn!(
50 path = %r#gen.crate_dir.display(),
51 err = %e,
52 "cargo check failed; proceeding with maturin anyway (review generated code)"
53 );
54 }
55
56 let status = Command::new("maturin")
57 .args(["develop", "--release"])
58 .current_dir(&r#gen.crate_dir)
59 .status()
60 .context("failed to spawn maturin; install via `pip install maturin` and ensure on PATH")?;
61
62 if !status.success() {
63 return Err(anyhow!("maturin build failed with status {status}"));
64 }
65
66 info!(
67 path = %r#gen.crate_dir.display(),
68 fallback_functions = r#gen.fallback_functions,
69 "maturin build completed"
70 );
71 Ok(())
72}
73
74pub fn run_benchmark(
79 source: &InputSource,
80 result: &GenerationResult,
81 targets: &[TargetSpec],
82) -> Result<()> {
83 use crate::profiler::detect_python;
84
85 let python = detect_python()?;
86 let code = extract_code(source)?;
87 let module_name = result
88 .crate_dir
89 .file_name()
90 .and_then(|n| n.to_str())
91 .unwrap_or("rustify_ml_ext");
92
93 let func_names: Vec<String> = targets
95 .iter()
96 .filter(|t| {
97 result
99 .generated_functions
100 .iter()
101 .any(|f| f.contains(&format!("pub fn {}", t.func)) && !f.contains("// fallback"))
102 })
103 .map(|t| t.func.clone())
104 .collect();
105
106 if func_names.is_empty() {
107 warn!("no fully-translated functions to benchmark; skipping");
108 return Ok(());
109 }
110
111 let harness = build_benchmark_harness(&code, module_name, &func_names);
112
113 let output = Command::new(&python)
114 .args(["-c", &harness])
115 .output()
116 .with_context(|| format!("failed to run {} for benchmark", python))?;
117
118 if !output.status.success() {
119 let stderr = String::from_utf8_lossy(&output.stderr);
120 return Err(anyhow!("benchmark harness failed: {}", stderr.trim()));
121 }
122
123 let stdout = String::from_utf8_lossy(&output.stdout);
124 println!("\n{}", stdout.trim());
125 Ok(())
126}
127
128fn build_benchmark_harness(code: &str, module_name: &str, func_names: &[String]) -> String {
130 let escaped_code = code.replace('\\', "\\\\").replace('"', "\\\"");
131 let funcs_list = func_names
132 .iter()
133 .map(|f| format!("\"{}\"", f))
134 .collect::<Vec<_>>()
135 .join(", ");
136
137 format!(
138 r#"
139import timeit, sys, importlib, types
140
141# --- original Python code ---
142_src = """{code}"""
143_mod = types.ModuleType("_orig")
144exec(compile(_src, "<rustify_bench>", "exec"), _mod.__dict__)
145
146# --- accelerated Rust extension ---
147try:
148 _ext = importlib.import_module("{module}")
149except ImportError as e:
150 print(f"Could not import {module}: {{e}}")
151 sys.exit(1)
152
153_funcs = [{funcs}]
154_iters = 1000
155
156print()
157print(f"{{'':-<60}}")
158print(f" rustify-ml benchmark ({{_iters}} iterations each)")
159print(f"{{'':-<60}}")
160print(f" {{\"Function\":<22}} | {{\"Python\":>10}} | {{\"Rust\":>10}} | {{\"Speedup\":>8}}")
161print(f" {{'':-<22}}-+-{{'':-<10}}-+-{{'':-<10}}-+-{{'':-<8}}")
162
163for fn_name in _funcs:
164 py_fn = getattr(_mod, fn_name, None)
165 rs_fn = getattr(_ext, fn_name, None)
166 if py_fn is None or rs_fn is None:
167 print(f" {{fn_name:<22}} | skipped (not found)")
168 continue
169 # Build a simple call with dummy float vectors
170 try:
171 import inspect
172 sig = inspect.signature(py_fn)
173 n_params = len(sig.parameters)
174 dummy = [float(i) for i in range(100)]
175 args = tuple(dummy for _ in range(n_params))
176 py_time = timeit.timeit(lambda: py_fn(*args), number=_iters)
177 rs_time = timeit.timeit(lambda: rs_fn(*args), number=_iters)
178 speedup = py_time / rs_time if rs_time > 0 else float("inf")
179 print(f" {{fn_name:<22}} | {{py_time:>9.4f}}s | {{rs_time:>9.4f}}s | {{speedup:>7.1f}}x")
180 except Exception as e:
181 print(f" {{fn_name:<22}} | error: {{e}}")
182
183print(f"{{'':-<60}}")
184print()
185"#,
186 code = escaped_code,
187 module = module_name,
188 funcs = funcs_list,
189 )
190}