1use crate::tensor::Tensor;
4use crate::traits::Model;
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10#[derive(Clone)]
12pub struct ModelInput {
13 pub input_ids: Tensor,
14 pub attention_mask: Option<Tensor>,
15 pub token_type_ids: Option<Tensor>,
16 pub position_ids: Option<Tensor>,
17}
18
19#[derive(Default)]
21pub struct ModelOutput {
22 pub hidden_states: Option<Tensor>,
23 pub logits: Option<Tensor>,
24 pub attentions: Option<Vec<Tensor>>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct BenchmarkConfig {
30 pub batch_sizes: Vec<usize>,
32 pub sequence_lengths: Vec<usize>,
34 pub warmup_iterations: usize,
36 pub num_iterations: usize,
38 pub measure_memory: bool,
40 pub device: String,
42 pub use_fp16: bool,
44 pub include_generation: bool,
46 pub max_generation_length: Option<usize>,
48}
49
50impl Default for BenchmarkConfig {
51 fn default() -> Self {
52 Self {
53 batch_sizes: vec![1, 4, 8, 16, 32],
54 sequence_lengths: vec![128, 256, 512, 1024, 2048],
55 warmup_iterations: 10,
56 num_iterations: 100,
57 measure_memory: true,
58 device: "cpu".to_string(),
59 use_fp16: false,
60 include_generation: false,
61 max_generation_length: Some(256),
62 }
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct BenchmarkResult {
69 pub name: String,
71 pub model_type: String,
73 pub avg_latency_ms: f64,
75 pub p50_latency_ms: f64,
77 pub p95_latency_ms: f64,
79 pub p99_latency_ms: f64,
81 pub min_latency_ms: f64,
83 pub max_latency_ms: f64,
85 pub std_dev_ms: f64,
87 pub throughput_tokens_per_sec: f64,
89 pub throughput_batches_per_sec: f64,
91 pub memory_bytes: Option<usize>,
93 pub peak_memory_bytes: Option<usize>,
95 pub parameters: HashMap<String, String>,
97 pub raw_timings: Vec<Duration>,
99 pub timestamp: chrono::DateTime<chrono::Utc>,
101}
102
103impl BenchmarkResult {
104 fn percentile(sorted_timings: &[Duration], percentile: f64) -> Duration {
106 let index = ((sorted_timings.len() - 1) as f64 * percentile / 100.0) as usize;
107 sorted_timings[index]
108 }
109
110 pub fn from_timings(
112 name: String,
113 model_type: String,
114 timings: Vec<Duration>,
115 batch_size: usize,
116 seq_len: usize,
117 memory_bytes: Option<usize>,
118 peak_memory_bytes: Option<usize>,
119 ) -> Self {
120 let mut sorted_timings = timings.clone();
121 sorted_timings.sort();
122
123 let total_duration: Duration = timings.iter().sum();
124 let avg_duration = total_duration / timings.len() as u32;
125
126 let avg_ms = avg_duration.as_secs_f64() * 1000.0;
127 let variance = timings
128 .iter()
129 .map(|t| {
130 let diff = t.as_secs_f64() - avg_duration.as_secs_f64();
131 diff * diff
132 })
133 .sum::<f64>()
134 / timings.len() as f64;
135 let std_dev_ms = variance.sqrt() * 1000.0;
136
137 let tokens_per_batch = batch_size * seq_len;
138 let batches_per_sec = 1.0 / avg_duration.as_secs_f64();
139 let tokens_per_sec = tokens_per_batch as f64 * batches_per_sec;
140
141 let mut parameters = HashMap::new();
142 parameters.insert("batch_size".to_string(), batch_size.to_string());
143 parameters.insert("seq_len".to_string(), seq_len.to_string());
144 parameters.insert("num_iterations".to_string(), timings.len().to_string());
145
146 Self {
147 name,
148 model_type,
149 avg_latency_ms: avg_ms,
150 p50_latency_ms: Self::percentile(&sorted_timings, 50.0).as_secs_f64() * 1000.0,
151 p95_latency_ms: Self::percentile(&sorted_timings, 95.0).as_secs_f64() * 1000.0,
152 p99_latency_ms: Self::percentile(&sorted_timings, 99.0).as_secs_f64() * 1000.0,
153 min_latency_ms: sorted_timings[0].as_secs_f64() * 1000.0,
154 max_latency_ms: sorted_timings[sorted_timings.len() - 1].as_secs_f64() * 1000.0,
155 std_dev_ms,
156 throughput_tokens_per_sec: tokens_per_sec,
157 throughput_batches_per_sec: batches_per_sec,
158 memory_bytes,
159 peak_memory_bytes,
160 parameters,
161 raw_timings: timings,
162 timestamp: chrono::Utc::now(),
163 }
164 }
165}
166
167pub struct BenchmarkSuite {
169 results: Vec<BenchmarkResult>,
170 config: BenchmarkConfig,
171}
172
173impl BenchmarkSuite {
174 pub fn new(config: BenchmarkConfig) -> Self {
175 Self {
176 results: Vec::new(),
177 config,
178 }
179 }
180
181 pub fn benchmark_inference<M>(&mut self, model: &M, model_name: &str) -> Result<()>
183 where
184 M: Model<Input = ModelInput, Output = ModelOutput>,
185 {
186 println!("Benchmarking {} inference...", model_name);
187
188 for &batch_size in &self.config.batch_sizes {
189 for &seq_len in &self.config.sequence_lengths {
190 let result =
191 self.run_single_inference_benchmark(model, model_name, batch_size, seq_len)?;
192 self.results.push(result);
193 }
194 }
195
196 Ok(())
197 }
198
199 fn run_single_inference_benchmark<M>(
201 &self,
202 model: &M,
203 model_name: &str,
204 batch_size: usize,
205 seq_len: usize,
206 ) -> Result<BenchmarkResult>
207 where
208 M: Model<Input = ModelInput, Output = ModelOutput>,
209 {
210 println!(" Batch size: {}, Sequence length: {}", batch_size, seq_len);
211
212 let input_ids = Tensor::zeros(&[batch_size, seq_len])?;
214 let attention_mask = Some(Tensor::ones(&[batch_size, seq_len])?);
215
216 let model_input = ModelInput {
217 input_ids,
218 attention_mask,
219 token_type_ids: None,
220 position_ids: None,
221 };
222
223 let initial_memory =
225 if self.config.measure_memory { Some(self.get_memory_usage()) } else { None };
226
227 for _ in 0..self.config.warmup_iterations {
229 let _ = model.forward(model_input.clone())?;
230 }
231
232 let mut timings = Vec::with_capacity(self.config.num_iterations);
234 let mut peak_memory = initial_memory;
235
236 for _ in 0..self.config.num_iterations {
237 let start = Instant::now();
238 let _ = model.forward(model_input.clone())?;
239 let duration = start.elapsed();
240 timings.push(duration);
241
242 if self.config.measure_memory {
243 let current_memory = self.get_memory_usage();
244 if let (Some(peak), current) = (peak_memory.as_mut(), current_memory) {
245 *peak = (*peak).max(current);
246 }
247 }
248 }
249
250 let memory_usage = if self.config.measure_memory {
252 let final_memory = self.get_memory_usage();
253 initial_memory.map(|initial| final_memory - initial)
254 } else {
255 None
256 };
257
258 Ok(BenchmarkResult::from_timings(
259 format!("{}_inference_b{}_s{}", model_name, batch_size, seq_len),
260 model_name.to_string(),
261 timings,
262 batch_size,
263 seq_len,
264 memory_usage,
265 peak_memory.map(|p| p - initial_memory.unwrap_or(0)),
266 ))
267 }
268
269 fn get_memory_usage(&self) -> usize {
271 #[cfg(target_os = "linux")]
273 {
274 if let Ok(status) = std::fs::read_to_string("/proc/self/status") {
275 for line in status.lines() {
276 if line.starts_with("VmRSS:") {
277 if let Some(value_str) = line.split_whitespace().nth(1) {
278 if let Ok(kb) = value_str.parse::<usize>() {
279 return kb * 1024; }
281 }
282 }
283 }
284 }
285 }
286
287 #[cfg(target_os = "macos")]
288 {
289 use std::process::Command;
290 if let Ok(output) = Command::new("ps")
291 .args(["-o", "rss=", "-p"])
292 .arg(std::process::id().to_string())
293 .output()
294 {
295 if let Ok(rss_str) = String::from_utf8(output.stdout) {
296 if let Ok(kb) = rss_str.trim().parse::<usize>() {
297 return kb * 1024; }
299 }
300 }
301 }
302
303 #[cfg(target_os = "windows")]
304 {
305 use std::process::Command;
306 if let Ok(output) = Command::new("wmic")
307 .args([
308 "process",
309 "where",
310 &format!("ProcessId={}", std::process::id()),
311 "get",
312 "WorkingSetSize",
313 "/value",
314 ])
315 .output()
316 {
317 if let Ok(output_str) = String::from_utf8(output.stdout) {
318 for line in output_str.lines() {
319 if line.starts_with("WorkingSetSize=") {
320 if let Some(value_str) = line.split('=').nth(1) {
321 if let Ok(bytes) = value_str.parse::<usize>() {
322 return bytes;
323 }
324 }
325 }
326 }
327 }
328 }
329 }
330
331 let estimated_tensor_memory = self.results.len() * 1024 * 1024 * 50; let base_memory = 100 * 1024 * 1024; estimated_tensor_memory + base_memory
335 }
336
337 pub fn print_summary(&self) {
339 println!("\n=== Benchmark Results Summary ===");
340 println!(
341 "{:<40} {:>12} {:>12} {:>12} {:>12} {:>15}",
342 "Benchmark", "Avg (ms)", "P50 (ms)", "P95 (ms)", "P99 (ms)", "Throughput (tok/s)"
343 );
344 println!("{}", "-".repeat(103));
345
346 for result in &self.results {
347 println!(
348 "{:<40} {:>12.2} {:>12.2} {:>12.2} {:>12.2} {:>15.0}",
349 result.name,
350 result.avg_latency_ms,
351 result.p50_latency_ms,
352 result.p95_latency_ms,
353 result.p99_latency_ms,
354 result.throughput_tokens_per_sec,
355 );
356 }
357 }
358
359 pub fn export_json(&self, path: &str) -> Result<()> {
361 let json = serde_json::to_string_pretty(&self.results)?;
362 std::fs::write(path, json)?;
363 Ok(())
364 }
365
366 pub fn export_csv(&self, path: &str) -> Result<()> {
368 use std::io::Write;
369 let mut file = std::fs::File::create(path)?;
370
371 writeln!(file, "name,model_type,batch_size,seq_len,avg_latency_ms,p50_ms,p95_ms,p99_ms,min_ms,max_ms,std_dev_ms,throughput_tokens_sec,throughput_batches_sec,memory_bytes,timestamp")?;
373
374 for result in &self.results {
376 writeln!(
377 file,
378 "{},{},{},{},{:.2},{:.2},{:.2},{:.2},{:.2},{:.2},{:.2},{:.0},{:.2},{},{}",
379 result.name,
380 result.model_type,
381 result.parameters.get("batch_size").unwrap_or(&"0".to_string()),
382 result.parameters.get("seq_len").unwrap_or(&"0".to_string()),
383 result.avg_latency_ms,
384 result.p50_latency_ms,
385 result.p95_latency_ms,
386 result.p99_latency_ms,
387 result.min_latency_ms,
388 result.max_latency_ms,
389 result.std_dev_ms,
390 result.throughput_tokens_per_sec,
391 result.throughput_batches_per_sec,
392 result.memory_bytes.unwrap_or(0),
393 result.timestamp.to_rfc3339(),
394 )?;
395 }
396
397 Ok(())
398 }
399
400 pub fn results(&self) -> &[BenchmarkResult] {
402 &self.results
403 }
404
405 pub fn compare_with_baseline(&self, baseline: &[BenchmarkResult]) -> Vec<ComparisonSummary> {
407 let mut comparisons = Vec::new();
408
409 for result in &self.results {
410 if let Some(baseline_result) = baseline.iter().find(|b| b.name == result.name) {
411 let speedup = baseline_result.avg_latency_ms / result.avg_latency_ms;
412 let throughput_improvement =
413 result.throughput_tokens_per_sec / baseline_result.throughput_tokens_per_sec;
414
415 comparisons.push(ComparisonSummary {
416 benchmark_name: result.name.clone(),
417 speedup,
418 throughput_improvement,
419 latency_reduction_percent: (1.0
420 - result.avg_latency_ms / baseline_result.avg_latency_ms)
421 * 100.0,
422 memory_reduction_percent: if let (Some(current), Some(baseline)) =
423 (result.memory_bytes, baseline_result.memory_bytes)
424 {
425 Some((1.0 - current as f64 / baseline as f64) * 100.0)
426 } else {
427 None
428 },
429 });
430 }
431 }
432
433 comparisons
434 }
435}
436
437#[derive(Debug, Clone, Serialize, Deserialize)]
439pub struct ComparisonSummary {
440 pub benchmark_name: String,
441 pub speedup: f64,
442 pub throughput_improvement: f64,
443 pub latency_reduction_percent: f64,
444 pub memory_reduction_percent: Option<f64>,
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_benchmark_result_from_timings() {
453 let timings = vec![
454 Duration::from_millis(10),
455 Duration::from_millis(12),
456 Duration::from_millis(11),
457 Duration::from_millis(15),
458 Duration::from_millis(13),
459 ];
460
461 let result = BenchmarkResult::from_timings(
462 "test_benchmark".to_string(),
463 "TestModel".to_string(),
464 timings,
465 4,
466 128,
467 Some(1024 * 1024),
468 Some(2048 * 1024),
469 );
470
471 assert_eq!(result.name, "test_benchmark");
472 assert_eq!(result.model_type, "TestModel");
473 assert!(result.avg_latency_ms > 0.0);
474 assert!(result.throughput_tokens_per_sec > 0.0);
475 assert_eq!(
476 result.parameters.get("batch_size").expect("expected value not found"),
477 "4"
478 );
479 assert_eq!(
480 result.parameters.get("seq_len").expect("expected value not found"),
481 "128"
482 );
483 }
484
485 #[test]
486 fn test_benchmark_config_default() {
487 let config = BenchmarkConfig::default();
488 assert_eq!(config.batch_sizes, vec![1, 4, 8, 16, 32]);
489 assert_eq!(config.warmup_iterations, 10);
490 assert_eq!(config.num_iterations, 100);
491 assert!(config.measure_memory);
492 }
493}