1use std::process::Command;
2
3use anyhow::{Context, Result, anyhow};
4use tracing::{info, warn};
5
6use crate::utils::{GenerationResult, InputSource, TargetSpec, extract_code};
7
8pub fn cargo_check_generated(r#gen: &GenerationResult) -> Result<()> {
12 info!(path = %r#gen.crate_dir.display(), "running cargo check on generated crate");
13
14 let output = Command::new("cargo")
15 .args(["check", "--message-format=short"])
16 .current_dir(&r#gen.crate_dir)
17 .output()
18 .context("failed to spawn cargo; ensure Rust toolchain is installed")?;
19
20 if !output.status.success() {
21 let stderr = String::from_utf8_lossy(&output.stderr);
22 let stdout = String::from_utf8_lossy(&output.stdout);
23 let combined = format!("{}\n{}", stdout.trim(), stderr.trim());
24 warn!(
25 path = %r#gen.crate_dir.display(),
26 "cargo check failed on generated crate — review generated code:\n{}",
27 combined
28 );
29 return Err(anyhow!(
30 "generated Rust code failed cargo check. Review {} and fix translation issues.\n\nCompiler output:\n{}",
31 r#gen.crate_dir.join("src/lib.rs").display(),
32 combined
33 ));
34 }
35
36 info!(path = %r#gen.crate_dir.display(), "cargo check passed on generated crate");
37 Ok(())
38}
39
40pub fn build_extension(r#gen: &GenerationResult, dry_run: bool) -> Result<()> {
41 if dry_run {
42 info!(path = %r#gen.crate_dir.display(), "dry-run: skipping maturin build");
43 return Ok(());
44 }
45
46 if let Err(e) = cargo_check_generated(r#gen) {
49 warn!(
50 path = %r#gen.crate_dir.display(),
51 err = %e,
52 "cargo check failed; proceeding with maturin anyway (review generated code)"
53 );
54 }
55
56 if let Err(e) = Command::new("maturin").arg("--version").output() {
58 return Err(anyhow!(
59 "maturin not found: install with `pip install maturin` and ensure it is on PATH (error: {e})"
60 ));
61 }
62
63 let status = Command::new("maturin")
64 .args(["develop", "--release"])
65 .current_dir(&r#gen.crate_dir)
66 .status()
67 .context("failed to spawn maturin; install via `pip install maturin` and ensure on PATH")?;
68
69 if !status.success() {
70 return Err(anyhow!("maturin build failed with status {status}"));
71 }
72
73 info!(
74 path = %r#gen.crate_dir.display(),
75 fallback_functions = r#gen.fallback_functions,
76 "maturin build completed"
77 );
78 Ok(())
79}
80
81pub fn run_benchmark(
86 source: &InputSource,
87 result: &GenerationResult,
88 targets: &[TargetSpec],
89) -> Result<()> {
90 use crate::profiler::detect_python;
91
92 let python = detect_python()?;
93 let code = extract_code(source)?;
94 let module_name = result
95 .crate_dir
96 .file_name()
97 .and_then(|n| n.to_str())
98 .unwrap_or("rustify_ml_ext");
99
100 let func_names: Vec<String> = targets
102 .iter()
103 .filter(|t| {
104 result
106 .generated_functions
107 .iter()
108 .any(|f| f.contains(&format!("pub fn {}", t.func)) && !f.contains("// fallback"))
109 })
110 .map(|t| t.func.clone())
111 .collect();
112
113 if func_names.is_empty() {
114 warn!("no fully-translated functions to benchmark; skipping");
115 return Ok(());
116 }
117
118 let harness = build_benchmark_harness(&code, module_name, &func_names);
119
120 let output = Command::new(&python)
121 .args(["-c", &harness])
122 .output()
123 .with_context(|| format!("failed to run {} for benchmark", python))?;
124
125 if !output.status.success() {
126 let stderr = String::from_utf8_lossy(&output.stderr);
127 return Err(anyhow!("benchmark harness failed: {}", stderr.trim()));
128 }
129
130 let stdout = String::from_utf8_lossy(&output.stdout);
131 println!("\n{}", stdout.trim());
132 Ok(())
133}
134
135fn build_benchmark_harness(code: &str, module_name: &str, func_names: &[String]) -> String {
137 let escaped_code = code.replace('\\', "\\\\").replace('"', "\\\"");
138 let funcs_list = func_names
139 .iter()
140 .map(|f| format!("\"{}\"", f))
141 .collect::<Vec<_>>()
142 .join(", ");
143
144 format!(
145 r#"
146import timeit, sys, importlib, types
147
148# --- original Python code ---
149_src = """{code}"""
150_mod = types.ModuleType("_orig")
151exec(compile(_src, "<rustify_bench>", "exec"), _mod.__dict__)
152
153# --- accelerated Rust extension ---
154try:
155 _ext = importlib.import_module("{module}")
156except ImportError as e:
157 print(f"Could not import {module}: {{e}}")
158 sys.exit(1)
159
160_funcs = [{funcs}]
161_iters = 1000
162
163print()
164print(f"{{'':-<60}}")
165print(f" rustify-ml benchmark ({{_iters}} iterations each)")
166print(f"{{'':-<60}}")
167print(f" {{\"Function\":<22}} | {{\"Python\":>10}} | {{\"Rust\":>10}} | {{\"Speedup\":>8}}")
168print(f" {{'':-<22}}-+-{{'':-<10}}-+-{{'':-<10}}-+-{{'':-<8}}")
169
170for fn_name in _funcs:
171 py_fn = getattr(_mod, fn_name, None)
172 rs_fn = getattr(_ext, fn_name, None)
173 if py_fn is None or rs_fn is None:
174 print(f" {{fn_name:<22}} | skipped (not found)")
175 continue
176 # Build a simple call with dummy float vectors
177 try:
178 import inspect
179 sig = inspect.signature(py_fn)
180 n_params = len(sig.parameters)
181 dummy = [float(i) for i in range(100)]
182 args = tuple(dummy for _ in range(n_params))
183 py_time = timeit.timeit(lambda: py_fn(*args), number=_iters)
184 rs_time = timeit.timeit(lambda: rs_fn(*args), number=_iters)
185 speedup = py_time / rs_time if rs_time > 0 else float("inf")
186 print(f" {{fn_name:<22}} | {{py_time:>9.4f}}s | {{rs_time:>9.4f}}s | {{speedup:>7.1f}}x")
187 except Exception as e:
188 print(f" {{fn_name:<22}} | error: {{e}}")
189
190print(f"{{'':-<60}}")
191print()
192"#,
193 code = escaped_code,
194 module = module_name,
195 funcs = funcs_list,
196 )
197}