Skip to main content

trustformers_core/performance/
benchmark.rs

1//! Core benchmarking infrastructure for TrustformeRS
2
3use crate::tensor::Tensor;
4use crate::traits::Model;
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10/// Model input structure for benchmarking
11#[derive(Clone)]
12pub struct ModelInput {
13    pub input_ids: Tensor,
14    pub attention_mask: Option<Tensor>,
15    pub token_type_ids: Option<Tensor>,
16    pub position_ids: Option<Tensor>,
17}
18
19/// Model output structure for benchmarking
20#[derive(Default)]
21pub struct ModelOutput {
22    pub hidden_states: Option<Tensor>,
23    pub logits: Option<Tensor>,
24    pub attentions: Option<Vec<Tensor>>,
25}
26
27/// Configuration for benchmark runs
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct BenchmarkConfig {
30    /// Batch sizes to test
31    pub batch_sizes: Vec<usize>,
32    /// Sequence lengths to test
33    pub sequence_lengths: Vec<usize>,
34    /// Number of warmup iterations
35    pub warmup_iterations: usize,
36    /// Number of benchmark iterations
37    pub num_iterations: usize,
38    /// Whether to measure memory usage
39    pub measure_memory: bool,
40    /// Device to run benchmarks on (cpu, cuda, etc.)
41    pub device: String,
42    /// Whether to use mixed precision
43    pub use_fp16: bool,
44    /// Whether to benchmark generation tasks
45    pub include_generation: bool,
46    /// Maximum generation length for generation benchmarks
47    pub max_generation_length: Option<usize>,
48}
49
50impl Default for BenchmarkConfig {
51    fn default() -> Self {
52        Self {
53            batch_sizes: vec![1, 4, 8, 16, 32],
54            sequence_lengths: vec![128, 256, 512, 1024, 2048],
55            warmup_iterations: 10,
56            num_iterations: 100,
57            measure_memory: true,
58            device: "cpu".to_string(),
59            use_fp16: false,
60            include_generation: false,
61            max_generation_length: Some(256),
62        }
63    }
64}
65
66/// Result of a single benchmark run
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct BenchmarkResult {
69    /// Name of the benchmark
70    pub name: String,
71    /// Model type (BERT, GPT-2, etc.)
72    pub model_type: String,
73    /// Average latency per forward pass
74    pub avg_latency_ms: f64,
75    /// P50 latency
76    pub p50_latency_ms: f64,
77    /// P95 latency
78    pub p95_latency_ms: f64,
79    /// P99 latency
80    pub p99_latency_ms: f64,
81    /// Minimum latency
82    pub min_latency_ms: f64,
83    /// Maximum latency
84    pub max_latency_ms: f64,
85    /// Standard deviation of latency
86    pub std_dev_ms: f64,
87    /// Throughput in tokens per second
88    pub throughput_tokens_per_sec: f64,
89    /// Throughput in batches per second
90    pub throughput_batches_per_sec: f64,
91    /// Memory usage in bytes
92    pub memory_bytes: Option<usize>,
93    /// Peak memory usage
94    pub peak_memory_bytes: Option<usize>,
95    /// Configuration parameters
96    pub parameters: HashMap<String, String>,
97    /// Raw timing data
98    pub raw_timings: Vec<Duration>,
99    /// Timestamp of the benchmark
100    pub timestamp: chrono::DateTime<chrono::Utc>,
101}
102
103impl BenchmarkResult {
104    /// Calculate percentile from sorted timings
105    fn percentile(sorted_timings: &[Duration], percentile: f64) -> Duration {
106        let index = ((sorted_timings.len() - 1) as f64 * percentile / 100.0) as usize;
107        sorted_timings[index]
108    }
109
110    /// Create result from raw timings
111    pub fn from_timings(
112        name: String,
113        model_type: String,
114        timings: Vec<Duration>,
115        batch_size: usize,
116        seq_len: usize,
117        memory_bytes: Option<usize>,
118        peak_memory_bytes: Option<usize>,
119    ) -> Self {
120        let mut sorted_timings = timings.clone();
121        sorted_timings.sort();
122
123        let total_duration: Duration = timings.iter().sum();
124        let avg_duration = total_duration / timings.len() as u32;
125
126        let avg_ms = avg_duration.as_secs_f64() * 1000.0;
127        let variance = timings
128            .iter()
129            .map(|t| {
130                let diff = t.as_secs_f64() - avg_duration.as_secs_f64();
131                diff * diff
132            })
133            .sum::<f64>()
134            / timings.len() as f64;
135        let std_dev_ms = variance.sqrt() * 1000.0;
136
137        let tokens_per_batch = batch_size * seq_len;
138        let batches_per_sec = 1.0 / avg_duration.as_secs_f64();
139        let tokens_per_sec = tokens_per_batch as f64 * batches_per_sec;
140
141        let mut parameters = HashMap::new();
142        parameters.insert("batch_size".to_string(), batch_size.to_string());
143        parameters.insert("seq_len".to_string(), seq_len.to_string());
144        parameters.insert("num_iterations".to_string(), timings.len().to_string());
145
146        Self {
147            name,
148            model_type,
149            avg_latency_ms: avg_ms,
150            p50_latency_ms: Self::percentile(&sorted_timings, 50.0).as_secs_f64() * 1000.0,
151            p95_latency_ms: Self::percentile(&sorted_timings, 95.0).as_secs_f64() * 1000.0,
152            p99_latency_ms: Self::percentile(&sorted_timings, 99.0).as_secs_f64() * 1000.0,
153            min_latency_ms: sorted_timings[0].as_secs_f64() * 1000.0,
154            max_latency_ms: sorted_timings[sorted_timings.len() - 1].as_secs_f64() * 1000.0,
155            std_dev_ms,
156            throughput_tokens_per_sec: tokens_per_sec,
157            throughput_batches_per_sec: batches_per_sec,
158            memory_bytes,
159            peak_memory_bytes,
160            parameters,
161            raw_timings: timings,
162            timestamp: chrono::Utc::now(),
163        }
164    }
165}
166
167/// Main benchmark suite for running performance tests
168pub struct BenchmarkSuite {
169    results: Vec<BenchmarkResult>,
170    config: BenchmarkConfig,
171}
172
173impl BenchmarkSuite {
174    pub fn new(config: BenchmarkConfig) -> Self {
175        Self {
176            results: Vec::new(),
177            config,
178        }
179    }
180
181    /// Run inference benchmark for a model that takes ModelInput
182    pub fn benchmark_inference<M>(&mut self, model: &M, model_name: &str) -> Result<()>
183    where
184        M: Model<Input = ModelInput, Output = ModelOutput>,
185    {
186        println!("Benchmarking {} inference...", model_name);
187
188        for &batch_size in &self.config.batch_sizes {
189            for &seq_len in &self.config.sequence_lengths {
190                let result =
191                    self.run_single_inference_benchmark(model, model_name, batch_size, seq_len)?;
192                self.results.push(result);
193            }
194        }
195
196        Ok(())
197    }
198
199    /// Run a single inference benchmark
200    fn run_single_inference_benchmark<M>(
201        &self,
202        model: &M,
203        model_name: &str,
204        batch_size: usize,
205        seq_len: usize,
206    ) -> Result<BenchmarkResult>
207    where
208        M: Model<Input = ModelInput, Output = ModelOutput>,
209    {
210        println!("  Batch size: {}, Sequence length: {}", batch_size, seq_len);
211
212        // Create dummy input
213        let input_ids = Tensor::zeros(&[batch_size, seq_len])?;
214        let attention_mask = Some(Tensor::ones(&[batch_size, seq_len])?);
215
216        let model_input = ModelInput {
217            input_ids,
218            attention_mask,
219            token_type_ids: None,
220            position_ids: None,
221        };
222
223        // Get initial memory snapshot
224        let initial_memory =
225            if self.config.measure_memory { Some(self.get_memory_usage()) } else { None };
226
227        // Warmup
228        for _ in 0..self.config.warmup_iterations {
229            let _ = model.forward(model_input.clone())?;
230        }
231
232        // Benchmark
233        let mut timings = Vec::with_capacity(self.config.num_iterations);
234        let mut peak_memory = initial_memory;
235
236        for _ in 0..self.config.num_iterations {
237            let start = Instant::now();
238            let _ = model.forward(model_input.clone())?;
239            let duration = start.elapsed();
240            timings.push(duration);
241
242            if self.config.measure_memory {
243                let current_memory = self.get_memory_usage();
244                if let (Some(peak), current) = (peak_memory.as_mut(), current_memory) {
245                    *peak = (*peak).max(current);
246                }
247            }
248        }
249
250        // Calculate memory usage
251        let memory_usage = if self.config.measure_memory {
252            let final_memory = self.get_memory_usage();
253            initial_memory.map(|initial| final_memory - initial)
254        } else {
255            None
256        };
257
258        Ok(BenchmarkResult::from_timings(
259            format!("{}_inference_b{}_s{}", model_name, batch_size, seq_len),
260            model_name.to_string(),
261            timings,
262            batch_size,
263            seq_len,
264            memory_usage,
265            peak_memory.map(|p| p - initial_memory.unwrap_or(0)),
266        ))
267    }
268
269    /// Get current memory usage
270    fn get_memory_usage(&self) -> usize {
271        // Platform-specific memory usage tracking
272        #[cfg(target_os = "linux")]
273        {
274            if let Ok(status) = std::fs::read_to_string("/proc/self/status") {
275                for line in status.lines() {
276                    if line.starts_with("VmRSS:") {
277                        if let Some(value_str) = line.split_whitespace().nth(1) {
278                            if let Ok(kb) = value_str.parse::<usize>() {
279                                return kb * 1024; // Convert KB to bytes
280                            }
281                        }
282                    }
283                }
284            }
285        }
286
287        #[cfg(target_os = "macos")]
288        {
289            use std::process::Command;
290            if let Ok(output) = Command::new("ps")
291                .args(["-o", "rss=", "-p"])
292                .arg(std::process::id().to_string())
293                .output()
294            {
295                if let Ok(rss_str) = String::from_utf8(output.stdout) {
296                    if let Ok(kb) = rss_str.trim().parse::<usize>() {
297                        return kb * 1024; // Convert KB to bytes
298                    }
299                }
300            }
301        }
302
303        #[cfg(target_os = "windows")]
304        {
305            use std::process::Command;
306            if let Ok(output) = Command::new("wmic")
307                .args([
308                    "process",
309                    "where",
310                    &format!("ProcessId={}", std::process::id()),
311                    "get",
312                    "WorkingSetSize",
313                    "/value",
314                ])
315                .output()
316            {
317                if let Ok(output_str) = String::from_utf8(output.stdout) {
318                    for line in output_str.lines() {
319                        if line.starts_with("WorkingSetSize=") {
320                            if let Some(value_str) = line.split('=').nth(1) {
321                                if let Ok(bytes) = value_str.parse::<usize>() {
322                                    return bytes;
323                                }
324                            }
325                        }
326                    }
327                }
328            }
329        }
330
331        // Fallback: estimate based on heap allocations and typical overhead
332        let estimated_tensor_memory = self.results.len() * 1024 * 1024 * 50; // 50MB per benchmark result
333        let base_memory = 100 * 1024 * 1024; // 100MB base overhead
334        estimated_tensor_memory + base_memory
335    }
336
337    /// Print benchmark summary
338    pub fn print_summary(&self) {
339        println!("\n=== Benchmark Results Summary ===");
340        println!(
341            "{:<40} {:>12} {:>12} {:>12} {:>12} {:>15}",
342            "Benchmark", "Avg (ms)", "P50 (ms)", "P95 (ms)", "P99 (ms)", "Throughput (tok/s)"
343        );
344        println!("{}", "-".repeat(103));
345
346        for result in &self.results {
347            println!(
348                "{:<40} {:>12.2} {:>12.2} {:>12.2} {:>12.2} {:>15.0}",
349                result.name,
350                result.avg_latency_ms,
351                result.p50_latency_ms,
352                result.p95_latency_ms,
353                result.p99_latency_ms,
354                result.throughput_tokens_per_sec,
355            );
356        }
357    }
358
359    /// Export results to JSON
360    pub fn export_json(&self, path: &str) -> Result<()> {
361        let json = serde_json::to_string_pretty(&self.results)?;
362        std::fs::write(path, json)?;
363        Ok(())
364    }
365
366    /// Export results to CSV
367    pub fn export_csv(&self, path: &str) -> Result<()> {
368        use std::io::Write;
369        let mut file = std::fs::File::create(path)?;
370
371        // Write header
372        writeln!(file, "name,model_type,batch_size,seq_len,avg_latency_ms,p50_ms,p95_ms,p99_ms,min_ms,max_ms,std_dev_ms,throughput_tokens_sec,throughput_batches_sec,memory_bytes,timestamp")?;
373
374        // Write data
375        for result in &self.results {
376            writeln!(
377                file,
378                "{},{},{},{},{:.2},{:.2},{:.2},{:.2},{:.2},{:.2},{:.2},{:.0},{:.2},{},{}",
379                result.name,
380                result.model_type,
381                result.parameters.get("batch_size").unwrap_or(&"0".to_string()),
382                result.parameters.get("seq_len").unwrap_or(&"0".to_string()),
383                result.avg_latency_ms,
384                result.p50_latency_ms,
385                result.p95_latency_ms,
386                result.p99_latency_ms,
387                result.min_latency_ms,
388                result.max_latency_ms,
389                result.std_dev_ms,
390                result.throughput_tokens_per_sec,
391                result.throughput_batches_per_sec,
392                result.memory_bytes.unwrap_or(0),
393                result.timestamp.to_rfc3339(),
394            )?;
395        }
396
397        Ok(())
398    }
399
400    /// Get results
401    pub fn results(&self) -> &[BenchmarkResult] {
402        &self.results
403    }
404
405    /// Compare with baseline results
406    pub fn compare_with_baseline(&self, baseline: &[BenchmarkResult]) -> Vec<ComparisonSummary> {
407        let mut comparisons = Vec::new();
408
409        for result in &self.results {
410            if let Some(baseline_result) = baseline.iter().find(|b| b.name == result.name) {
411                let speedup = baseline_result.avg_latency_ms / result.avg_latency_ms;
412                let throughput_improvement =
413                    result.throughput_tokens_per_sec / baseline_result.throughput_tokens_per_sec;
414
415                comparisons.push(ComparisonSummary {
416                    benchmark_name: result.name.clone(),
417                    speedup,
418                    throughput_improvement,
419                    latency_reduction_percent: (1.0
420                        - result.avg_latency_ms / baseline_result.avg_latency_ms)
421                        * 100.0,
422                    memory_reduction_percent: if let (Some(current), Some(baseline)) =
423                        (result.memory_bytes, baseline_result.memory_bytes)
424                    {
425                        Some((1.0 - current as f64 / baseline as f64) * 100.0)
426                    } else {
427                        None
428                    },
429                });
430            }
431        }
432
433        comparisons
434    }
435}
436
437/// Summary of performance comparison
438#[derive(Debug, Clone, Serialize, Deserialize)]
439pub struct ComparisonSummary {
440    pub benchmark_name: String,
441    pub speedup: f64,
442    pub throughput_improvement: f64,
443    pub latency_reduction_percent: f64,
444    pub memory_reduction_percent: Option<f64>,
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_benchmark_result_from_timings() {
453        let timings = vec![
454            Duration::from_millis(10),
455            Duration::from_millis(12),
456            Duration::from_millis(11),
457            Duration::from_millis(15),
458            Duration::from_millis(13),
459        ];
460
461        let result = BenchmarkResult::from_timings(
462            "test_benchmark".to_string(),
463            "TestModel".to_string(),
464            timings,
465            4,
466            128,
467            Some(1024 * 1024),
468            Some(2048 * 1024),
469        );
470
471        assert_eq!(result.name, "test_benchmark");
472        assert_eq!(result.model_type, "TestModel");
473        assert!(result.avg_latency_ms > 0.0);
474        assert!(result.throughput_tokens_per_sec > 0.0);
475        assert_eq!(
476            result.parameters.get("batch_size").expect("expected value not found"),
477            "4"
478        );
479        assert_eq!(
480            result.parameters.get("seq_len").expect("expected value not found"),
481            "128"
482        );
483    }
484
485    #[test]
486    fn test_benchmark_config_default() {
487        let config = BenchmarkConfig::default();
488        assert_eq!(config.batch_sizes, vec![1, 4, 8, 16, 32]);
489        assert_eq!(config.warmup_iterations, 10);
490        assert_eq!(config.num_iterations, 100);
491        assert!(config.measure_memory);
492    }
493}