Skip to main content

trustformers_core/performance/
comparison.rs

1//! Performance comparison with other frameworks (PyTorch, HuggingFace)
2
3#![allow(dead_code)] // Comparison framework with reserved features for future benchmarking
4
5use crate::performance::benchmark::BenchmarkResult;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Framework for comparison
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum Framework {
12    TrustformeRS,
13    PyTorch,
14    HuggingFace,
15    TensorFlow,
16    ONNX,
17}
18
19impl std::fmt::Display for Framework {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            Framework::TrustformeRS => write!(f, "TrustformeRS"),
23            Framework::PyTorch => write!(f, "PyTorch"),
24            Framework::HuggingFace => write!(f, "HuggingFace"),
25            Framework::TensorFlow => write!(f, "TensorFlow"),
26            Framework::ONNX => write!(f, "ONNX"),
27        }
28    }
29}
30
31/// Comparison result between frameworks
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ComparisonResult {
34    /// Benchmark name
35    pub benchmark_name: String,
36    /// Model type
37    pub model_type: String,
38    /// Batch size
39    pub batch_size: usize,
40    /// Sequence length
41    pub sequence_length: usize,
42    /// Results by framework
43    pub framework_results: HashMap<Framework, FrameworkMetrics>,
44    /// Relative performance (TrustformeRS vs others)
45    pub relative_performance: HashMap<Framework, RelativePerformance>,
46}
47
48/// Performance metrics for a framework
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct FrameworkMetrics {
51    /// Average latency in milliseconds
52    pub avg_latency_ms: f64,
53    /// P95 latency
54    pub p95_latency_ms: f64,
55    /// P99 latency
56    pub p99_latency_ms: f64,
57    /// Throughput in tokens/second
58    pub throughput_tokens_per_sec: f64,
59    /// Memory usage in MB
60    pub memory_mb: Option<f64>,
61    /// GPU memory usage in MB
62    pub gpu_memory_mb: Option<f64>,
63    /// Framework version
64    pub version: String,
65}
66
67/// Relative performance metrics
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct RelativePerformance {
70    /// Speedup factor (>1 means TrustformeRS is faster)
71    pub speedup: f64,
72    /// Throughput ratio
73    pub throughput_ratio: f64,
74    /// Memory efficiency ratio
75    pub memory_efficiency: Option<f64>,
76    /// Latency improvement percentage
77    pub latency_improvement_percent: f64,
78}
79
80/// Model comparison manager
81pub struct ModelComparison {
82    results: Vec<ComparisonResult>,
83}
84
85impl Default for ModelComparison {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl ModelComparison {
92    pub fn new() -> Self {
93        Self {
94            results: Vec::new(),
95        }
96    }
97
98    /// Add TrustformeRS benchmark results
99    pub fn add_trustformers_results(&mut self, results: &[BenchmarkResult]) {
100        for result in results {
101            let framework_metrics = FrameworkMetrics {
102                avg_latency_ms: result.avg_latency_ms,
103                p95_latency_ms: result.p95_latency_ms,
104                p99_latency_ms: result.p99_latency_ms,
105                throughput_tokens_per_sec: result.throughput_tokens_per_sec,
106                memory_mb: result.memory_bytes.map(|b| b as f64 / (1024.0 * 1024.0)),
107                gpu_memory_mb: None, // Would be set if using GPU
108                version: env!("CARGO_PKG_VERSION").to_string(),
109            };
110
111            // Extract batch size and sequence length from parameters
112            let batch_size =
113                result.parameters.get("batch_size").and_then(|s| s.parse().ok()).unwrap_or(1);
114            let seq_len =
115                result.parameters.get("seq_len").and_then(|s| s.parse().ok()).unwrap_or(128);
116
117            // Check if we already have a comparison for this benchmark
118            if let Some(comparison) = self.results.iter_mut().find(|c| {
119                c.benchmark_name == result.name
120                    && c.batch_size == batch_size
121                    && c.sequence_length == seq_len
122            }) {
123                comparison.framework_results.insert(Framework::TrustformeRS, framework_metrics);
124            } else {
125                let mut framework_results = HashMap::new();
126                framework_results.insert(Framework::TrustformeRS, framework_metrics);
127
128                self.results.push(ComparisonResult {
129                    benchmark_name: result.name.clone(),
130                    model_type: result.model_type.clone(),
131                    batch_size,
132                    sequence_length: seq_len,
133                    framework_results,
134                    relative_performance: HashMap::new(),
135                });
136            }
137        }
138    }
139
140    /// Add PyTorch benchmark results
141    pub fn add_pytorch_results(&mut self, pytorch_results: &[PytorchBenchmark]) {
142        for result in pytorch_results {
143            let framework_metrics = FrameworkMetrics {
144                avg_latency_ms: result.avg_latency_ms,
145                p95_latency_ms: result.p95_latency_ms,
146                p99_latency_ms: result.p99_latency_ms,
147                throughput_tokens_per_sec: result.throughput_tokens_per_sec,
148                memory_mb: result.memory_mb,
149                gpu_memory_mb: result.gpu_memory_mb,
150                version: result.torch_version.clone(),
151            };
152
153            // Find or create comparison
154            if let Some(comparison) = self.results.iter_mut().find(|c| {
155                c.benchmark_name == result.name
156                    && c.batch_size == result.batch_size
157                    && c.sequence_length == result.sequence_length
158            }) {
159                comparison.framework_results.insert(Framework::PyTorch, framework_metrics);
160            } else {
161                let mut framework_results = HashMap::new();
162                framework_results.insert(Framework::PyTorch, framework_metrics);
163
164                self.results.push(ComparisonResult {
165                    benchmark_name: result.name.clone(),
166                    model_type: result.model_type.clone(),
167                    batch_size: result.batch_size,
168                    sequence_length: result.sequence_length,
169                    framework_results,
170                    relative_performance: HashMap::new(),
171                });
172            }
173        }
174
175        // Calculate relative performance
176        self.calculate_relative_performance();
177    }
178
179    /// Add HuggingFace benchmark results
180    pub fn add_huggingface_results(&mut self, hf_results: &[HuggingFaceBenchmark]) {
181        for result in hf_results {
182            let framework_metrics = FrameworkMetrics {
183                avg_latency_ms: result.avg_latency_ms,
184                p95_latency_ms: result.p95_latency_ms,
185                p99_latency_ms: result.p99_latency_ms,
186                throughput_tokens_per_sec: result.throughput_tokens_per_sec,
187                memory_mb: result.memory_mb,
188                gpu_memory_mb: result.gpu_memory_mb,
189                version: result.transformers_version.clone(),
190            };
191
192            // Find or create comparison
193            if let Some(comparison) = self.results.iter_mut().find(|c| {
194                c.benchmark_name == result.name
195                    && c.batch_size == result.batch_size
196                    && c.sequence_length == result.sequence_length
197            }) {
198                comparison.framework_results.insert(Framework::HuggingFace, framework_metrics);
199            } else {
200                let mut framework_results = HashMap::new();
201                framework_results.insert(Framework::HuggingFace, framework_metrics);
202
203                self.results.push(ComparisonResult {
204                    benchmark_name: result.name.clone(),
205                    model_type: result.model_type.clone(),
206                    batch_size: result.batch_size,
207                    sequence_length: result.sequence_length,
208                    framework_results,
209                    relative_performance: HashMap::new(),
210                });
211            }
212        }
213
214        // Calculate relative performance
215        self.calculate_relative_performance();
216    }
217
218    /// Calculate relative performance metrics
219    fn calculate_relative_performance(&mut self) {
220        for comparison in &mut self.results {
221            if let Some(trustformers) = comparison.framework_results.get(&Framework::TrustformeRS) {
222                // Compare with each other framework
223                for (framework, metrics) in &comparison.framework_results {
224                    if framework != &Framework::TrustformeRS {
225                        let speedup = metrics.avg_latency_ms / trustformers.avg_latency_ms;
226                        let throughput_ratio = trustformers.throughput_tokens_per_sec
227                            / metrics.throughput_tokens_per_sec;
228                        let latency_improvement =
229                            (1.0 - trustformers.avg_latency_ms / metrics.avg_latency_ms) * 100.0;
230
231                        let memory_efficiency = match (trustformers.memory_mb, metrics.memory_mb) {
232                            (Some(tf_mem), Some(other_mem)) => Some(other_mem / tf_mem),
233                            _ => None,
234                        };
235
236                        comparison.relative_performance.insert(
237                            *framework,
238                            RelativePerformance {
239                                speedup,
240                                throughput_ratio,
241                                memory_efficiency,
242                                latency_improvement_percent: latency_improvement,
243                            },
244                        );
245                    }
246                }
247            }
248        }
249    }
250
251    /// Generate comparison report
252    pub fn generate_report(&self) -> ComparisonReport {
253        let mut summary = ComparisonSummary::default();
254
255        // Calculate average performance across all benchmarks
256        for comparison in &self.results {
257            for (framework, perf) in &comparison.relative_performance {
258                summary
259                    .avg_speedup
260                    .entry(*framework)
261                    .and_modify(|v| v.0 += perf.speedup)
262                    .or_insert((perf.speedup, 1));
263                summary.avg_speedup.entry(*framework).and_modify(|v| v.1 += 1);
264
265                if perf.speedup > 1.0 {
266                    *summary.benchmarks_faster.entry(*framework).or_insert(0) += 1;
267                } else {
268                    *summary.benchmarks_slower.entry(*framework).or_insert(0) += 1;
269                }
270            }
271        }
272
273        // Calculate averages
274        for (_, (sum, count)) in summary.avg_speedup.iter_mut() {
275            if *count > 0 {
276                *sum /= *count as f64;
277            }
278        }
279
280        ComparisonReport {
281            comparisons: self.results.clone(),
282            summary,
283        }
284    }
285
286    /// Print comparison summary
287    pub fn print_summary(&self) {
288        println!("\n=== Performance Comparison Summary ===");
289
290        for comparison in &self.results {
291            println!(
292                "\n{} (batch={}, seq_len={})",
293                comparison.benchmark_name, comparison.batch_size, comparison.sequence_length
294            );
295
296            // Print metrics for each framework
297            println!(
298                "  {:15} {:>10} {:>10} {:>15} {:>10}",
299                "Framework", "Avg (ms)", "P95 (ms)", "Throughput", "Memory (MB)"
300            );
301            println!("  {}", "-".repeat(65));
302
303            for (framework, metrics) in &comparison.framework_results {
304                println!(
305                    "  {:15} {:>10.2} {:>10.2} {:>15.0} {:>10}",
306                    framework.to_string(),
307                    metrics.avg_latency_ms,
308                    metrics.p95_latency_ms,
309                    metrics.throughput_tokens_per_sec,
310                    metrics.memory_mb.map(|m| format!("{:.1}", m)).unwrap_or("-".to_string()),
311                );
312            }
313
314            // Print relative performance
315            if !comparison.relative_performance.is_empty() {
316                println!("\n  Relative Performance (TrustformeRS vs others):");
317                for (framework, perf) in &comparison.relative_performance {
318                    println!(
319                        "    vs {}: {:.2}x speedup, {:.1}% latency improvement",
320                        framework, perf.speedup, perf.latency_improvement_percent
321                    );
322                }
323            }
324        }
325    }
326}
327
328/// PyTorch benchmark result structure
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct PytorchBenchmark {
331    pub name: String,
332    pub model_type: String,
333    pub batch_size: usize,
334    pub sequence_length: usize,
335    pub avg_latency_ms: f64,
336    pub p95_latency_ms: f64,
337    pub p99_latency_ms: f64,
338    pub throughput_tokens_per_sec: f64,
339    pub memory_mb: Option<f64>,
340    pub gpu_memory_mb: Option<f64>,
341    pub torch_version: String,
342}
343
344/// HuggingFace benchmark result structure
345#[derive(Debug, Clone, Serialize, Deserialize)]
346pub struct HuggingFaceBenchmark {
347    pub name: String,
348    pub model_type: String,
349    pub batch_size: usize,
350    pub sequence_length: usize,
351    pub avg_latency_ms: f64,
352    pub p95_latency_ms: f64,
353    pub p99_latency_ms: f64,
354    pub throughput_tokens_per_sec: f64,
355    pub memory_mb: Option<f64>,
356    pub gpu_memory_mb: Option<f64>,
357    pub transformers_version: String,
358}
359
360/// Comparison report
361#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct ComparisonReport {
363    pub comparisons: Vec<ComparisonResult>,
364    pub summary: ComparisonSummary,
365}
366
367/// Comparison summary statistics
368#[derive(Debug, Clone, Default, Serialize, Deserialize)]
369pub struct ComparisonSummary {
370    /// Average speedup by framework (value, count)
371    pub avg_speedup: HashMap<Framework, (f64, usize)>,
372    /// Number of benchmarks where TrustformeRS is faster
373    pub benchmarks_faster: HashMap<Framework, usize>,
374    /// Number of benchmarks where TrustformeRS is slower
375    pub benchmarks_slower: HashMap<Framework, usize>,
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_comparison() {
384        let mut comparison = ModelComparison::new();
385
386        // Add TrustformeRS result
387        let tf_result = BenchmarkResult {
388            name: "bert_inference".to_string(),
389            model_type: "BERT".to_string(),
390            avg_latency_ms: 50.0,
391            p50_latency_ms: 48.0,
392            p95_latency_ms: 55.0,
393            p99_latency_ms: 60.0,
394            min_latency_ms: 45.0,
395            max_latency_ms: 65.0,
396            std_dev_ms: 5.0,
397            throughput_tokens_per_sec: 2560.0, // 4 * 128 * 1000 / 50
398            throughput_batches_per_sec: 20.0,
399            memory_bytes: Some(100 * 1024 * 1024),
400            peak_memory_bytes: Some(150 * 1024 * 1024),
401            parameters: {
402                let mut params = HashMap::new();
403                params.insert("batch_size".to_string(), "4".to_string());
404                params.insert("seq_len".to_string(), "128".to_string());
405                params
406            },
407            raw_timings: vec![],
408            timestamp: chrono::Utc::now(),
409        };
410
411        comparison.add_trustformers_results(&[tf_result]);
412
413        // Add PyTorch result
414        let pytorch_result = PytorchBenchmark {
415            name: "bert_inference".to_string(),
416            model_type: "BERT".to_string(),
417            batch_size: 4,
418            sequence_length: 128,
419            avg_latency_ms: 60.0,
420            p95_latency_ms: 65.0,
421            p99_latency_ms: 70.0,
422            throughput_tokens_per_sec: 2133.3,
423            memory_mb: Some(120.0),
424            gpu_memory_mb: None,
425            torch_version: "2.0.0".to_string(),
426        };
427
428        comparison.add_pytorch_results(&[pytorch_result]);
429
430        // Check comparison results
431        assert_eq!(comparison.results.len(), 1);
432        let result = &comparison.results[0];
433        assert_eq!(result.framework_results.len(), 2);
434
435        // Check relative performance
436        let pytorch_perf = result
437            .relative_performance
438            .get(&Framework::PyTorch)
439            .expect("expected value not found");
440        assert!(pytorch_perf.speedup > 1.0); // TrustformeRS should be faster
441        assert!(pytorch_perf.latency_improvement_percent > 0.0);
442    }
443}