torsh_functional/profiling/
benchmarking.rs1use super::core::{OperationMetrics, OperationSummary, Profiler};
7use std::time::Instant;
8use torsh_core::{Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10
11#[derive(Debug, Clone)]
13pub struct BenchmarkConfig {
14 pub warmup_iters: usize,
16 pub bench_iters: usize,
18 pub min_duration: f64,
20 pub max_duration: f64,
22 pub detailed_metrics: bool,
24}
25
26impl Default for BenchmarkConfig {
27 fn default() -> Self {
28 Self {
29 warmup_iters: 5,
30 bench_iters: 100,
31 min_duration: 1.0,
32 max_duration: 60.0,
33 detailed_metrics: true,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct BenchmarkResults {
41 pub operation_name: String,
42 pub config: BenchmarkConfig,
43 pub metrics: Vec<OperationMetrics>,
44 pub summary: OperationSummary,
45}
46
47pub fn benchmark<F, R>(
49 name: &str,
50 mut operation: F,
51 inputs: &[&Tensor],
52 config: BenchmarkConfig,
53) -> TorshResult<BenchmarkResults>
54where
55 F: FnMut(&[&Tensor]) -> TorshResult<R>,
56 R: AsRef<[Tensor]>,
57{
58 let mut profiler = Profiler::new();
59 if config.detailed_metrics {
60 profiler.enable_memory_tracking();
61 profiler.enable_flops_counting();
62 }
63
64 for _ in 0..config.warmup_iters {
66 let _ = operation(inputs)?;
67 }
68
69 let start_time = Instant::now();
71 let mut iteration = 0;
72
73 while iteration < config.bench_iters {
74 let elapsed = start_time.elapsed().as_secs_f64();
75 if elapsed > config.max_duration {
76 break;
77 }
78 if iteration > 0 && elapsed > config.min_duration {
79 break;
80 }
81
82 profiler.start_operation(name, inputs)?;
83 let result = operation(inputs)?;
84 let output_refs: Vec<&Tensor> = result.as_ref().iter().collect();
85 profiler.finish_operation(&output_refs)?;
86
87 iteration += 1;
88 }
89
90 let summary = profiler
91 .get_summary(name)
92 .ok_or_else(|| TorshError::Other("Failed to generate benchmark summary".to_string()))?;
93
94 Ok(BenchmarkResults {
95 operation_name: name.to_string(),
96 config,
97 metrics: profiler.metrics,
98 summary,
99 })
100}
101
102pub fn profile_operation<F, R>(
104 name: &str,
105 mut operation: F,
106 inputs: &[&Tensor],
107) -> TorshResult<OperationMetrics>
108where
109 F: FnMut(&[&Tensor]) -> TorshResult<R>,
110 R: AsRef<[Tensor]>,
111{
112 let mut profiler = Profiler::new();
113 profiler.enable_memory_tracking();
114 profiler.enable_flops_counting();
115
116 profiler.start_operation(name, inputs)?;
117 let result = operation(inputs)?;
118 let output_refs: Vec<&Tensor> = result.as_ref().iter().collect();
119 profiler.finish_operation(&output_refs)?;
120
121 Ok(profiler
122 .metrics
123 .into_iter()
124 .next()
125 .expect("profiler should have at least one metric after finish_operation"))
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use torsh_tensor::creation::randn;
132
133 #[test]
134 fn test_benchmark_basic() -> TorshResult<()> {
135 let input = randn(&[128, 128])?;
136 let inputs = vec![&input];
137
138 let config = BenchmarkConfig {
139 warmup_iters: 1,
140 bench_iters: 3,
141 min_duration: 0.1,
142 max_duration: 1.0,
143 detailed_metrics: false,
144 };
145
146 let results = benchmark(
147 "test_operation",
148 |inputs| -> TorshResult<Vec<Tensor>> { Ok(vec![inputs[0].clone()]) },
149 &inputs,
150 config,
151 )?;
152
153 assert_eq!(results.operation_name, "test_operation");
154 assert!(results.metrics.len() <= 3);
155 Ok(())
156 }
157
158 #[test]
159 fn test_profile_operation() -> TorshResult<()> {
160 let input = randn(&[64, 64])?;
161 let inputs = vec![&input];
162
163 let metrics = profile_operation(
164 "test_profile",
165 |inputs| -> TorshResult<Vec<Tensor>> { Ok(vec![inputs[0].clone()]) },
166 &inputs,
167 )?;
168
169 assert_eq!(metrics.name, "test_profile");
170 assert!(!metrics.input_shapes.is_empty());
171 assert!(!metrics.output_shapes.is_empty());
172 Ok(())
173 }
174}