Skip to main content

torsh_functional/profiling/
benchmarking.rs

1//! Benchmarking framework for functional operations
2//!
3//! This module provides comprehensive benchmarking capabilities with
4//! warmup iterations, statistical analysis, and detailed metrics collection.
5
6use super::core::{OperationMetrics, OperationSummary, Profiler};
7use std::time::Instant;
8use torsh_core::{Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10
11/// Benchmark configuration
12#[derive(Debug, Clone)]
13pub struct BenchmarkConfig {
14    /// Number of warmup iterations
15    pub warmup_iters: usize,
16    /// Number of benchmark iterations
17    pub bench_iters: usize,
18    /// Minimum benchmark duration in seconds
19    pub min_duration: f64,
20    /// Maximum benchmark duration in seconds
21    pub max_duration: f64,
22    /// Whether to collect detailed metrics
23    pub detailed_metrics: bool,
24}
25
26impl Default for BenchmarkConfig {
27    fn default() -> Self {
28        Self {
29            warmup_iters: 5,
30            bench_iters: 100,
31            min_duration: 1.0,
32            max_duration: 60.0,
33            detailed_metrics: true,
34        }
35    }
36}
37
38/// Benchmark results
39#[derive(Debug, Clone)]
40pub struct BenchmarkResults {
41    pub operation_name: String,
42    pub config: BenchmarkConfig,
43    pub metrics: Vec<OperationMetrics>,
44    pub summary: OperationSummary,
45}
46
47/// Benchmark a function with given inputs
48pub fn benchmark<F, R>(
49    name: &str,
50    mut operation: F,
51    inputs: &[&Tensor],
52    config: BenchmarkConfig,
53) -> TorshResult<BenchmarkResults>
54where
55    F: FnMut(&[&Tensor]) -> TorshResult<R>,
56    R: AsRef<[Tensor]>,
57{
58    let mut profiler = Profiler::new();
59    if config.detailed_metrics {
60        profiler.enable_memory_tracking();
61        profiler.enable_flops_counting();
62    }
63
64    // Warmup iterations
65    for _ in 0..config.warmup_iters {
66        let _ = operation(inputs)?;
67    }
68
69    // Benchmark iterations
70    let start_time = Instant::now();
71    let mut iteration = 0;
72
73    while iteration < config.bench_iters {
74        let elapsed = start_time.elapsed().as_secs_f64();
75        if elapsed > config.max_duration {
76            break;
77        }
78        if iteration > 0 && elapsed > config.min_duration {
79            break;
80        }
81
82        profiler.start_operation(name, inputs)?;
83        let result = operation(inputs)?;
84        let output_refs: Vec<&Tensor> = result.as_ref().iter().collect();
85        profiler.finish_operation(&output_refs)?;
86
87        iteration += 1;
88    }
89
90    let summary = profiler
91        .get_summary(name)
92        .ok_or_else(|| TorshError::Other("Failed to generate benchmark summary".to_string()))?;
93
94    Ok(BenchmarkResults {
95        operation_name: name.to_string(),
96        config,
97        metrics: profiler.metrics,
98        summary,
99    })
100}
101
102/// Profile a single operation
103pub fn profile_operation<F, R>(
104    name: &str,
105    mut operation: F,
106    inputs: &[&Tensor],
107) -> TorshResult<OperationMetrics>
108where
109    F: FnMut(&[&Tensor]) -> TorshResult<R>,
110    R: AsRef<[Tensor]>,
111{
112    let mut profiler = Profiler::new();
113    profiler.enable_memory_tracking();
114    profiler.enable_flops_counting();
115
116    profiler.start_operation(name, inputs)?;
117    let result = operation(inputs)?;
118    let output_refs: Vec<&Tensor> = result.as_ref().iter().collect();
119    profiler.finish_operation(&output_refs)?;
120
121    Ok(profiler
122        .metrics
123        .into_iter()
124        .next()
125        .expect("profiler should have at least one metric after finish_operation"))
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use torsh_tensor::creation::randn;
132
133    #[test]
134    fn test_benchmark_basic() -> TorshResult<()> {
135        let input = randn(&[128, 128])?;
136        let inputs = vec![&input];
137
138        let config = BenchmarkConfig {
139            warmup_iters: 1,
140            bench_iters: 3,
141            min_duration: 0.1,
142            max_duration: 1.0,
143            detailed_metrics: false,
144        };
145
146        let results = benchmark(
147            "test_operation",
148            |inputs| -> TorshResult<Vec<Tensor>> { Ok(vec![inputs[0].clone()]) },
149            &inputs,
150            config,
151        )?;
152
153        assert_eq!(results.operation_name, "test_operation");
154        assert!(results.metrics.len() <= 3);
155        Ok(())
156    }
157
158    #[test]
159    fn test_profile_operation() -> TorshResult<()> {
160        let input = randn(&[64, 64])?;
161        let inputs = vec![&input];
162
163        let metrics = profile_operation(
164            "test_profile",
165            |inputs| -> TorshResult<Vec<Tensor>> { Ok(vec![inputs[0].clone()]) },
166            &inputs,
167        )?;
168
169        assert_eq!(metrics.name, "test_profile");
170        assert!(!metrics.input_shapes.is_empty());
171        assert!(!metrics.output_shapes.is_empty());
172        Ok(())
173    }
174}