Skip to main content

trustformers_tokenizers/
benchmark_utils.rs

1//! Benchmark utilities for tokenization performance measurement
2//!
3//! This module provides simple utilities for measuring tokenization performance,
4//! allowing users to benchmark their tokenizers with their own data.
5
6use std::time::Instant;
7use trustformers_core::traits::Tokenizer;
8
9/// Simple benchmark configuration
10#[derive(Debug, Clone)]
11pub struct BenchmarkConfig {
12    /// Number of warmup iterations
13    pub warmup_iterations: usize,
14    /// Number of measurement iterations
15    pub measurement_iterations: usize,
16    /// Whether to include detailed statistics
17    pub detailed_stats: bool,
18}
19
20impl Default for BenchmarkConfig {
21    fn default() -> Self {
22        Self {
23            warmup_iterations: 10,
24            measurement_iterations: 100,
25            detailed_stats: true,
26        }
27    }
28}
29
30/// Benchmark results for tokenization performance
31#[derive(Debug, Clone)]
32pub struct BenchmarkResult {
33    /// Average tokens per second
34    pub tokens_per_second: f64,
35    /// Average characters per second
36    pub characters_per_second: f64,
37    /// Average latency per text (microseconds)
38    pub average_latency_us: f64,
39    /// Minimum latency (microseconds)
40    pub min_latency_us: f64,
41    /// Maximum latency (microseconds)
42    pub max_latency_us: f64,
43    /// Total texts processed
44    pub total_texts: usize,
45    /// Total tokens produced
46    pub total_tokens: usize,
47    /// Total characters processed
48    pub total_characters: usize,
49}
50
51impl BenchmarkResult {
52    /// Create a simple summary string
53    pub fn summary(&self) -> String {
54        format!(
55            "Performance: {:.0} tokens/sec, {:.0} chars/sec, {:.2}μs avg latency",
56            self.tokens_per_second, self.characters_per_second, self.average_latency_us
57        )
58    }
59
60    /// Create a detailed report
61    pub fn detailed_report(&self) -> String {
62        format!(
63            r#"Tokenization Benchmark Results
64==============================
65Throughput:
66  - Tokens per second: {:.2}
67  - Characters per second: {:.2}
68  - Texts per second: {:.2}
69
70Latency (per text):
71  - Average: {:.2} μs
72  - Minimum: {:.2} μs
73  - Maximum: {:.2} μs
74
75Volume:
76  - Total texts: {}
77  - Total tokens: {}
78  - Total characters: {}
79  - Average tokens per text: {:.1}
80  - Average characters per text: {:.1}"#,
81            self.tokens_per_second,
82            self.characters_per_second,
83            self.total_texts as f64 / (self.average_latency_us / 1_000_000.0),
84            self.average_latency_us,
85            self.min_latency_us,
86            self.max_latency_us,
87            self.total_texts,
88            self.total_tokens,
89            self.total_characters,
90            self.total_tokens as f64 / self.total_texts as f64,
91            self.total_characters as f64 / self.total_texts as f64
92        )
93    }
94}
95
96/// Simple benchmark runner for tokenizers
97pub struct TokenizerBenchmark;
98
99impl TokenizerBenchmark {
100    /// Benchmark a tokenizer with the given texts
101    pub fn benchmark<T: Tokenizer>(
102        tokenizer: &T,
103        texts: &[String],
104        config: BenchmarkConfig,
105    ) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
106        if texts.is_empty() {
107            return Err("No texts provided for benchmarking".into());
108        }
109
110        // Warmup phase
111        for _ in 0..config.warmup_iterations {
112            for text in texts.iter().take(std::cmp::min(texts.len(), 10)) {
113                let _ = tokenizer.encode(text)?;
114            }
115        }
116
117        // Measurement phase
118        let mut latencies = Vec::new();
119        let mut total_tokens = 0;
120        let mut total_characters = 0;
121
122        for _ in 0..config.measurement_iterations {
123            for text in texts {
124                let start = Instant::now();
125                let result = tokenizer.encode(text)?;
126                let elapsed = start.elapsed();
127
128                latencies.push(elapsed.as_micros() as f64);
129                total_tokens += result.input_ids.len();
130                total_characters += text.len();
131            }
132        }
133
134        // Calculate statistics
135        let total_time_seconds = latencies.iter().sum::<f64>() / 1_000_000.0;
136        let average_latency_us = latencies.iter().sum::<f64>() / latencies.len() as f64;
137        let min_latency_us = latencies.iter().cloned().fold(f64::INFINITY, f64::min);
138        let max_latency_us = latencies.iter().cloned().fold(0.0, f64::max);
139
140        let tokens_per_second = total_tokens as f64 / total_time_seconds;
141        let characters_per_second = total_characters as f64 / total_time_seconds;
142        let total_texts = texts.len() * config.measurement_iterations;
143
144        Ok(BenchmarkResult {
145            tokens_per_second,
146            characters_per_second,
147            average_latency_us,
148            min_latency_us,
149            max_latency_us,
150            total_texts,
151            total_tokens,
152            total_characters,
153        })
154    }
155
156    /// Quick benchmark with a single text repeated multiple times
157    pub fn quick_benchmark<T: Tokenizer>(
158        tokenizer: &T,
159        text: &str,
160        iterations: usize,
161    ) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
162        let texts = vec![text.to_string(); 1];
163        let config = BenchmarkConfig {
164            warmup_iterations: 5,
165            measurement_iterations: iterations,
166            detailed_stats: false,
167        };
168
169        Self::benchmark(tokenizer, &texts, config)
170    }
171
172    /// Benchmark with sample texts of different lengths
173    pub fn multi_length_benchmark<T: Tokenizer>(
174        tokenizer: &T,
175    ) -> Result<Vec<(String, BenchmarkResult)>, Box<dyn std::error::Error>> {
176        let test_cases = vec![
177            ("Short text", "Hello world!".to_string()),
178            ("Medium text", "This is a longer text that contains multiple sentences and should give us a better idea of tokenization performance on medium-length inputs.".to_string()),
179            ("Long text", "This is a much longer text that contains many sentences and words. It is designed to test the performance of tokenization on longer inputs that might be more representative of real-world usage scenarios. The text includes various punctuation marks, numbers like 123 and 456, and different types of content that a tokenizer might encounter in practice. This should help identify any performance differences between short and long text processing.".to_string()),
180        ];
181
182        let mut results = Vec::new();
183        for (name, text) in test_cases {
184            let result = Self::quick_benchmark(tokenizer, &text, 100)?;
185            results.push((name.to_string(), result));
186        }
187
188        Ok(results)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::char::CharTokenizer;
196    use std::collections::HashMap;
197
198    #[test]
199    fn test_benchmark_config_default() {
200        let config = BenchmarkConfig::default();
201        assert_eq!(config.warmup_iterations, 10);
202        assert_eq!(config.measurement_iterations, 100);
203        assert!(config.detailed_stats);
204    }
205
206    #[test]
207    fn test_quick_benchmark() {
208        let tokenizer = CharTokenizer::new(HashMap::new());
209        let result = TokenizerBenchmark::quick_benchmark(&tokenizer, "Hello world!", 10);
210        assert!(result.is_ok());
211
212        let result = result.expect("Operation failed in test");
213        assert!(result.tokens_per_second > 0.0);
214        assert!(result.characters_per_second > 0.0);
215        assert!(result.total_texts > 0);
216    }
217
218    #[test]
219    fn test_benchmark_with_multiple_texts() {
220        let tokenizer = CharTokenizer::new(HashMap::new());
221        let texts = vec![
222            "Hello world!".to_string(),
223            "This is a test.".to_string(),
224            "Another test text.".to_string(),
225        ];
226
227        let config = BenchmarkConfig {
228            warmup_iterations: 2,
229            measurement_iterations: 5,
230            detailed_stats: true,
231        };
232
233        let result = TokenizerBenchmark::benchmark(&tokenizer, &texts, config);
234        assert!(result.is_ok());
235
236        let result = result.expect("Operation failed in test");
237        assert!(result.tokens_per_second > 0.0);
238        assert_eq!(result.total_texts, 15); // 3 texts * 5 iterations
239    }
240
241    #[test]
242    fn test_multi_length_benchmark() {
243        let tokenizer = CharTokenizer::new(HashMap::new());
244        let results = TokenizerBenchmark::multi_length_benchmark(&tokenizer);
245        assert!(results.is_ok());
246
247        let results = results.expect("Operation failed in test");
248        assert_eq!(results.len(), 3); // Short, medium, long
249
250        for (name, result) in results {
251            assert!(!name.is_empty());
252            assert!(result.tokens_per_second > 0.0);
253        }
254    }
255
256    #[test]
257    fn test_benchmark_result_summary() {
258        let result = BenchmarkResult {
259            tokens_per_second: 1000.0,
260            characters_per_second: 5000.0,
261            average_latency_us: 50.0,
262            min_latency_us: 30.0,
263            max_latency_us: 80.0,
264            total_texts: 100,
265            total_tokens: 500,
266            total_characters: 2500,
267        };
268
269        let summary = result.summary();
270        assert!(summary.contains("1000"));
271        assert!(summary.contains("5000"));
272        assert!(summary.contains("50.00"));
273    }
274
275    #[test]
276    fn test_benchmark_empty_texts() {
277        let tokenizer = CharTokenizer::new(HashMap::new());
278        let texts = vec![];
279        let config = BenchmarkConfig::default();
280
281        let result = TokenizerBenchmark::benchmark(&tokenizer, &texts, config);
282        assert!(result.is_err());
283    }
284}