1use std::time::{Duration, Instant};
9
10use anyhow::Result;
11use tensorlogic_compiler::{compile_to_einsum_with_context, CompilerContext};
12use tensorlogic_ir::{EinsumGraph, TLExpr};
13
14use crate::executor::{Backend, CliExecutor, ExecutionConfig};
15use crate::optimize::{optimize_einsum_graph, OptimizationConfig, OptimizationLevel};
16use crate::output::{print_header, print_info};
17
18#[derive(Debug, Clone)]
20pub struct BenchmarkResults {
21 pub compilation_times: Vec<f64>,
23 pub execution_times: Vec<f64>,
25 pub optimization_times: Vec<f64>,
27}
28
29impl BenchmarkResults {
30 pub fn new() -> Self {
32 Self {
33 compilation_times: Vec::new(),
34 execution_times: Vec::new(),
35 optimization_times: Vec::new(),
36 }
37 }
38
39 fn stats(times: &[f64]) -> (f64, f64, f64, f64) {
41 if times.is_empty() {
42 return (0.0, 0.0, 0.0, 0.0);
43 }
44
45 let n = times.len() as f64;
46 let mean = times.iter().sum::<f64>() / n;
47 let variance = times.iter().map(|t| (t - mean).powi(2)).sum::<f64>() / n;
48 let std_dev = variance.sqrt();
49 let min = times.iter().cloned().fold(f64::INFINITY, f64::min);
50 let max = times.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
51
52 (mean, std_dev, min, max)
53 }
54
55 pub fn print_summary(&self) {
57 if !self.compilation_times.is_empty() {
58 let (mean, std, min, max) = Self::stats(&self.compilation_times);
59 println!("\nCompilation Benchmark:");
60 println!(" Iterations: {}", self.compilation_times.len());
61 println!(" Mean: {:.3} ms", mean);
62 println!(" Std Dev: {:.3} ms", std);
63 println!(" Min: {:.3} ms", min);
64 println!(" Max: {:.3} ms", max);
65 println!(" Throughput: {:.2} compilations/sec", 1000.0 / mean);
66 }
67
68 if !self.execution_times.is_empty() {
69 let (mean, std, min, max) = Self::stats(&self.execution_times);
70 println!("\nExecution Benchmark:");
71 println!(" Iterations: {}", self.execution_times.len());
72 println!(" Mean: {:.3} ms", mean);
73 println!(" Std Dev: {:.3} ms", std);
74 println!(" Min: {:.3} ms", min);
75 println!(" Max: {:.3} ms", max);
76 println!(" Throughput: {:.2} executions/sec", 1000.0 / mean);
77 }
78
79 if !self.optimization_times.is_empty() {
80 let (mean, std, min, max) = Self::stats(&self.optimization_times);
81 println!("\nOptimization Benchmark:");
82 println!(" Iterations: {}", self.optimization_times.len());
83 println!(" Mean: {:.3} ms", mean);
84 println!(" Std Dev: {:.3} ms", std);
85 println!(" Min: {:.3} ms", min);
86 println!(" Max: {:.3} ms", max);
87 }
88 }
89
90 pub fn to_json(&self) -> serde_json::Value {
92 let mut result = serde_json::Map::new();
93
94 if !self.compilation_times.is_empty() {
95 let (mean, std, min, max) = Self::stats(&self.compilation_times);
96 result.insert(
97 "compilation".to_string(),
98 serde_json::json!({
99 "iterations": self.compilation_times.len(),
100 "mean_ms": mean,
101 "std_dev_ms": std,
102 "min_ms": min,
103 "max_ms": max,
104 "times_ms": self.compilation_times,
105 }),
106 );
107 }
108
109 if !self.execution_times.is_empty() {
110 let (mean, std, min, max) = Self::stats(&self.execution_times);
111 result.insert(
112 "execution".to_string(),
113 serde_json::json!({
114 "iterations": self.execution_times.len(),
115 "mean_ms": mean,
116 "std_dev_ms": std,
117 "min_ms": min,
118 "max_ms": max,
119 "times_ms": self.execution_times,
120 }),
121 );
122 }
123
124 if !self.optimization_times.is_empty() {
125 let (mean, std, min, max) = Self::stats(&self.optimization_times);
126 result.insert(
127 "optimization".to_string(),
128 serde_json::json!({
129 "iterations": self.optimization_times.len(),
130 "mean_ms": mean,
131 "std_dev_ms": std,
132 "min_ms": min,
133 "max_ms": max,
134 "times_ms": self.optimization_times,
135 }),
136 );
137 }
138
139 serde_json::Value::Object(result)
140 }
141}
142
143impl Default for BenchmarkResults {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149pub struct Benchmarker {
151 iterations: usize,
152 verbose: bool,
153 quiet: bool,
154}
155
156impl Benchmarker {
157 #[allow(dead_code)]
159 pub fn new(iterations: usize, verbose: bool) -> Self {
160 Self {
161 iterations,
162 verbose,
163 quiet: false,
164 }
165 }
166
167 pub fn with_quiet(iterations: usize, verbose: bool, quiet: bool) -> Self {
169 Self {
170 iterations,
171 verbose,
172 quiet,
173 }
174 }
175
176 pub fn benchmark_compilation(
178 &self,
179 expr: &TLExpr,
180 context: &CompilerContext,
181 ) -> Result<Vec<f64>> {
182 let mut times = Vec::with_capacity(self.iterations);
183
184 if !self.quiet {
185 print_header("Benchmarking compilation...");
186 }
187
188 let mut ctx = context.clone();
190 compile_to_einsum_with_context(expr, &mut ctx)?;
191
192 for i in 0..self.iterations {
193 let mut ctx = context.clone();
194 let start = Instant::now();
195 compile_to_einsum_with_context(expr, &mut ctx)?;
196 let elapsed = start.elapsed();
197 let ms = elapsed.as_secs_f64() * 1000.0;
198 times.push(ms);
199
200 if self.verbose && !self.quiet {
201 print_info(&format!(" Iteration {}: {:.3} ms", i + 1, ms));
202 }
203 }
204
205 Ok(times)
206 }
207
208 pub fn benchmark_execution(&self, graph: &EinsumGraph, backend: Backend) -> Result<Vec<f64>> {
210 let mut times = Vec::with_capacity(self.iterations);
211
212 if !self.quiet {
213 print_header(&format!("Benchmarking execution ({})...", backend.name()));
214 }
215
216 let config = ExecutionConfig {
217 backend,
218 device: tensorlogic_scirs_backend::DeviceType::Cpu,
219 show_metrics: false,
220 show_intermediates: false,
221 validate_shapes: false,
222 trace: false,
223 };
224
225 let executor = CliExecutor::new(config)?;
226
227 let _ = executor.execute(graph);
229
230 for i in 0..self.iterations {
231 let start = Instant::now();
232 let _ = executor.execute(graph);
233 let elapsed = start.elapsed();
234 let ms = elapsed.as_secs_f64() * 1000.0;
235 times.push(ms);
236
237 if self.verbose && !self.quiet {
238 print_info(&format!(" Iteration {}: {:.3} ms", i + 1, ms));
239 }
240 }
241
242 Ok(times)
243 }
244
245 pub fn benchmark_optimization(
247 &self,
248 expr: &TLExpr,
249 context: &CompilerContext,
250 ) -> Result<Vec<f64>> {
251 let mut times = Vec::with_capacity(self.iterations);
252
253 if !self.quiet {
254 print_header("Benchmarking optimization...");
255 }
256
257 let opt_config = OptimizationConfig {
258 level: OptimizationLevel::Basic, enable_dce: true,
260 enable_cse: true,
261 enable_identity: true,
262 show_stats: false,
263 verbose: false,
264 };
265
266 let mut ctx = context.clone();
268 let graph = compile_to_einsum_with_context(expr, &mut ctx)?;
269 let _ = optimize_einsum_graph(graph, &opt_config);
270
271 for i in 0..self.iterations {
272 let mut ctx = context.clone();
273 let graph = compile_to_einsum_with_context(expr, &mut ctx)?;
274
275 let start = Instant::now();
276 let _ = optimize_einsum_graph(graph, &opt_config);
277 let elapsed = start.elapsed();
278 let ms = elapsed.as_secs_f64() * 1000.0;
279 times.push(ms);
280
281 if self.verbose && !self.quiet {
282 print_info(&format!(" Iteration {}: {:.3} ms", i + 1, ms));
283 }
284 }
285
286 Ok(times)
287 }
288}
289
290#[allow(dead_code)]
292pub fn format_duration(duration: Duration) -> String {
293 let ms = duration.as_secs_f64() * 1000.0;
294 if ms < 1.0 {
295 format!("{:.1} µs", ms * 1000.0)
296 } else if ms < 1000.0 {
297 format!("{:.3} ms", ms)
298 } else {
299 format!("{:.3} s", ms / 1000.0)
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_benchmark_results_stats() {
309 let times = vec![10.0, 20.0, 30.0];
310 let (mean, std, min, max) = BenchmarkResults::stats(×);
311
312 assert!((mean - 20.0).abs() < 0.001);
313 assert!(std > 0.0);
314 assert_eq!(min, 10.0);
315 assert_eq!(max, 30.0);
316 }
317
318 #[test]
319 fn test_benchmark_results_empty() {
320 let times: Vec<f64> = vec![];
321 let (mean, std, min, max) = BenchmarkResults::stats(×);
322
323 assert_eq!(mean, 0.0);
324 assert_eq!(std, 0.0);
325 assert_eq!(min, 0.0);
326 assert_eq!(max, 0.0);
327 }
328
329 #[test]
330 fn test_benchmark_results_json() {
331 let mut results = BenchmarkResults::new();
332 results.compilation_times = vec![10.0, 20.0, 30.0];
333
334 let json = results.to_json();
335 assert!(json.get("compilation").is_some());
336 }
337}