1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::time::{Duration, Instant};
4use trustformers_core::errors::{Result, TrustformersError};
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ProfilerConfig {
10 pub warmup_iterations: usize,
11 pub benchmark_iterations: usize,
12 pub measure_memory: bool,
13 pub measure_throughput: bool,
14 pub concurrent_threads: Option<usize>,
15 pub text_lengths: Vec<usize>,
16 pub batch_sizes: Vec<usize>,
17 pub detailed_timing: bool,
18 pub export_format: ExportFormat,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub enum ExportFormat {
24 Json,
25 Csv,
26 Html,
27 Markdown,
28}
29
30impl Default for ProfilerConfig {
31 fn default() -> Self {
32 Self {
33 warmup_iterations: 3,
34 benchmark_iterations: 10,
35 measure_memory: true,
36 measure_throughput: true,
37 concurrent_threads: Some(num_cpus::get()),
38 text_lengths: vec![50, 100, 500, 1000, 5000],
39 batch_sizes: vec![1, 8, 16, 32, 64],
40 detailed_timing: true,
41 export_format: ExportFormat::Json,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TimingStats {
49 pub mean: Duration,
50 pub median: Duration,
51 pub min: Duration,
52 pub max: Duration,
53 pub std_dev: Duration,
54 pub percentile_95: Duration,
55 pub percentile_99: Duration,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct MemoryStats {
61 pub peak_memory_mb: f64,
62 pub average_memory_mb: f64,
63 pub memory_growth_mb: f64,
64 pub allocations_count: Option<usize>,
65 pub deallocations_count: Option<usize>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ThroughputStats {
71 pub tokens_per_second: f64,
72 pub characters_per_second: f64,
73 pub batches_per_second: f64,
74 pub peak_throughput: f64,
75 pub average_throughput: f64,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct BenchmarkResult {
81 pub tokenizer_name: String,
82 pub text_length: usize,
83 pub batch_size: usize,
84 pub thread_count: usize,
85 pub timing: TimingStats,
86 pub memory: Option<MemoryStats>,
87 pub throughput: Option<ThroughputStats>,
88 pub error_rate: f64,
89 pub metadata: HashMap<String, serde_json::Value>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct ProfilingReport {
95 pub config: ProfilerConfig,
96 pub benchmarks: Vec<BenchmarkResult>,
97 pub summary: ProfilingSummary,
98 pub comparisons: Vec<TokenizerComparison>,
99 pub recommendations: Vec<String>,
100 pub timestamp: String,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct ProfilingSummary {
106 pub total_benchmarks: usize,
107 pub fastest_tokenizer: String,
108 pub most_memory_efficient: String,
109 pub highest_throughput: String,
110 pub most_consistent: String,
111 pub overall_stats: HashMap<String, f64>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct TokenizerComparison {
117 pub scenario: String,
118 pub results: HashMap<String, BenchmarkResult>,
119 pub winner: String,
120 pub performance_gap: f64,
121}
122
123pub struct PerformanceProfiler {
125 config: ProfilerConfig,
126 results: Vec<BenchmarkResult>,
127}
128
129impl PerformanceProfiler {
130 pub fn new(config: ProfilerConfig) -> Self {
132 Self {
133 config,
134 results: Vec::new(),
135 }
136 }
137
138 pub fn default() -> Self {
140 Self::new(ProfilerConfig::default())
141 }
142
143 pub fn profile_tokenizer<T: Tokenizer + Sync>(
145 &mut self,
146 name: &str,
147 tokenizer: &T,
148 test_texts: &[String],
149 ) -> Result<Vec<BenchmarkResult>> {
150 let mut tokenizer_results = Vec::new();
151
152 for &text_length in &self.config.text_lengths {
153 for &batch_size in &self.config.batch_sizes {
154 let texts = self.prepare_test_texts(test_texts, text_length, batch_size);
156
157 let result =
159 self.benchmark_scenario(name, tokenizer, &texts, text_length, batch_size)?;
160
161 tokenizer_results.push(result.clone());
162 self.results.push(result);
163 }
164 }
165
166 Ok(tokenizer_results)
167 }
168
169 pub fn profile_multiple<T: Tokenizer + Sync>(
171 &mut self,
172 tokenizers: HashMap<String, &T>,
173 test_texts: &[String],
174 ) -> Result<ProfilingReport> {
175 for (name, tokenizer) in tokenizers {
176 self.profile_tokenizer(&name, tokenizer, test_texts)?;
177 }
178
179 self.generate_report()
180 }
181
182 fn benchmark_scenario<T: Tokenizer + Sync>(
184 &self,
185 name: &str,
186 tokenizer: &T,
187 texts: &[String],
188 text_length: usize,
189 batch_size: usize,
190 ) -> Result<BenchmarkResult> {
191 let thread_count = self.config.concurrent_threads.unwrap_or(1);
192
193 for _ in 0..self.config.warmup_iterations {
195 let _ = self.run_tokenization(tokenizer, texts)?;
196 }
197
198 let mut timings = Vec::new();
200 let mut error_count = 0;
201 let start_memory = self.get_memory_usage();
202
203 for _ in 0..self.config.benchmark_iterations {
204 let start = Instant::now();
205 match self.run_tokenization(tokenizer, texts) {
206 Ok(_) => {
207 let duration = start.elapsed();
208 timings.push(duration);
209 },
210 Err(_) => {
211 error_count += 1;
212 timings.push(Duration::from_millis(u64::MAX)); },
214 }
215 }
216
217 let end_memory = self.get_memory_usage();
218 let error_rate = error_count as f64 / self.config.benchmark_iterations as f64;
219
220 let timing = self.calculate_timing_stats(&timings);
222 let memory = if self.config.measure_memory {
223 Some(MemoryStats {
224 peak_memory_mb: end_memory,
225 average_memory_mb: (start_memory + end_memory) / 2.0,
226 memory_growth_mb: end_memory - start_memory,
227 allocations_count: None,
228 deallocations_count: None,
229 })
230 } else {
231 None
232 };
233
234 let throughput = if self.config.measure_throughput {
235 Some(self.calculate_throughput_stats(texts, &timings, batch_size))
236 } else {
237 None
238 };
239
240 Ok(BenchmarkResult {
241 tokenizer_name: name.to_string(),
242 text_length,
243 batch_size,
244 thread_count,
245 timing,
246 memory,
247 throughput,
248 error_rate,
249 metadata: HashMap::new(),
250 })
251 }
252
253 fn run_tokenization<T: Tokenizer>(
255 &self,
256 tokenizer: &T,
257 texts: &[String],
258 ) -> Result<Vec<TokenizedInput>> {
259 let mut results = Vec::new();
260 for text in texts {
261 let result = tokenizer.encode(text)?;
262 results.push(result);
263 }
264 Ok(results)
265 }
266
267 fn prepare_test_texts(
269 &self,
270 source_texts: &[String],
271 target_length: usize,
272 count: usize,
273 ) -> Vec<String> {
274 let mut texts = Vec::new();
275 let mut text_pool = source_texts.iter().cycle();
276
277 for _ in 0..count {
278 let mut combined_text = String::new();
279
280 while combined_text.len() < target_length {
281 if let Some(text) = text_pool.next() {
282 combined_text.push_str(text);
283 combined_text.push(' ');
284 } else {
285 break;
286 }
287 }
288
289 if combined_text.len() > target_length {
291 combined_text.truncate(target_length);
292 }
293
294 texts.push(combined_text);
295 }
296
297 texts
298 }
299
300 fn calculate_timing_stats(&self, timings: &[Duration]) -> TimingStats {
302 let mut valid_timings: Vec<Duration> = timings
303 .iter()
304 .filter(|&&t| t != Duration::from_millis(u64::MAX))
305 .copied()
306 .collect();
307
308 valid_timings.sort();
309
310 if valid_timings.is_empty() {
311 return TimingStats {
312 mean: Duration::ZERO,
313 median: Duration::ZERO,
314 min: Duration::ZERO,
315 max: Duration::ZERO,
316 std_dev: Duration::ZERO,
317 percentile_95: Duration::ZERO,
318 percentile_99: Duration::ZERO,
319 };
320 }
321
322 let sum: Duration = valid_timings.iter().sum();
323 let mean = sum / valid_timings.len() as u32;
324
325 let median = valid_timings[valid_timings.len() / 2];
326 let min = valid_timings[0];
327 let max = valid_timings[valid_timings.len() - 1];
328
329 let variance: f64 = valid_timings
331 .iter()
332 .map(|&t| {
333 let diff = t.as_nanos() as f64 - mean.as_nanos() as f64;
334 diff * diff
335 })
336 .sum::<f64>()
337 / valid_timings.len() as f64;
338
339 let std_dev = Duration::from_nanos(variance.sqrt() as u64);
340
341 let p95_idx = (valid_timings.len() as f64 * 0.95) as usize;
342 let p99_idx = (valid_timings.len() as f64 * 0.99) as usize;
343
344 let percentile_95 = valid_timings.get(p95_idx).copied().unwrap_or(max);
345 let percentile_99 = valid_timings.get(p99_idx).copied().unwrap_or(max);
346
347 TimingStats {
348 mean,
349 median,
350 min,
351 max,
352 std_dev,
353 percentile_95,
354 percentile_99,
355 }
356 }
357
358 fn calculate_throughput_stats(
360 &self,
361 texts: &[String],
362 timings: &[Duration],
363 batch_size: usize,
364 ) -> ThroughputStats {
365 let total_chars: usize = texts.iter().map(|t| t.len()).sum();
366 let total_tokens = texts.len() * batch_size; let valid_timings: Vec<Duration> = timings
369 .iter()
370 .filter(|&&t| t != Duration::from_millis(u64::MAX))
371 .copied()
372 .collect();
373
374 if valid_timings.is_empty() {
375 return ThroughputStats {
376 tokens_per_second: 0.0,
377 characters_per_second: 0.0,
378 batches_per_second: 0.0,
379 peak_throughput: 0.0,
380 average_throughput: 0.0,
381 };
382 }
383
384 let throughputs: Vec<f64> = valid_timings
385 .iter()
386 .map(|&duration| {
387 if duration.as_secs_f64() > 0.0 {
388 total_tokens as f64 / duration.as_secs_f64()
389 } else {
390 0.0
391 }
392 })
393 .collect();
394
395 let average_throughput = throughputs.iter().sum::<f64>() / throughputs.len() as f64;
396 let peak_throughput = throughputs
397 .iter()
398 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
399 .copied()
400 .unwrap_or(0.0);
401
402 let avg_duration = valid_timings.iter().sum::<Duration>() / valid_timings.len() as u32;
403 let tokens_per_second = if avg_duration.as_secs_f64() > 0.0 {
404 total_tokens as f64 / avg_duration.as_secs_f64()
405 } else {
406 0.0
407 };
408
409 let characters_per_second = if avg_duration.as_secs_f64() > 0.0 {
410 total_chars as f64 / avg_duration.as_secs_f64()
411 } else {
412 0.0
413 };
414
415 let batches_per_second = if avg_duration.as_secs_f64() > 0.0 {
416 1.0 / avg_duration.as_secs_f64()
417 } else {
418 0.0
419 };
420
421 ThroughputStats {
422 tokens_per_second,
423 characters_per_second,
424 batches_per_second,
425 peak_throughput,
426 average_throughput,
427 }
428 }
429
430 fn get_memory_usage(&self) -> f64 {
432 #[cfg(target_os = "linux")]
436 {
437 if let Ok(contents) = std::fs::read_to_string("/proc/self/status") {
438 for line in contents.lines() {
439 if line.starts_with("VmRSS:") {
440 if let Some(kb_str) = line.split_whitespace().nth(1) {
441 if let Ok(kb) = kb_str.parse::<f64>() {
442 return kb / 1024.0; }
444 }
445 }
446 }
447 }
448 }
449
450 0.0
452 }
453
454 fn generate_report(&self) -> Result<ProfilingReport> {
456 let summary = self.generate_summary();
457 let comparisons = self.generate_comparisons();
458 let recommendations = self.generate_recommendations();
459
460 Ok(ProfilingReport {
461 config: self.config.clone(),
462 benchmarks: self.results.clone(),
463 summary,
464 comparisons,
465 recommendations,
466 timestamp: chrono::Utc::now().to_rfc3339(),
467 })
468 }
469
470 fn generate_summary(&self) -> ProfilingSummary {
472 if self.results.is_empty() {
473 return ProfilingSummary {
474 total_benchmarks: 0,
475 fastest_tokenizer: "N/A".to_string(),
476 most_memory_efficient: "N/A".to_string(),
477 highest_throughput: "N/A".to_string(),
478 most_consistent: "N/A".to_string(),
479 overall_stats: HashMap::new(),
480 };
481 }
482
483 let fastest = self
485 .results
486 .iter()
487 .min_by(|a, b| {
488 a.timing.mean.partial_cmp(&b.timing.mean).unwrap_or(std::cmp::Ordering::Equal)
489 })
490 .map(|r| r.tokenizer_name.clone())
491 .unwrap_or_else(|| "N/A".to_string());
492
493 let most_memory_efficient = self
495 .results
496 .iter()
497 .filter_map(|r| r.memory.as_ref().map(|m| (r, m.peak_memory_mb)))
498 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
499 .map(|(r, _)| r.tokenizer_name.clone())
500 .unwrap_or_else(|| "N/A".to_string());
501
502 let highest_throughput = self
504 .results
505 .iter()
506 .filter_map(|r| r.throughput.as_ref().map(|t| (r, t.peak_throughput)))
507 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
508 .map(|(r, _)| r.tokenizer_name.clone())
509 .unwrap_or_else(|| "N/A".to_string());
510
511 let most_consistent = self
513 .results
514 .iter()
515 .min_by(|a, b| {
516 a.timing
517 .std_dev
518 .partial_cmp(&b.timing.std_dev)
519 .unwrap_or(std::cmp::Ordering::Equal)
520 })
521 .map(|r| r.tokenizer_name.clone())
522 .unwrap_or_else(|| "N/A".to_string());
523
524 let mut overall_stats = HashMap::new();
526 let total_time: Duration = self.results.iter().map(|r| r.timing.mean).sum();
527 overall_stats.insert(
528 "total_benchmark_time_ms".to_string(),
529 total_time.as_millis() as f64,
530 );
531
532 let avg_throughput = self
533 .results
534 .iter()
535 .filter_map(|r| r.throughput.as_ref())
536 .map(|t| t.average_throughput)
537 .sum::<f64>()
538 / self.results.len() as f64;
539 overall_stats.insert("average_throughput".to_string(), avg_throughput);
540
541 ProfilingSummary {
542 total_benchmarks: self.results.len(),
543 fastest_tokenizer: fastest,
544 most_memory_efficient,
545 highest_throughput,
546 most_consistent,
547 overall_stats,
548 }
549 }
550
551 fn generate_comparisons(&self) -> Vec<TokenizerComparison> {
553 let mut comparisons = Vec::new();
554
555 let mut scenarios: HashMap<String, Vec<&BenchmarkResult>> = HashMap::new();
557 for result in &self.results {
558 let scenario = format!("length_{}_batch_{}", result.text_length, result.batch_size);
559 scenarios.entry(scenario).or_default().push(result);
560 }
561
562 for (scenario, results) in scenarios {
563 if results.len() > 1 {
564 let mut scenario_results = HashMap::new();
565 for result in &results {
566 scenario_results.insert(result.tokenizer_name.clone(), (*result).clone());
567 }
568
569 let winner = results
571 .iter()
572 .min_by(|a, b| {
573 a.timing
574 .mean
575 .partial_cmp(&b.timing.mean)
576 .unwrap_or(std::cmp::Ordering::Equal)
577 })
578 .map(|r| r.tokenizer_name.clone())
579 .unwrap_or_else(|| "N/A".to_string());
580
581 let fastest_time =
583 results.iter().map(|r| r.timing.mean.as_millis()).min().unwrap_or(0);
584 let slowest_time =
585 results.iter().map(|r| r.timing.mean.as_millis()).max().unwrap_or(0);
586
587 let performance_gap = if fastest_time > 0 {
588 (slowest_time as f64 / fastest_time as f64) - 1.0
589 } else {
590 0.0
591 };
592
593 comparisons.push(TokenizerComparison {
594 scenario,
595 results: scenario_results,
596 winner,
597 performance_gap,
598 });
599 }
600 }
601
602 comparisons
603 }
604
605 fn generate_recommendations(&self) -> Vec<String> {
607 let mut recommendations = Vec::new();
608
609 if self.results.is_empty() {
610 return recommendations;
611 }
612
613 let high_error_rate = self.results.iter().any(|r| r.error_rate > 0.1);
615 if high_error_rate {
616 recommendations
617 .push("Consider investigating tokenizers with high error rates (>10%)".to_string());
618 }
619
620 if let Some(max_memory) = self
622 .results
623 .iter()
624 .filter_map(|r| r.memory.as_ref())
625 .map(|m| m.peak_memory_mb)
626 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
627 {
628 if max_memory > 1000.0 {
629 recommendations.push(
630 "Consider using memory-efficient tokenizers for large-scale processing"
631 .to_string(),
632 );
633 }
634 }
635
636 let high_variance = self
638 .results
639 .iter()
640 .any(|r| r.timing.std_dev.as_millis() > r.timing.mean.as_millis() / 2);
641 if high_variance {
642 recommendations.push(
643 "Some tokenizers show high timing variance - consider warmup strategies"
644 .to_string(),
645 );
646 }
647
648 let throughputs: Vec<f64> = self
650 .results
651 .iter()
652 .filter_map(|r| r.throughput.as_ref())
653 .map(|t| t.average_throughput)
654 .collect();
655 if !throughputs.is_empty() {
656 let max_throughput = throughputs
657 .iter()
658 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
659 .copied()
660 .unwrap_or(0.0);
661 let min_throughput = throughputs
662 .iter()
663 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
664 .copied()
665 .unwrap_or(0.0);
666
667 if max_throughput > min_throughput * 2.0 {
668 recommendations.push("Significant throughput differences detected - choose tokenizer based on use case".to_string());
669 }
670 }
671
672 recommendations
673 }
674
675 pub fn export_report(&self, report: &ProfilingReport, format: ExportFormat) -> Result<String> {
677 match format {
678 ExportFormat::Json => self.export_json(report),
679 ExportFormat::Csv => self.export_csv(report),
680 ExportFormat::Html => self.export_html(report),
681 ExportFormat::Markdown => self.export_markdown(report),
682 }
683 }
684
685 fn export_json(&self, report: &ProfilingReport) -> Result<String> {
687 serde_json::to_string_pretty(report).map_err(|e| {
688 TrustformersError::other(
689 anyhow::anyhow!("Failed to serialize to JSON: {}", e).to_string(),
690 )
691 })
692 }
693
694 fn export_csv(&self, report: &ProfilingReport) -> Result<String> {
696 let mut csv = String::new();
697 csv.push_str(
698 "tokenizer_name,text_length,batch_size,mean_time_ms,memory_mb,throughput,error_rate\n",
699 );
700
701 for benchmark in &report.benchmarks {
702 csv.push_str(&format!(
703 "{},{},{},{},{},{},{}\n",
704 benchmark.tokenizer_name,
705 benchmark.text_length,
706 benchmark.batch_size,
707 benchmark.timing.mean.as_millis(),
708 benchmark.memory.as_ref().map(|m| m.peak_memory_mb).unwrap_or(0.0),
709 benchmark.throughput.as_ref().map(|t| t.average_throughput).unwrap_or(0.0),
710 benchmark.error_rate
711 ));
712 }
713
714 Ok(csv)
715 }
716
717 fn export_html(&self, report: &ProfilingReport) -> Result<String> {
719 let mut html = String::new();
720 html.push_str(
721 "<!DOCTYPE html>\n<html>\n<head>\n<title>Tokenizer Performance Report</title>\n",
722 );
723 html.push_str("<style>body{font-family:Arial,sans-serif;margin:40px;}table{border-collapse:collapse;width:100%;}th,td{border:1px solid #ddd;padding:8px;text-align:left;}th{background-color:#f2f2f2;}</style>\n");
724 html.push_str("</head>\n<body>\n");
725 html.push_str("<h1>Tokenizer Performance Report</h1>\n");
726
727 html.push_str("<h2>Summary</h2>\n");
728 html.push_str("<table>\n");
729 html.push_str(&format!(
730 "<tr><td>Total Benchmarks</td><td>{}</td></tr>\n",
731 report.summary.total_benchmarks
732 ));
733 html.push_str(&format!(
734 "<tr><td>Fastest Tokenizer</td><td>{}</td></tr>\n",
735 report.summary.fastest_tokenizer
736 ));
737 html.push_str(&format!(
738 "<tr><td>Most Memory Efficient</td><td>{}</td></tr>\n",
739 report.summary.most_memory_efficient
740 ));
741 html.push_str(&format!(
742 "<tr><td>Highest Throughput</td><td>{}</td></tr>\n",
743 report.summary.highest_throughput
744 ));
745 html.push_str("</table>\n");
746
747 html.push_str("<h2>Detailed Results</h2>\n");
748 html.push_str("<table>\n");
749 html.push_str("<tr><th>Tokenizer</th><th>Text Length</th><th>Batch Size</th><th>Mean Time (ms)</th><th>Memory (MB)</th><th>Throughput</th></tr>\n");
750
751 for benchmark in &report.benchmarks {
752 html.push_str(&format!(
753 "<tr><td>{}</td><td>{}</td><td>{}</td><td>{}</td><td>{:.1}</td><td>{:.1}</td></tr>\n",
754 benchmark.tokenizer_name,
755 benchmark.text_length,
756 benchmark.batch_size,
757 benchmark.timing.mean.as_millis(),
758 benchmark.memory.as_ref().map(|m| m.peak_memory_mb).unwrap_or(0.0),
759 benchmark.throughput.as_ref().map(|t| t.average_throughput).unwrap_or(0.0)
760 ));
761 }
762
763 html.push_str("</table>\n</body>\n</html>");
764 Ok(html)
765 }
766
767 fn export_markdown(&self, report: &ProfilingReport) -> Result<String> {
769 let mut md = String::new();
770 md.push_str("# Tokenizer Performance Report\n\n");
771
772 md.push_str("## Summary\n\n");
773 md.push_str(&format!(
774 "- **Total Benchmarks**: {}\n",
775 report.summary.total_benchmarks
776 ));
777 md.push_str(&format!(
778 "- **Fastest Tokenizer**: {}\n",
779 report.summary.fastest_tokenizer
780 ));
781 md.push_str(&format!(
782 "- **Most Memory Efficient**: {}\n",
783 report.summary.most_memory_efficient
784 ));
785 md.push_str(&format!(
786 "- **Highest Throughput**: {}\n\n",
787 report.summary.highest_throughput
788 ));
789
790 md.push_str("## Detailed Results\n\n");
791 md.push_str("| Tokenizer | Text Length | Batch Size | Mean Time (ms) | Memory (MB) | Throughput |\n");
792 md.push_str("|-----------|-------------|------------|----------------|-------------|------------|\n");
793
794 for benchmark in &report.benchmarks {
795 md.push_str(&format!(
796 "| {} | {} | {} | {} | {:.1} | {:.1} |\n",
797 benchmark.tokenizer_name,
798 benchmark.text_length,
799 benchmark.batch_size,
800 benchmark.timing.mean.as_millis(),
801 benchmark.memory.as_ref().map(|m| m.peak_memory_mb).unwrap_or(0.0),
802 benchmark.throughput.as_ref().map(|t| t.average_throughput).unwrap_or(0.0)
803 ));
804 }
805
806 if !report.recommendations.is_empty() {
807 md.push_str("\n## Recommendations\n\n");
808 for (i, rec) in report.recommendations.iter().enumerate() {
809 md.push_str(&format!("{}. {}\n", i + 1, rec));
810 }
811 }
812
813 Ok(md)
814 }
815}
816
817#[cfg(test)]
818mod tests {
819 use super::*;
820 use crate::char::CharTokenizer;
821 use std::collections::HashMap;
822
823 fn create_test_char_tokenizer() -> CharTokenizer {
824 let mut vocab = HashMap::new();
825 vocab.insert("[PAD]".to_string(), 0);
826 vocab.insert("[UNK]".to_string(), 1);
827 vocab.insert("[CLS]".to_string(), 2);
828 vocab.insert("[SEP]".to_string(), 3);
829 vocab.insert("h".to_string(), 4);
830 vocab.insert("e".to_string(), 5);
831 vocab.insert("l".to_string(), 6);
832 vocab.insert("o".to_string(), 7);
833 vocab.insert("w".to_string(), 8);
834 vocab.insert("r".to_string(), 9);
835 vocab.insert("d".to_string(), 10);
836 vocab.insert(" ".to_string(), 11);
837 vocab.insert("t".to_string(), 12);
838 vocab.insert("s".to_string(), 13);
839 CharTokenizer::new(vocab)
840 }
841
842 #[test]
843 fn test_profiler_creation() {
844 let config = ProfilerConfig::default();
845 let profiler = PerformanceProfiler::new(config);
846 assert_eq!(profiler.results.len(), 0);
847 }
848
849 #[test]
850 fn test_single_tokenizer_profiling() {
851 let mut profiler = PerformanceProfiler::new(ProfilerConfig {
852 warmup_iterations: 1,
853 benchmark_iterations: 2,
854 text_lengths: vec![10],
855 batch_sizes: vec![1],
856 ..Default::default()
857 });
858
859 let tokenizer = create_test_char_tokenizer();
860 let test_texts = vec!["Hello world!".to_string()];
861
862 let results = profiler
863 .profile_tokenizer("char", &tokenizer, &test_texts)
864 .expect("Operation failed in test");
865 assert_eq!(results.len(), 1);
866 assert_eq!(results[0].tokenizer_name, "char");
867 }
868
869 #[test]
870 fn test_timing_stats_calculation() {
871 let profiler = PerformanceProfiler::default();
872 let timings = vec![
873 Duration::from_millis(100),
874 Duration::from_millis(110),
875 Duration::from_millis(90),
876 Duration::from_millis(105),
877 ];
878
879 let stats = profiler.calculate_timing_stats(&timings);
880 assert!(stats.mean.as_millis() > 0);
881 assert!(stats.min <= stats.median);
882 assert!(stats.median <= stats.max);
883 }
884
885 #[test]
886 fn test_report_generation() {
887 let mut profiler = PerformanceProfiler::new(ProfilerConfig {
888 warmup_iterations: 1,
889 benchmark_iterations: 1,
890 text_lengths: vec![5],
891 batch_sizes: vec![1],
892 ..Default::default()
893 });
894
895 let tokenizer = create_test_char_tokenizer();
896 let test_texts = vec!["Hi".to_string()];
897
898 profiler
899 .profile_tokenizer("test", &tokenizer, &test_texts)
900 .expect("Operation failed in test");
901 let report = profiler.generate_report().expect("Operation failed in test");
902
903 assert_eq!(report.benchmarks.len(), 1);
904 assert_eq!(report.summary.total_benchmarks, 1);
905 }
906
907 #[test]
908 fn test_export_formats() {
909 let profiler = PerformanceProfiler::default();
910 let report = ProfilingReport {
911 config: ProfilerConfig::default(),
912 benchmarks: vec![],
913 summary: ProfilingSummary {
914 total_benchmarks: 0,
915 fastest_tokenizer: "test".to_string(),
916 most_memory_efficient: "test".to_string(),
917 highest_throughput: "test".to_string(),
918 most_consistent: "test".to_string(),
919 overall_stats: HashMap::new(),
920 },
921 comparisons: vec![],
922 recommendations: vec![],
923 timestamp: "2023-01-01T00:00:00Z".to_string(),
924 };
925
926 let json = profiler
927 .export_report(&report, ExportFormat::Json)
928 .expect("Operation failed in test");
929 assert!(json.contains("fastest_tokenizer"));
930
931 let csv = profiler
932 .export_report(&report, ExportFormat::Csv)
933 .expect("Operation failed in test");
934 assert!(csv.contains("tokenizer_name"));
935
936 let html = profiler
937 .export_report(&report, ExportFormat::Html)
938 .expect("Operation failed in test");
939 assert!(html.contains("<html>"));
940
941 let md = profiler
942 .export_report(&report, ExportFormat::Markdown)
943 .expect("Operation failed in test");
944 assert!(md.contains("# Tokenizer Performance Report"));
945 }
946}