1use crate::{FxGraph, TorshResult};
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
12pub struct BenchmarkResult {
13 pub operation_name: String,
14 pub execution_time: Duration,
15 pub memory_usage: Option<usize>,
16 pub iterations: usize,
17 pub success_rate: f64,
18}
19
20#[derive(Debug)]
22pub struct GraphBenchmarkSuite {
23 results: HashMap<String, Vec<BenchmarkResult>>,
24 warmup_iterations: usize,
25 benchmark_iterations: usize,
26}
27
28impl GraphBenchmarkSuite {
29 pub fn new() -> Self {
31 Self {
32 results: HashMap::new(),
33 warmup_iterations: 10,
34 benchmark_iterations: 100,
35 }
36 }
37
38 pub fn with_warmup_iterations(mut self, iterations: usize) -> Self {
40 self.warmup_iterations = iterations;
41 self
42 }
43
44 pub fn with_benchmark_iterations(mut self, iterations: usize) -> Self {
46 self.benchmark_iterations = iterations;
47 self
48 }
49
50 pub fn benchmark_graph_creation(&mut self) -> TorshResult<()> {
52 let result = self.benchmark_operation("single_op_creation", || {
54 let _graph = FxGraph::single_op("relu", vec!["input".to_string()]);
55 Ok(())
56 })?;
57
58 let result_seq = self.benchmark_operation("sequential_ops_creation", || {
60 let _graph = FxGraph::sequential_ops(&["relu", "sigmoid", "tanh"]);
61 Ok(())
62 })?;
63
64 let result_large = self.benchmark_operation("large_graph_creation", || {
66 let ops: Vec<&str> = (0..100)
67 .map(|i| {
68 if i % 3 == 0 {
69 "relu"
70 } else if i % 3 == 1 {
71 "sigmoid"
72 } else {
73 "tanh"
74 }
75 })
76 .collect();
77 let _graph = FxGraph::sequential_ops(&ops);
78 Ok(())
79 })?;
80
81 self.results
82 .entry("graph_creation".to_string())
83 .or_insert_with(Vec::new)
84 .extend([result, result_seq, result_large]);
85
86 Ok(())
87 }
88
89 pub fn benchmark_serialization(&mut self) -> TorshResult<()> {
91 let test_graph = FxGraph::sequential_ops(&["relu", "sigmoid", "tanh", "softmax"]);
92
93 let json_serialize = self.benchmark_operation("json_serialize", || {
95 let _json = test_graph.to_json()?;
96 Ok(())
97 })?;
98
99 let binary_serialize = self.benchmark_operation("binary_serialize", || {
101 let _binary = test_graph.to_binary()?;
102 Ok(())
103 })?;
104
105 let json_data = test_graph.to_json()?;
107 let json_deserialize = self.benchmark_operation("json_deserialize", || {
108 let _graph = FxGraph::from_json(&json_data)?;
109 Ok(())
110 })?;
111
112 let binary_data = test_graph.to_binary()?;
114 let binary_deserialize = self.benchmark_operation("binary_deserialize", || {
115 let _graph = FxGraph::from_binary(&binary_data)?;
116 Ok(())
117 })?;
118
119 self.results
120 .entry("serialization".to_string())
121 .or_insert_with(Vec::new)
122 .extend([
123 json_serialize,
124 binary_serialize,
125 json_deserialize,
126 binary_deserialize,
127 ]);
128
129 Ok(())
130 }
131
132 pub fn benchmark_analysis(&mut self) -> TorshResult<()> {
134 let test_graph =
135 FxGraph::sequential_ops(&["relu", "sigmoid", "tanh", "softmax", "dropout"]);
136
137 let validation = self.benchmark_operation("graph_validation", || {
139 let _result = test_graph.validate()?;
140 Ok(())
141 })?;
142
143 let node_filtering = self.benchmark_operation("node_filtering", || {
145 let _inputs = test_graph.input_nodes();
146 let _outputs = test_graph.output_nodes();
147 let _calls = test_graph.call_nodes();
148 Ok(())
149 })?;
150
151 let summary = self.benchmark_operation("summary_generation", || {
153 let _summary = test_graph.summary();
154 Ok(())
155 })?;
156
157 self.results
158 .entry("analysis".to_string())
159 .or_insert_with(Vec::new)
160 .extend([validation, node_filtering, summary]);
161
162 Ok(())
163 }
164
165 pub fn benchmark_codegen(&mut self) -> TorshResult<()> {
167 let test_graph = FxGraph::sequential_ops(&["relu", "sigmoid", "tanh"]);
168
169 let python_codegen = self.benchmark_operation("python_codegen", || {
171 let _code = test_graph.to_python()?;
172 Ok(())
173 })?;
174
175 let cpp_codegen = self.benchmark_operation("cpp_codegen", || {
177 let _code = test_graph.to_cpp()?;
178 Ok(())
179 })?;
180
181 self.results
182 .entry("codegen".to_string())
183 .or_insert_with(Vec::new)
184 .extend([python_codegen, cpp_codegen]);
185
186 Ok(())
187 }
188
189 pub fn benchmark_operation<F>(
191 &self,
192 name: &str,
193 mut operation: F,
194 ) -> TorshResult<BenchmarkResult>
195 where
196 F: FnMut() -> TorshResult<()>,
197 {
198 for _ in 0..self.warmup_iterations {
200 let _ = operation();
201 }
202
203 let mut total_time = Duration::ZERO;
205 let mut successful_runs = 0;
206
207 for _ in 0..self.benchmark_iterations {
208 let start = Instant::now();
209 match operation() {
210 Ok(_) => {
211 total_time += start.elapsed();
212 successful_runs += 1;
213 }
214 Err(_) => {} }
216 }
217
218 let avg_time = if successful_runs > 0 {
219 total_time / successful_runs as u32
220 } else {
221 Duration::ZERO
222 };
223
224 let success_rate = successful_runs as f64 / self.benchmark_iterations as f64;
225
226 Ok(BenchmarkResult {
227 operation_name: name.to_string(),
228 execution_time: avg_time,
229 memory_usage: None, iterations: self.benchmark_iterations,
231 success_rate,
232 })
233 }
234
235 pub fn run_comprehensive_benchmark(&mut self) -> TorshResult<()> {
237 println!("Running comprehensive FX graph benchmark suite...");
238
239 self.benchmark_graph_creation()?;
240 self.benchmark_serialization()?;
241 self.benchmark_analysis()?;
242 self.benchmark_codegen()?;
243
244 Ok(())
245 }
246
247 pub fn get_results(&self, category: &str) -> Option<&Vec<BenchmarkResult>> {
249 self.results.get(category)
250 }
251
252 pub fn get_all_results(&self) -> &HashMap<String, Vec<BenchmarkResult>> {
254 &self.results
255 }
256
257 pub fn generate_report(&self) -> String {
259 let mut report = String::new();
260 report.push_str("FX Graph Performance Benchmark Report\n");
261 report.push_str("=====================================\n\n");
262
263 for (category, results) in &self.results {
264 report.push_str(&format!("Category: {category}\n"));
265 report.push_str("----------------------------\n");
266
267 for result in results {
268 report.push_str(&format!(
269 " Operation: {}\n Time: {:?}\n Iterations: {}\n Success Rate: {:.2}%\n\n",
270 result.operation_name,
271 result.execution_time,
272 result.iterations,
273 result.success_rate * 100.0
274 ));
275 }
276 report.push('\n');
277 }
278
279 report
280 }
281
282 pub fn compare_with_baseline(&self, baseline: &GraphBenchmarkSuite) -> String {
284 let mut comparison = String::new();
285 comparison.push_str("Performance Comparison with Baseline\n");
286 comparison.push_str("===================================\n\n");
287
288 for (category, results) in &self.results {
289 if let Some(baseline_results) = baseline.get_results(category) {
290 comparison.push_str(&format!("Category: {category}\n"));
291 comparison.push_str("----------------------------\n");
292
293 for (current, baseline_result) in results.iter().zip(baseline_results.iter()) {
294 if current.operation_name == baseline_result.operation_name {
295 let ratio = if baseline_result.execution_time.as_nanos() > 0 {
296 current.execution_time.as_nanos() as f64
297 / baseline_result.execution_time.as_nanos() as f64
298 } else {
299 1.0
300 };
301
302 let performance_change = if ratio < 1.0 {
303 let speedup = 1.0 / ratio;
304 format!("FASTER by {speedup:.2}x")
305 } else if ratio > 1.0 {
306 format!("SLOWER by {ratio:.2}x")
307 } else {
308 "SAME".to_string()
309 };
310
311 comparison.push_str(&format!(
312 " {}: {} (Current: {:?}, Baseline: {:?})\n",
313 current.operation_name,
314 performance_change,
315 current.execution_time,
316 baseline_result.execution_time
317 ));
318 }
319 }
320 comparison.push('\n');
321 }
322 }
323
324 comparison
325 }
326}
327
328pub struct RegressionTester {
330 threshold: f64, }
332
333impl RegressionTester {
334 pub fn new(threshold: f64) -> Self {
336 Self { threshold }
337 }
338
339 pub fn test_regression(
341 &self,
342 current: &GraphBenchmarkSuite,
343 baseline: &GraphBenchmarkSuite,
344 ) -> Vec<String> {
345 let mut regressions = Vec::new();
346
347 for (category, current_results) in current.get_all_results() {
348 if let Some(baseline_results) = baseline.get_results(category) {
349 for (current_result, baseline_result) in
350 current_results.iter().zip(baseline_results.iter())
351 {
352 if current_result.operation_name == baseline_result.operation_name {
353 let ratio = if baseline_result.execution_time.as_nanos() > 0 {
354 current_result.execution_time.as_nanos() as f64
355 / baseline_result.execution_time.as_nanos() as f64
356 } else {
357 1.0
358 };
359
360 if ratio > self.threshold {
361 regressions.push(format!(
362 "REGRESSION in {}/{}: {:.2}x slower than baseline (threshold: {:.2}x)",
363 category,
364 current_result.operation_name,
365 ratio,
366 self.threshold
367 ));
368 }
369 }
370 }
371 }
372 }
373
374 regressions
375 }
376}
377
378#[macro_export]
380macro_rules! benchmark {
381 ($name:expr, $code:block) => {{
382 let start = std::time::Instant::now();
383 let result = $code;
384 let duration = start.elapsed();
385 println!("Benchmark '{}': {:?}", $name, duration);
386 result
387 }};
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_benchmark_suite_creation() {
396 let suite = GraphBenchmarkSuite::new()
397 .with_warmup_iterations(5)
398 .with_benchmark_iterations(50);
399
400 assert_eq!(suite.warmup_iterations, 5);
401 assert_eq!(suite.benchmark_iterations, 50);
402 }
403
404 #[test]
405 fn test_simple_benchmark() {
406 let suite = GraphBenchmarkSuite::new()
407 .with_warmup_iterations(1)
408 .with_benchmark_iterations(5);
409
410 let result = suite
411 .benchmark_operation("test_op", || {
412 std::thread::sleep(std::time::Duration::from_millis(1));
414 Ok(())
415 })
416 .unwrap();
417
418 assert_eq!(result.operation_name, "test_op");
419 assert_eq!(result.iterations, 5);
420 assert_eq!(result.success_rate, 1.0);
421 assert!(result.execution_time > Duration::ZERO);
422 }
423
424 #[test]
425 fn test_graph_creation_benchmark() {
426 let mut suite = GraphBenchmarkSuite::new()
427 .with_warmup_iterations(1)
428 .with_benchmark_iterations(10);
429
430 suite.benchmark_graph_creation().unwrap();
431
432 let results = suite.get_results("graph_creation").unwrap();
433 assert_eq!(results.len(), 3); for result in results {
436 assert_eq!(result.success_rate, 1.0);
437 assert!(result.iterations > 0);
438 }
439 }
440
441 #[test]
442 fn test_serialization_benchmark() {
443 let mut suite = GraphBenchmarkSuite::new()
444 .with_warmup_iterations(1)
445 .with_benchmark_iterations(5);
446
447 suite.benchmark_serialization().unwrap();
448
449 let results = suite.get_results("serialization").unwrap();
450 assert_eq!(results.len(), 4); }
452
453 #[test]
454 fn test_report_generation() {
455 let mut suite = GraphBenchmarkSuite::new()
456 .with_warmup_iterations(1)
457 .with_benchmark_iterations(5);
458
459 suite.benchmark_graph_creation().unwrap();
460
461 let report = suite.generate_report();
462 assert!(report.contains("FX Graph Performance Benchmark Report"));
463 assert!(report.contains("graph_creation"));
464 assert!(report.contains("single_op_creation"));
465 }
466
467 #[test]
468 fn test_regression_tester() {
469 let tester = RegressionTester::new(1.5); let mut baseline = GraphBenchmarkSuite::new();
473 baseline.results.insert(
474 "test".to_string(),
475 vec![BenchmarkResult {
476 operation_name: "fast_op".to_string(),
477 execution_time: Duration::from_millis(10),
478 memory_usage: None,
479 iterations: 100,
480 success_rate: 1.0,
481 }],
482 );
483
484 let mut current = GraphBenchmarkSuite::new();
485 current.results.insert(
486 "test".to_string(),
487 vec![BenchmarkResult {
488 operation_name: "fast_op".to_string(),
489 execution_time: Duration::from_millis(20), memory_usage: None,
491 iterations: 100,
492 success_rate: 1.0,
493 }],
494 );
495
496 let regressions = tester.test_regression(¤t, &baseline);
497 assert_eq!(regressions.len(), 1);
498 assert!(regressions[0].contains("REGRESSION"));
499 assert!(regressions[0].contains("2.00x slower"));
500 }
501
502 #[test]
503 fn test_benchmark_macro() {
504 let result = benchmark!("test_operation", {
505 std::thread::sleep(std::time::Duration::from_millis(1));
506 42
507 });
508
509 assert_eq!(result, 42);
510 }
511}