torsh_cli/commands/model/
real_benchmarking.rs1#![allow(dead_code)]
8
9use anyhow::Result;
10use std::time::Instant;
11use tracing::{debug, info};
12
13use scirs2_core::random::{thread_rng, Distribution, Normal};
15
16use torsh::core::device::DeviceType;
18use torsh::tensor::Tensor;
19
20use super::tensor_integration::forward_pass;
21use super::types::{TimingResult, TorshModel};
22
23#[derive(Debug, Clone)]
25pub struct BenchmarkConfig {
26 pub warmup_iterations: usize,
28 pub measurement_iterations: usize,
30 pub batch_size: usize,
32 pub device: DeviceType,
34 pub measure_memory: bool,
36 pub collect_detailed_stats: bool,
38}
39
40impl Default for BenchmarkConfig {
41 fn default() -> Self {
42 Self {
43 warmup_iterations: 10,
44 measurement_iterations: 100,
45 batch_size: 1,
46 device: DeviceType::Cpu,
47 measure_memory: true,
48 collect_detailed_stats: true,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct DetailedBenchmarkResults {
56 pub iteration_timings: Vec<f64>,
58 pub warmup_duration: f64,
60 pub measurement_duration: f64,
62 pub memory_samples: Vec<f64>,
64 pub peak_memory: f64,
66 pub avg_memory: f64,
68 pub throughput: f64,
70 pub latency_stats: LatencyStatistics,
72}
73
74#[derive(Debug, Clone)]
76pub struct LatencyStatistics {
77 pub mean: f64,
79 pub median: f64,
81 pub std_dev: f64,
83 pub min: f64,
85 pub max: f64,
87 pub p95: f64,
89 pub p99: f64,
91}
92
93pub fn benchmark_model_real(
95 model: &TorshModel,
96 config: &BenchmarkConfig,
97) -> Result<DetailedBenchmarkResults> {
98 info!(
99 "Starting real model benchmark with {} warmup and {} measurement iterations",
100 config.warmup_iterations, config.measurement_iterations
101 );
102
103 let input_shape = model
105 .layers
106 .first()
107 .map(|l| l.input_shape.clone())
108 .unwrap_or_else(|| vec![784]);
109
110 let input = create_input_tensor(&input_shape, config.batch_size, config.device)?;
111
112 debug!("Running warmup phase...");
114 let warmup_start = Instant::now();
115 for i in 0..config.warmup_iterations {
116 let _ = forward_pass(model, &input)?;
117
118 if i % 10 == 0 {
119 debug!("Warmup iteration {}/{}", i, config.warmup_iterations);
120 }
121 }
122 let warmup_duration = warmup_start.elapsed().as_secs_f64() * 1000.0;
123
124 debug!("Warmup completed in {:.2} ms", warmup_duration);
125
126 debug!("Running measurement phase...");
128 let mut iteration_timings = Vec::with_capacity(config.measurement_iterations);
129 let mut memory_samples = Vec::new();
130
131 let measurement_start = Instant::now();
132
133 for i in 0..config.measurement_iterations {
134 let iter_start = Instant::now();
136 let _ = forward_pass(model, &input)?;
137 let iter_duration = iter_start.elapsed().as_secs_f64() * 1000.0;
138
139 iteration_timings.push(iter_duration);
140
141 if config.measure_memory && i % 10 == 0 {
143 let memory_mb = estimate_memory_usage(model);
144 memory_samples.push(memory_mb);
145 }
146
147 if i % 20 == 0 {
148 debug!(
149 "Measurement iteration {}/{} - {:.2} ms",
150 i, config.measurement_iterations, iter_duration
151 );
152 }
153 }
154
155 let measurement_duration = measurement_start.elapsed().as_secs_f64() * 1000.0;
156
157 debug!("Measurement completed in {:.2} ms", measurement_duration);
158
159 let latency_stats = calculate_latency_statistics(&iteration_timings)?;
161
162 let peak_memory = memory_samples
163 .iter()
164 .copied()
165 .max_by(|a, b| {
166 a.partial_cmp(b)
167 .expect("memory sample values should be comparable")
168 })
169 .unwrap_or(0.0);
170
171 let avg_memory = if !memory_samples.is_empty() {
172 memory_samples.iter().sum::<f64>() / memory_samples.len() as f64
173 } else {
174 0.0
175 };
176
177 let throughput = if latency_stats.mean > 0.0 {
178 (1000.0 / latency_stats.mean) * config.batch_size as f64
179 } else {
180 0.0
181 };
182
183 Ok(DetailedBenchmarkResults {
184 iteration_timings,
185 warmup_duration,
186 measurement_duration,
187 memory_samples,
188 peak_memory,
189 avg_memory,
190 throughput,
191 latency_stats,
192 })
193}
194
195pub fn to_timing_result(results: &DetailedBenchmarkResults) -> TimingResult {
197 TimingResult {
198 throughput_fps: results.throughput,
199 latency_ms: results.latency_stats.mean,
200 memory_mb: results.peak_memory,
201 warmup_time_ms: results.warmup_duration,
202 avg_inference_time_ms: results.latency_stats.mean,
203 min_inference_time_ms: results.latency_stats.min,
204 max_inference_time_ms: results.latency_stats.max,
205 std_dev_ms: results.latency_stats.std_dev,
206 device_utilization: None, }
208}
209
210fn calculate_latency_statistics(timings: &[f64]) -> Result<LatencyStatistics> {
212 if timings.is_empty() {
213 anyhow::bail!("No timing samples available for statistics");
214 }
215
216 let mut sorted_timings = timings.to_vec();
217 sorted_timings.sort_by(|a, b| {
218 a.partial_cmp(b)
219 .expect("timing values should be comparable")
220 });
221
222 let mean = sorted_timings.iter().sum::<f64>() / sorted_timings.len() as f64;
223
224 let variance = sorted_timings
225 .iter()
226 .map(|&t| (t - mean).powi(2))
227 .sum::<f64>()
228 / sorted_timings.len() as f64;
229
230 let std_dev = variance.sqrt();
231
232 let median = if sorted_timings.len() % 2 == 0 {
233 let mid = sorted_timings.len() / 2;
234 (sorted_timings[mid - 1] + sorted_timings[mid]) / 2.0
235 } else {
236 sorted_timings[sorted_timings.len() / 2]
237 };
238
239 let min = sorted_timings[0];
240 let max = sorted_timings[sorted_timings.len() - 1];
241
242 let p95_idx = ((sorted_timings.len() as f64 * 0.95) as usize).min(sorted_timings.len() - 1);
243 let p95 = sorted_timings[p95_idx];
244
245 let p99_idx = ((sorted_timings.len() as f64 * 0.99) as usize).min(sorted_timings.len() - 1);
246 let p99 = sorted_timings[p99_idx];
247
248 Ok(LatencyStatistics {
249 mean,
250 median,
251 std_dev,
252 min,
253 max,
254 p95,
255 p99,
256 })
257}
258
259fn create_input_tensor(
261 shape: &[usize],
262 batch_size: usize,
263 device: DeviceType,
264) -> Result<Tensor<f32>> {
265 let mut full_shape = vec![batch_size];
266 full_shape.extend_from_slice(shape);
267
268 let mut rng = thread_rng();
269 let normal = Normal::new(0.0, 1.0)?;
270
271 let num_elements: usize = full_shape.iter().product();
272 let data: Vec<f32> = (0..num_elements)
273 .map(|_| normal.sample(&mut rng) as f32)
274 .collect();
275
276 Ok(Tensor::from_data(data, full_shape, device)?)
277}
278
279fn estimate_memory_usage(model: &TorshModel) -> f64 {
281 let param_count: u64 = model.layers.iter().map(|l| l.parameters).sum();
282
283 let memory_bytes = param_count * 4 * 2; memory_bytes as f64 / (1024.0 * 1024.0)
288}
289
290pub fn benchmark_batch_sizes(
292 model: &TorshModel,
293 batch_sizes: &[usize],
294 device: DeviceType,
295) -> Result<Vec<(usize, DetailedBenchmarkResults)>> {
296 let mut results = Vec::new();
297
298 for &batch_size in batch_sizes {
299 info!("Benchmarking with batch size: {}", batch_size);
300
301 let config = BenchmarkConfig {
302 batch_size,
303 device,
304 ..Default::default()
305 };
306
307 let bench_results = benchmark_model_real(model, &config)?;
308 results.push((batch_size, bench_results));
309 }
310
311 Ok(results)
312}
313
314pub fn benchmark_devices(
316 model: &TorshModel,
317 devices: &[DeviceType],
318) -> Result<Vec<(DeviceType, DetailedBenchmarkResults)>> {
319 let mut results = Vec::new();
320
321 for &device in devices {
322 info!("Benchmarking on device: {:?}", device);
323
324 let config = BenchmarkConfig {
325 device,
326 ..Default::default()
327 };
328
329 let bench_results = benchmark_model_real(model, &config)?;
330 results.push((device, bench_results));
331 }
332
333 Ok(results)
334}
335
336pub fn format_benchmark_results(results: &DetailedBenchmarkResults) -> String {
338 format!(
339 r#"
340Benchmark Results:
341==================
342Throughput: {:.2} samples/sec
343Latency (avg): {:.2} ms
344Latency (median): {:.2} ms
345Latency (min): {:.2} ms
346Latency (max): {:.2} ms
347Latency (p95): {:.2} ms
348Latency (p99): {:.2} ms
349Latency (std dev): {:.2} ms
350
351Memory:
352-------
353Peak: {:.2} MB
354Average: {:.2} MB
355
356Timing:
357-------
358Warmup: {:.2} ms
359Total Measurement: {:.2} ms
360Iterations: {}
361"#,
362 results.throughput,
363 results.latency_stats.mean,
364 results.latency_stats.median,
365 results.latency_stats.min,
366 results.latency_stats.max,
367 results.latency_stats.p95,
368 results.latency_stats.p99,
369 results.latency_stats.std_dev,
370 results.peak_memory,
371 results.avg_memory,
372 results.warmup_duration,
373 results.measurement_duration,
374 results.iteration_timings.len(),
375 )
376}
377
378#[cfg(test)]
379mod tests {
380 use super::super::tensor_integration::create_real_model;
381 use super::*;
382
383 #[test]
384 fn test_latency_statistics() {
385 let timings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
386 let stats = calculate_latency_statistics(&timings)
387 .expect("calculate latency statistics should succeed");
388
389 assert!((stats.mean - 5.5).abs() < 0.1);
390 assert!((stats.median - 5.5).abs() < 0.1);
391 assert_eq!(stats.min, 1.0);
392 assert_eq!(stats.max, 10.0);
393 }
394
395 #[test]
396 fn test_benchmark_config_default() {
397 let config = BenchmarkConfig::default();
398 assert_eq!(config.warmup_iterations, 10);
399 assert_eq!(config.measurement_iterations, 100);
400 assert_eq!(config.batch_size, 1);
401 }
402
403 #[test]
404 fn test_create_input_tensor() {
405 let tensor = create_input_tensor(&[3, 224, 224], 2, DeviceType::Cpu)
406 .expect("create input tensor should succeed");
407 assert_eq!(tensor.shape().dims(), &[2, 3, 224, 224]);
408 }
409
410 #[test]
411 #[ignore = "Flaky test - passes individually but may fail in full suite"]
412 fn test_benchmark_model_real() {
413 let model = create_real_model("test", 2, DeviceType::Cpu)
414 .expect("create real model should succeed");
415 let config = BenchmarkConfig {
416 warmup_iterations: 2,
417 measurement_iterations: 5,
418 batch_size: 1,
419 device: DeviceType::Cpu,
420 measure_memory: true,
421 collect_detailed_stats: true,
422 };
423
424 let results =
425 benchmark_model_real(&model, &config).expect("benchmark model real should succeed");
426 assert_eq!(results.iteration_timings.len(), 5);
427 assert!(results.throughput > 0.0);
428 assert!(results.peak_memory > 0.0);
429 }
430}