Skip to main content

torsh_fx/
benchmarking.rs

1//! Benchmarking utilities for FX graph operations and transformations
2//!
3//! This module provides comprehensive benchmarking capabilities for measuring
4//! performance of graph operations, transformations, and code generation.
5
6use crate::{FxGraph, TorshResult};
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10/// Benchmark results for a single operation
11#[derive(Debug, Clone)]
12pub struct BenchmarkResult {
13    pub operation_name: String,
14    pub execution_time: Duration,
15    pub memory_usage: Option<usize>,
16    pub iterations: usize,
17    pub success_rate: f64,
18}
19
20/// Comprehensive benchmark suite for graph operations
21#[derive(Debug)]
22pub struct GraphBenchmarkSuite {
23    results: HashMap<String, Vec<BenchmarkResult>>,
24    warmup_iterations: usize,
25    benchmark_iterations: usize,
26}
27
28impl GraphBenchmarkSuite {
29    /// Create a new benchmark suite
30    pub fn new() -> Self {
31        Self {
32            results: HashMap::new(),
33            warmup_iterations: 10,
34            benchmark_iterations: 100,
35        }
36    }
37
38    /// Set the number of warmup iterations
39    pub fn with_warmup_iterations(mut self, iterations: usize) -> Self {
40        self.warmup_iterations = iterations;
41        self
42    }
43
44    /// Set the number of benchmark iterations
45    pub fn with_benchmark_iterations(mut self, iterations: usize) -> Self {
46        self.benchmark_iterations = iterations;
47        self
48    }
49
50    /// Benchmark graph creation operations
51    pub fn benchmark_graph_creation(&mut self) -> TorshResult<()> {
52        // Benchmark single operation graph creation
53        let result = self.benchmark_operation("single_op_creation", || {
54            let _graph = FxGraph::single_op("relu", vec!["input".to_string()]);
55            Ok(())
56        })?;
57
58        // Benchmark sequential operations graph creation
59        let result_seq = self.benchmark_operation("sequential_ops_creation", || {
60            let _graph = FxGraph::sequential_ops(&["relu", "sigmoid", "tanh"]);
61            Ok(())
62        })?;
63
64        // Benchmark large graph creation
65        let result_large = self.benchmark_operation("large_graph_creation", || {
66            let ops: Vec<&str> = (0..100)
67                .map(|i| {
68                    if i % 3 == 0 {
69                        "relu"
70                    } else if i % 3 == 1 {
71                        "sigmoid"
72                    } else {
73                        "tanh"
74                    }
75                })
76                .collect();
77            let _graph = FxGraph::sequential_ops(&ops);
78            Ok(())
79        })?;
80
81        self.results
82            .entry("graph_creation".to_string())
83            .or_insert_with(Vec::new)
84            .extend([result, result_seq, result_large]);
85
86        Ok(())
87    }
88
89    /// Benchmark graph serialization operations
90    pub fn benchmark_serialization(&mut self) -> TorshResult<()> {
91        let test_graph = FxGraph::sequential_ops(&["relu", "sigmoid", "tanh", "softmax"]);
92
93        // Benchmark JSON serialization
94        let json_serialize = self.benchmark_operation("json_serialize", || {
95            let _json = test_graph.to_json()?;
96            Ok(())
97        })?;
98
99        // Benchmark binary serialization
100        let binary_serialize = self.benchmark_operation("binary_serialize", || {
101            let _binary = test_graph.to_binary()?;
102            Ok(())
103        })?;
104
105        // Benchmark JSON deserialization
106        let json_data = test_graph.to_json()?;
107        let json_deserialize = self.benchmark_operation("json_deserialize", || {
108            let _graph = FxGraph::from_json(&json_data)?;
109            Ok(())
110        })?;
111
112        // Benchmark binary deserialization
113        let binary_data = test_graph.to_binary()?;
114        let binary_deserialize = self.benchmark_operation("binary_deserialize", || {
115            let _graph = FxGraph::from_binary(&binary_data)?;
116            Ok(())
117        })?;
118
119        self.results
120            .entry("serialization".to_string())
121            .or_insert_with(Vec::new)
122            .extend([
123                json_serialize,
124                binary_serialize,
125                json_deserialize,
126                binary_deserialize,
127            ]);
128
129        Ok(())
130    }
131
132    /// Benchmark graph analysis operations
133    pub fn benchmark_analysis(&mut self) -> TorshResult<()> {
134        let test_graph =
135            FxGraph::sequential_ops(&["relu", "sigmoid", "tanh", "softmax", "dropout"]);
136
137        // Benchmark validation
138        let validation = self.benchmark_operation("graph_validation", || {
139            let _result = test_graph.validate()?;
140            Ok(())
141        })?;
142
143        // Benchmark node filtering
144        let node_filtering = self.benchmark_operation("node_filtering", || {
145            let _inputs = test_graph.input_nodes();
146            let _outputs = test_graph.output_nodes();
147            let _calls = test_graph.call_nodes();
148            Ok(())
149        })?;
150
151        // Benchmark summary generation
152        let summary = self.benchmark_operation("summary_generation", || {
153            let _summary = test_graph.summary();
154            Ok(())
155        })?;
156
157        self.results
158            .entry("analysis".to_string())
159            .or_insert_with(Vec::new)
160            .extend([validation, node_filtering, summary]);
161
162        Ok(())
163    }
164
165    /// Benchmark code generation operations
166    pub fn benchmark_codegen(&mut self) -> TorshResult<()> {
167        let test_graph = FxGraph::sequential_ops(&["relu", "sigmoid", "tanh"]);
168
169        // Benchmark Python code generation
170        let python_codegen = self.benchmark_operation("python_codegen", || {
171            let _code = test_graph.to_python()?;
172            Ok(())
173        })?;
174
175        // Benchmark C++ code generation
176        let cpp_codegen = self.benchmark_operation("cpp_codegen", || {
177            let _code = test_graph.to_cpp()?;
178            Ok(())
179        })?;
180
181        self.results
182            .entry("codegen".to_string())
183            .or_insert_with(Vec::new)
184            .extend([python_codegen, cpp_codegen]);
185
186        Ok(())
187    }
188
189    /// Generic method to benchmark any operation
190    pub fn benchmark_operation<F>(
191        &self,
192        name: &str,
193        mut operation: F,
194    ) -> TorshResult<BenchmarkResult>
195    where
196        F: FnMut() -> TorshResult<()>,
197    {
198        // Warmup phase
199        for _ in 0..self.warmup_iterations {
200            let _ = operation();
201        }
202
203        // Benchmark phase
204        let mut total_time = Duration::ZERO;
205        let mut successful_runs = 0;
206
207        for _ in 0..self.benchmark_iterations {
208            let start = Instant::now();
209            match operation() {
210                Ok(_) => {
211                    total_time += start.elapsed();
212                    successful_runs += 1;
213                }
214                Err(_) => {} // Count failures but continue
215            }
216        }
217
218        let avg_time = if successful_runs > 0 {
219            total_time / successful_runs as u32
220        } else {
221            Duration::ZERO
222        };
223
224        let success_rate = successful_runs as f64 / self.benchmark_iterations as f64;
225
226        Ok(BenchmarkResult {
227            operation_name: name.to_string(),
228            execution_time: avg_time,
229            memory_usage: None, // Could be extended to measure memory usage
230            iterations: self.benchmark_iterations,
231            success_rate,
232        })
233    }
234
235    /// Run a comprehensive benchmark suite
236    pub fn run_comprehensive_benchmark(&mut self) -> TorshResult<()> {
237        println!("Running comprehensive FX graph benchmark suite...");
238
239        self.benchmark_graph_creation()?;
240        self.benchmark_serialization()?;
241        self.benchmark_analysis()?;
242        self.benchmark_codegen()?;
243
244        Ok(())
245    }
246
247    /// Get benchmark results for a specific category
248    pub fn get_results(&self, category: &str) -> Option<&Vec<BenchmarkResult>> {
249        self.results.get(category)
250    }
251
252    /// Get all benchmark results
253    pub fn get_all_results(&self) -> &HashMap<String, Vec<BenchmarkResult>> {
254        &self.results
255    }
256
257    /// Generate a performance report
258    pub fn generate_report(&self) -> String {
259        let mut report = String::new();
260        report.push_str("FX Graph Performance Benchmark Report\n");
261        report.push_str("=====================================\n\n");
262
263        for (category, results) in &self.results {
264            report.push_str(&format!("Category: {category}\n"));
265            report.push_str("----------------------------\n");
266
267            for result in results {
268                report.push_str(&format!(
269                    "  Operation: {}\n    Time: {:?}\n    Iterations: {}\n    Success Rate: {:.2}%\n\n",
270                    result.operation_name,
271                    result.execution_time,
272                    result.iterations,
273                    result.success_rate * 100.0
274                ));
275            }
276            report.push('\n');
277        }
278
279        report
280    }
281
282    /// Compare performance against baseline benchmarks
283    pub fn compare_with_baseline(&self, baseline: &GraphBenchmarkSuite) -> String {
284        let mut comparison = String::new();
285        comparison.push_str("Performance Comparison with Baseline\n");
286        comparison.push_str("===================================\n\n");
287
288        for (category, results) in &self.results {
289            if let Some(baseline_results) = baseline.get_results(category) {
290                comparison.push_str(&format!("Category: {category}\n"));
291                comparison.push_str("----------------------------\n");
292
293                for (current, baseline_result) in results.iter().zip(baseline_results.iter()) {
294                    if current.operation_name == baseline_result.operation_name {
295                        let ratio = if baseline_result.execution_time.as_nanos() > 0 {
296                            current.execution_time.as_nanos() as f64
297                                / baseline_result.execution_time.as_nanos() as f64
298                        } else {
299                            1.0
300                        };
301
302                        let performance_change = if ratio < 1.0 {
303                            let speedup = 1.0 / ratio;
304                            format!("FASTER by {speedup:.2}x")
305                        } else if ratio > 1.0 {
306                            format!("SLOWER by {ratio:.2}x")
307                        } else {
308                            "SAME".to_string()
309                        };
310
311                        comparison.push_str(&format!(
312                            "  {}: {} (Current: {:?}, Baseline: {:?})\n",
313                            current.operation_name,
314                            performance_change,
315                            current.execution_time,
316                            baseline_result.execution_time
317                        ));
318                    }
319                }
320                comparison.push('\n');
321            }
322        }
323
324        comparison
325    }
326}
327
328/// Performance regression testing utilities
329pub struct RegressionTester {
330    threshold: f64, // Allowed performance degradation (e.g., 1.1 = 10% slower is acceptable)
331}
332
333impl RegressionTester {
334    /// Create a new regression tester with a specified threshold
335    pub fn new(threshold: f64) -> Self {
336        Self { threshold }
337    }
338
339    /// Test for performance regressions
340    pub fn test_regression(
341        &self,
342        current: &GraphBenchmarkSuite,
343        baseline: &GraphBenchmarkSuite,
344    ) -> Vec<String> {
345        let mut regressions = Vec::new();
346
347        for (category, current_results) in current.get_all_results() {
348            if let Some(baseline_results) = baseline.get_results(category) {
349                for (current_result, baseline_result) in
350                    current_results.iter().zip(baseline_results.iter())
351                {
352                    if current_result.operation_name == baseline_result.operation_name {
353                        let ratio = if baseline_result.execution_time.as_nanos() > 0 {
354                            current_result.execution_time.as_nanos() as f64
355                                / baseline_result.execution_time.as_nanos() as f64
356                        } else {
357                            1.0
358                        };
359
360                        if ratio > self.threshold {
361                            regressions.push(format!(
362                                "REGRESSION in {}/{}: {:.2}x slower than baseline (threshold: {:.2}x)",
363                                category,
364                                current_result.operation_name,
365                                ratio,
366                                self.threshold
367                            ));
368                        }
369                    }
370                }
371            }
372        }
373
374        regressions
375    }
376}
377
378/// Simple benchmark macro for quick measurements
379#[macro_export]
380macro_rules! benchmark {
381    ($name:expr, $code:block) => {{
382        let start = std::time::Instant::now();
383        let result = $code;
384        let duration = start.elapsed();
385        println!("Benchmark '{}': {:?}", $name, duration);
386        result
387    }};
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn test_benchmark_suite_creation() {
396        let suite = GraphBenchmarkSuite::new()
397            .with_warmup_iterations(5)
398            .with_benchmark_iterations(50);
399
400        assert_eq!(suite.warmup_iterations, 5);
401        assert_eq!(suite.benchmark_iterations, 50);
402    }
403
404    #[test]
405    fn test_simple_benchmark() {
406        let suite = GraphBenchmarkSuite::new()
407            .with_warmup_iterations(1)
408            .with_benchmark_iterations(5);
409
410        let result = suite
411            .benchmark_operation("test_op", || {
412                // Simulate some work
413                std::thread::sleep(std::time::Duration::from_millis(1));
414                Ok(())
415            })
416            .unwrap();
417
418        assert_eq!(result.operation_name, "test_op");
419        assert_eq!(result.iterations, 5);
420        assert_eq!(result.success_rate, 1.0);
421        assert!(result.execution_time > Duration::ZERO);
422    }
423
424    #[test]
425    fn test_graph_creation_benchmark() {
426        let mut suite = GraphBenchmarkSuite::new()
427            .with_warmup_iterations(1)
428            .with_benchmark_iterations(10);
429
430        suite.benchmark_graph_creation().unwrap();
431
432        let results = suite.get_results("graph_creation").unwrap();
433        assert_eq!(results.len(), 3); // single_op, sequential_ops, large_graph
434
435        for result in results {
436            assert_eq!(result.success_rate, 1.0);
437            assert!(result.iterations > 0);
438        }
439    }
440
441    #[test]
442    fn test_serialization_benchmark() {
443        let mut suite = GraphBenchmarkSuite::new()
444            .with_warmup_iterations(1)
445            .with_benchmark_iterations(5);
446
447        suite.benchmark_serialization().unwrap();
448
449        let results = suite.get_results("serialization").unwrap();
450        assert_eq!(results.len(), 4); // json_serialize, binary_serialize, json_deserialize, binary_deserialize
451    }
452
453    #[test]
454    fn test_report_generation() {
455        let mut suite = GraphBenchmarkSuite::new()
456            .with_warmup_iterations(1)
457            .with_benchmark_iterations(5);
458
459        suite.benchmark_graph_creation().unwrap();
460
461        let report = suite.generate_report();
462        assert!(report.contains("FX Graph Performance Benchmark Report"));
463        assert!(report.contains("graph_creation"));
464        assert!(report.contains("single_op_creation"));
465    }
466
467    #[test]
468    fn test_regression_tester() {
469        let tester = RegressionTester::new(1.5); // 50% degradation threshold
470
471        // Create mock benchmark suites
472        let mut baseline = GraphBenchmarkSuite::new();
473        baseline.results.insert(
474            "test".to_string(),
475            vec![BenchmarkResult {
476                operation_name: "fast_op".to_string(),
477                execution_time: Duration::from_millis(10),
478                memory_usage: None,
479                iterations: 100,
480                success_rate: 1.0,
481            }],
482        );
483
484        let mut current = GraphBenchmarkSuite::new();
485        current.results.insert(
486            "test".to_string(),
487            vec![BenchmarkResult {
488                operation_name: "fast_op".to_string(),
489                execution_time: Duration::from_millis(20), // 2x slower
490                memory_usage: None,
491                iterations: 100,
492                success_rate: 1.0,
493            }],
494        );
495
496        let regressions = tester.test_regression(&current, &baseline);
497        assert_eq!(regressions.len(), 1);
498        assert!(regressions[0].contains("REGRESSION"));
499        assert!(regressions[0].contains("2.00x slower"));
500    }
501
502    #[test]
503    fn test_benchmark_macro() {
504        let result = benchmark!("test_operation", {
505            std::thread::sleep(std::time::Duration::from_millis(1));
506            42
507        });
508
509        assert_eq!(result, 42);
510    }
511}