1use std::time::Instant;
7use trustformers_core::traits::Tokenizer;
8
9#[derive(Debug, Clone)]
11pub struct BenchmarkConfig {
12 pub warmup_iterations: usize,
14 pub measurement_iterations: usize,
16 pub detailed_stats: bool,
18}
19
20impl Default for BenchmarkConfig {
21 fn default() -> Self {
22 Self {
23 warmup_iterations: 10,
24 measurement_iterations: 100,
25 detailed_stats: true,
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct BenchmarkResult {
33 pub tokens_per_second: f64,
35 pub characters_per_second: f64,
37 pub average_latency_us: f64,
39 pub min_latency_us: f64,
41 pub max_latency_us: f64,
43 pub total_texts: usize,
45 pub total_tokens: usize,
47 pub total_characters: usize,
49}
50
51impl BenchmarkResult {
52 pub fn summary(&self) -> String {
54 format!(
55 "Performance: {:.0} tokens/sec, {:.0} chars/sec, {:.2}μs avg latency",
56 self.tokens_per_second, self.characters_per_second, self.average_latency_us
57 )
58 }
59
60 pub fn detailed_report(&self) -> String {
62 format!(
63 r#"Tokenization Benchmark Results
64==============================
65Throughput:
66 - Tokens per second: {:.2}
67 - Characters per second: {:.2}
68 - Texts per second: {:.2}
69
70Latency (per text):
71 - Average: {:.2} μs
72 - Minimum: {:.2} μs
73 - Maximum: {:.2} μs
74
75Volume:
76 - Total texts: {}
77 - Total tokens: {}
78 - Total characters: {}
79 - Average tokens per text: {:.1}
80 - Average characters per text: {:.1}"#,
81 self.tokens_per_second,
82 self.characters_per_second,
83 self.total_texts as f64 / (self.average_latency_us / 1_000_000.0),
84 self.average_latency_us,
85 self.min_latency_us,
86 self.max_latency_us,
87 self.total_texts,
88 self.total_tokens,
89 self.total_characters,
90 self.total_tokens as f64 / self.total_texts as f64,
91 self.total_characters as f64 / self.total_texts as f64
92 )
93 }
94}
95
96pub struct TokenizerBenchmark;
98
99impl TokenizerBenchmark {
100 pub fn benchmark<T: Tokenizer>(
102 tokenizer: &T,
103 texts: &[String],
104 config: BenchmarkConfig,
105 ) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
106 if texts.is_empty() {
107 return Err("No texts provided for benchmarking".into());
108 }
109
110 for _ in 0..config.warmup_iterations {
112 for text in texts.iter().take(std::cmp::min(texts.len(), 10)) {
113 let _ = tokenizer.encode(text)?;
114 }
115 }
116
117 let mut latencies = Vec::new();
119 let mut total_tokens = 0;
120 let mut total_characters = 0;
121
122 for _ in 0..config.measurement_iterations {
123 for text in texts {
124 let start = Instant::now();
125 let result = tokenizer.encode(text)?;
126 let elapsed = start.elapsed();
127
128 latencies.push(elapsed.as_micros() as f64);
129 total_tokens += result.input_ids.len();
130 total_characters += text.len();
131 }
132 }
133
134 let total_time_seconds = latencies.iter().sum::<f64>() / 1_000_000.0;
136 let average_latency_us = latencies.iter().sum::<f64>() / latencies.len() as f64;
137 let min_latency_us = latencies.iter().cloned().fold(f64::INFINITY, f64::min);
138 let max_latency_us = latencies.iter().cloned().fold(0.0, f64::max);
139
140 let tokens_per_second = total_tokens as f64 / total_time_seconds;
141 let characters_per_second = total_characters as f64 / total_time_seconds;
142 let total_texts = texts.len() * config.measurement_iterations;
143
144 Ok(BenchmarkResult {
145 tokens_per_second,
146 characters_per_second,
147 average_latency_us,
148 min_latency_us,
149 max_latency_us,
150 total_texts,
151 total_tokens,
152 total_characters,
153 })
154 }
155
156 pub fn quick_benchmark<T: Tokenizer>(
158 tokenizer: &T,
159 text: &str,
160 iterations: usize,
161 ) -> Result<BenchmarkResult, Box<dyn std::error::Error>> {
162 let texts = vec![text.to_string(); 1];
163 let config = BenchmarkConfig {
164 warmup_iterations: 5,
165 measurement_iterations: iterations,
166 detailed_stats: false,
167 };
168
169 Self::benchmark(tokenizer, &texts, config)
170 }
171
172 pub fn multi_length_benchmark<T: Tokenizer>(
174 tokenizer: &T,
175 ) -> Result<Vec<(String, BenchmarkResult)>, Box<dyn std::error::Error>> {
176 let test_cases = vec![
177 ("Short text", "Hello world!".to_string()),
178 ("Medium text", "This is a longer text that contains multiple sentences and should give us a better idea of tokenization performance on medium-length inputs.".to_string()),
179 ("Long text", "This is a much longer text that contains many sentences and words. It is designed to test the performance of tokenization on longer inputs that might be more representative of real-world usage scenarios. The text includes various punctuation marks, numbers like 123 and 456, and different types of content that a tokenizer might encounter in practice. This should help identify any performance differences between short and long text processing.".to_string()),
180 ];
181
182 let mut results = Vec::new();
183 for (name, text) in test_cases {
184 let result = Self::quick_benchmark(tokenizer, &text, 100)?;
185 results.push((name.to_string(), result));
186 }
187
188 Ok(results)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::char::CharTokenizer;
196 use std::collections::HashMap;
197
198 #[test]
199 fn test_benchmark_config_default() {
200 let config = BenchmarkConfig::default();
201 assert_eq!(config.warmup_iterations, 10);
202 assert_eq!(config.measurement_iterations, 100);
203 assert!(config.detailed_stats);
204 }
205
206 #[test]
207 fn test_quick_benchmark() {
208 let tokenizer = CharTokenizer::new(HashMap::new());
209 let result = TokenizerBenchmark::quick_benchmark(&tokenizer, "Hello world!", 10);
210 assert!(result.is_ok());
211
212 let result = result.expect("Operation failed in test");
213 assert!(result.tokens_per_second > 0.0);
214 assert!(result.characters_per_second > 0.0);
215 assert!(result.total_texts > 0);
216 }
217
218 #[test]
219 fn test_benchmark_with_multiple_texts() {
220 let tokenizer = CharTokenizer::new(HashMap::new());
221 let texts = vec![
222 "Hello world!".to_string(),
223 "This is a test.".to_string(),
224 "Another test text.".to_string(),
225 ];
226
227 let config = BenchmarkConfig {
228 warmup_iterations: 2,
229 measurement_iterations: 5,
230 detailed_stats: true,
231 };
232
233 let result = TokenizerBenchmark::benchmark(&tokenizer, &texts, config);
234 assert!(result.is_ok());
235
236 let result = result.expect("Operation failed in test");
237 assert!(result.tokens_per_second > 0.0);
238 assert_eq!(result.total_texts, 15); }
240
241 #[test]
242 fn test_multi_length_benchmark() {
243 let tokenizer = CharTokenizer::new(HashMap::new());
244 let results = TokenizerBenchmark::multi_length_benchmark(&tokenizer);
245 assert!(results.is_ok());
246
247 let results = results.expect("Operation failed in test");
248 assert_eq!(results.len(), 3); for (name, result) in results {
251 assert!(!name.is_empty());
252 assert!(result.tokens_per_second > 0.0);
253 }
254 }
255
256 #[test]
257 fn test_benchmark_result_summary() {
258 let result = BenchmarkResult {
259 tokens_per_second: 1000.0,
260 characters_per_second: 5000.0,
261 average_latency_us: 50.0,
262 min_latency_us: 30.0,
263 max_latency_us: 80.0,
264 total_texts: 100,
265 total_tokens: 500,
266 total_characters: 2500,
267 };
268
269 let summary = result.summary();
270 assert!(summary.contains("1000"));
271 assert!(summary.contains("5000"));
272 assert!(summary.contains("50.00"));
273 }
274
275 #[test]
276 fn test_benchmark_empty_texts() {
277 let tokenizer = CharTokenizer::new(HashMap::new());
278 let texts = vec![];
279 let config = BenchmarkConfig::default();
280
281 let result = TokenizerBenchmark::benchmark(&tokenizer, &texts, config);
282 assert!(result.is_err());
283 }
284}