Skip to main content

trustformers_tokenizers/
performance_profiler.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::time::{Duration, Instant};
4use trustformers_core::errors::{Result, TrustformersError};
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7/// Configuration for performance profiling
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ProfilerConfig {
10    pub warmup_iterations: usize,
11    pub benchmark_iterations: usize,
12    pub measure_memory: bool,
13    pub measure_throughput: bool,
14    pub concurrent_threads: Option<usize>,
15    pub text_lengths: Vec<usize>,
16    pub batch_sizes: Vec<usize>,
17    pub detailed_timing: bool,
18    pub export_format: ExportFormat,
19}
20
21/// Export format for profiling results
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub enum ExportFormat {
24    Json,
25    Csv,
26    Html,
27    Markdown,
28}
29
30impl Default for ProfilerConfig {
31    fn default() -> Self {
32        Self {
33            warmup_iterations: 3,
34            benchmark_iterations: 10,
35            measure_memory: true,
36            measure_throughput: true,
37            concurrent_threads: Some(num_cpus::get()),
38            text_lengths: vec![50, 100, 500, 1000, 5000],
39            batch_sizes: vec![1, 8, 16, 32, 64],
40            detailed_timing: true,
41            export_format: ExportFormat::Json,
42        }
43    }
44}
45
46/// Timing measurements
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TimingStats {
49    pub mean: Duration,
50    pub median: Duration,
51    pub min: Duration,
52    pub max: Duration,
53    pub std_dev: Duration,
54    pub percentile_95: Duration,
55    pub percentile_99: Duration,
56}
57
58/// Memory usage statistics
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct MemoryStats {
61    pub peak_memory_mb: f64,
62    pub average_memory_mb: f64,
63    pub memory_growth_mb: f64,
64    pub allocations_count: Option<usize>,
65    pub deallocations_count: Option<usize>,
66}
67
68/// Throughput measurements
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ThroughputStats {
71    pub tokens_per_second: f64,
72    pub characters_per_second: f64,
73    pub batches_per_second: f64,
74    pub peak_throughput: f64,
75    pub average_throughput: f64,
76}
77
78/// Individual benchmark result
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct BenchmarkResult {
81    pub tokenizer_name: String,
82    pub text_length: usize,
83    pub batch_size: usize,
84    pub thread_count: usize,
85    pub timing: TimingStats,
86    pub memory: Option<MemoryStats>,
87    pub throughput: Option<ThroughputStats>,
88    pub error_rate: f64,
89    pub metadata: HashMap<String, serde_json::Value>,
90}
91
92/// Complete profiling session results
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct ProfilingReport {
95    pub config: ProfilerConfig,
96    pub benchmarks: Vec<BenchmarkResult>,
97    pub summary: ProfilingSummary,
98    pub comparisons: Vec<TokenizerComparison>,
99    pub recommendations: Vec<String>,
100    pub timestamp: String,
101}
102
103/// Summary statistics across all benchmarks
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct ProfilingSummary {
106    pub total_benchmarks: usize,
107    pub fastest_tokenizer: String,
108    pub most_memory_efficient: String,
109    pub highest_throughput: String,
110    pub most_consistent: String,
111    pub overall_stats: HashMap<String, f64>,
112}
113
114/// Comparison between tokenizers
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct TokenizerComparison {
117    pub scenario: String,
118    pub results: HashMap<String, BenchmarkResult>,
119    pub winner: String,
120    pub performance_gap: f64,
121}
122
123/// Performance profiler implementation
124pub struct PerformanceProfiler {
125    config: ProfilerConfig,
126    results: Vec<BenchmarkResult>,
127}
128
129impl PerformanceProfiler {
130    /// Create a new performance profiler
131    pub fn new(config: ProfilerConfig) -> Self {
132        Self {
133            config,
134            results: Vec::new(),
135        }
136    }
137
138    /// Create profiler with default configuration
139    pub fn default() -> Self {
140        Self::new(ProfilerConfig::default())
141    }
142
143    /// Profile a single tokenizer
144    pub fn profile_tokenizer<T: Tokenizer + Sync>(
145        &mut self,
146        name: &str,
147        tokenizer: &T,
148        test_texts: &[String],
149    ) -> Result<Vec<BenchmarkResult>> {
150        let mut tokenizer_results = Vec::new();
151
152        for &text_length in &self.config.text_lengths {
153            for &batch_size in &self.config.batch_sizes {
154                // Prepare test data
155                let texts = self.prepare_test_texts(test_texts, text_length, batch_size);
156
157                // Run benchmark
158                let result =
159                    self.benchmark_scenario(name, tokenizer, &texts, text_length, batch_size)?;
160
161                tokenizer_results.push(result.clone());
162                self.results.push(result);
163            }
164        }
165
166        Ok(tokenizer_results)
167    }
168
169    /// Profile multiple tokenizers
170    pub fn profile_multiple<T: Tokenizer + Sync>(
171        &mut self,
172        tokenizers: HashMap<String, &T>,
173        test_texts: &[String],
174    ) -> Result<ProfilingReport> {
175        for (name, tokenizer) in tokenizers {
176            self.profile_tokenizer(&name, tokenizer, test_texts)?;
177        }
178
179        self.generate_report()
180    }
181
182    /// Benchmark a specific scenario
183    fn benchmark_scenario<T: Tokenizer + Sync>(
184        &self,
185        name: &str,
186        tokenizer: &T,
187        texts: &[String],
188        text_length: usize,
189        batch_size: usize,
190    ) -> Result<BenchmarkResult> {
191        let thread_count = self.config.concurrent_threads.unwrap_or(1);
192
193        // Warmup
194        for _ in 0..self.config.warmup_iterations {
195            let _ = self.run_tokenization(tokenizer, texts)?;
196        }
197
198        // Collect timing measurements
199        let mut timings = Vec::new();
200        let mut error_count = 0;
201        let start_memory = self.get_memory_usage();
202
203        for _ in 0..self.config.benchmark_iterations {
204            let start = Instant::now();
205            match self.run_tokenization(tokenizer, texts) {
206                Ok(_) => {
207                    let duration = start.elapsed();
208                    timings.push(duration);
209                },
210                Err(_) => {
211                    error_count += 1;
212                    timings.push(Duration::from_millis(u64::MAX)); // Mark as failed
213                },
214            }
215        }
216
217        let end_memory = self.get_memory_usage();
218        let error_rate = error_count as f64 / self.config.benchmark_iterations as f64;
219
220        // Calculate statistics
221        let timing = self.calculate_timing_stats(&timings);
222        let memory = if self.config.measure_memory {
223            Some(MemoryStats {
224                peak_memory_mb: end_memory,
225                average_memory_mb: (start_memory + end_memory) / 2.0,
226                memory_growth_mb: end_memory - start_memory,
227                allocations_count: None,
228                deallocations_count: None,
229            })
230        } else {
231            None
232        };
233
234        let throughput = if self.config.measure_throughput {
235            Some(self.calculate_throughput_stats(texts, &timings, batch_size))
236        } else {
237            None
238        };
239
240        Ok(BenchmarkResult {
241            tokenizer_name: name.to_string(),
242            text_length,
243            batch_size,
244            thread_count,
245            timing,
246            memory,
247            throughput,
248            error_rate,
249            metadata: HashMap::new(),
250        })
251    }
252
253    /// Run tokenization on texts
254    fn run_tokenization<T: Tokenizer>(
255        &self,
256        tokenizer: &T,
257        texts: &[String],
258    ) -> Result<Vec<TokenizedInput>> {
259        let mut results = Vec::new();
260        for text in texts {
261            let result = tokenizer.encode(text)?;
262            results.push(result);
263        }
264        Ok(results)
265    }
266
267    /// Prepare test texts for benchmarking
268    fn prepare_test_texts(
269        &self,
270        source_texts: &[String],
271        target_length: usize,
272        count: usize,
273    ) -> Vec<String> {
274        let mut texts = Vec::new();
275        let mut text_pool = source_texts.iter().cycle();
276
277        for _ in 0..count {
278            let mut combined_text = String::new();
279
280            while combined_text.len() < target_length {
281                if let Some(text) = text_pool.next() {
282                    combined_text.push_str(text);
283                    combined_text.push(' ');
284                } else {
285                    break;
286                }
287            }
288
289            // Truncate to exact length
290            if combined_text.len() > target_length {
291                combined_text.truncate(target_length);
292            }
293
294            texts.push(combined_text);
295        }
296
297        texts
298    }
299
300    /// Calculate timing statistics
301    fn calculate_timing_stats(&self, timings: &[Duration]) -> TimingStats {
302        let mut valid_timings: Vec<Duration> = timings
303            .iter()
304            .filter(|&&t| t != Duration::from_millis(u64::MAX))
305            .copied()
306            .collect();
307
308        valid_timings.sort();
309
310        if valid_timings.is_empty() {
311            return TimingStats {
312                mean: Duration::ZERO,
313                median: Duration::ZERO,
314                min: Duration::ZERO,
315                max: Duration::ZERO,
316                std_dev: Duration::ZERO,
317                percentile_95: Duration::ZERO,
318                percentile_99: Duration::ZERO,
319            };
320        }
321
322        let sum: Duration = valid_timings.iter().sum();
323        let mean = sum / valid_timings.len() as u32;
324
325        let median = valid_timings[valid_timings.len() / 2];
326        let min = valid_timings[0];
327        let max = valid_timings[valid_timings.len() - 1];
328
329        // Calculate standard deviation
330        let variance: f64 = valid_timings
331            .iter()
332            .map(|&t| {
333                let diff = t.as_nanos() as f64 - mean.as_nanos() as f64;
334                diff * diff
335            })
336            .sum::<f64>()
337            / valid_timings.len() as f64;
338
339        let std_dev = Duration::from_nanos(variance.sqrt() as u64);
340
341        let p95_idx = (valid_timings.len() as f64 * 0.95) as usize;
342        let p99_idx = (valid_timings.len() as f64 * 0.99) as usize;
343
344        let percentile_95 = valid_timings.get(p95_idx).copied().unwrap_or(max);
345        let percentile_99 = valid_timings.get(p99_idx).copied().unwrap_or(max);
346
347        TimingStats {
348            mean,
349            median,
350            min,
351            max,
352            std_dev,
353            percentile_95,
354            percentile_99,
355        }
356    }
357
358    /// Calculate throughput statistics
359    fn calculate_throughput_stats(
360        &self,
361        texts: &[String],
362        timings: &[Duration],
363        batch_size: usize,
364    ) -> ThroughputStats {
365        let total_chars: usize = texts.iter().map(|t| t.len()).sum();
366        let total_tokens = texts.len() * batch_size; // Approximate
367
368        let valid_timings: Vec<Duration> = timings
369            .iter()
370            .filter(|&&t| t != Duration::from_millis(u64::MAX))
371            .copied()
372            .collect();
373
374        if valid_timings.is_empty() {
375            return ThroughputStats {
376                tokens_per_second: 0.0,
377                characters_per_second: 0.0,
378                batches_per_second: 0.0,
379                peak_throughput: 0.0,
380                average_throughput: 0.0,
381            };
382        }
383
384        let throughputs: Vec<f64> = valid_timings
385            .iter()
386            .map(|&duration| {
387                if duration.as_secs_f64() > 0.0 {
388                    total_tokens as f64 / duration.as_secs_f64()
389                } else {
390                    0.0
391                }
392            })
393            .collect();
394
395        let average_throughput = throughputs.iter().sum::<f64>() / throughputs.len() as f64;
396        let peak_throughput = throughputs
397            .iter()
398            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
399            .copied()
400            .unwrap_or(0.0);
401
402        let avg_duration = valid_timings.iter().sum::<Duration>() / valid_timings.len() as u32;
403        let tokens_per_second = if avg_duration.as_secs_f64() > 0.0 {
404            total_tokens as f64 / avg_duration.as_secs_f64()
405        } else {
406            0.0
407        };
408
409        let characters_per_second = if avg_duration.as_secs_f64() > 0.0 {
410            total_chars as f64 / avg_duration.as_secs_f64()
411        } else {
412            0.0
413        };
414
415        let batches_per_second = if avg_duration.as_secs_f64() > 0.0 {
416            1.0 / avg_duration.as_secs_f64()
417        } else {
418            0.0
419        };
420
421        ThroughputStats {
422            tokens_per_second,
423            characters_per_second,
424            batches_per_second,
425            peak_throughput,
426            average_throughput,
427        }
428    }
429
430    /// Get current memory usage (simplified)
431    fn get_memory_usage(&self) -> f64 {
432        // This is a simplified implementation
433        // In a real implementation, you'd use platform-specific APIs
434        // or libraries like `memory-stats` for accurate memory measurement
435        #[cfg(target_os = "linux")]
436        {
437            if let Ok(contents) = std::fs::read_to_string("/proc/self/status") {
438                for line in contents.lines() {
439                    if line.starts_with("VmRSS:") {
440                        if let Some(kb_str) = line.split_whitespace().nth(1) {
441                            if let Ok(kb) = kb_str.parse::<f64>() {
442                                return kb / 1024.0; // Convert to MB
443                            }
444                        }
445                    }
446                }
447            }
448        }
449
450        // Fallback: return 0 if we can't measure memory
451        0.0
452    }
453
454    /// Generate profiling report
455    fn generate_report(&self) -> Result<ProfilingReport> {
456        let summary = self.generate_summary();
457        let comparisons = self.generate_comparisons();
458        let recommendations = self.generate_recommendations();
459
460        Ok(ProfilingReport {
461            config: self.config.clone(),
462            benchmarks: self.results.clone(),
463            summary,
464            comparisons,
465            recommendations,
466            timestamp: chrono::Utc::now().to_rfc3339(),
467        })
468    }
469
470    /// Generate summary statistics
471    fn generate_summary(&self) -> ProfilingSummary {
472        if self.results.is_empty() {
473            return ProfilingSummary {
474                total_benchmarks: 0,
475                fastest_tokenizer: "N/A".to_string(),
476                most_memory_efficient: "N/A".to_string(),
477                highest_throughput: "N/A".to_string(),
478                most_consistent: "N/A".to_string(),
479                overall_stats: HashMap::new(),
480            };
481        }
482
483        // Find fastest tokenizer (lowest mean time)
484        let fastest = self
485            .results
486            .iter()
487            .min_by(|a, b| {
488                a.timing.mean.partial_cmp(&b.timing.mean).unwrap_or(std::cmp::Ordering::Equal)
489            })
490            .map(|r| r.tokenizer_name.clone())
491            .unwrap_or_else(|| "N/A".to_string());
492
493        // Find most memory efficient
494        let most_memory_efficient = self
495            .results
496            .iter()
497            .filter_map(|r| r.memory.as_ref().map(|m| (r, m.peak_memory_mb)))
498            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
499            .map(|(r, _)| r.tokenizer_name.clone())
500            .unwrap_or_else(|| "N/A".to_string());
501
502        // Find highest throughput
503        let highest_throughput = self
504            .results
505            .iter()
506            .filter_map(|r| r.throughput.as_ref().map(|t| (r, t.peak_throughput)))
507            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
508            .map(|(r, _)| r.tokenizer_name.clone())
509            .unwrap_or_else(|| "N/A".to_string());
510
511        // Find most consistent (lowest std deviation)
512        let most_consistent = self
513            .results
514            .iter()
515            .min_by(|a, b| {
516                a.timing
517                    .std_dev
518                    .partial_cmp(&b.timing.std_dev)
519                    .unwrap_or(std::cmp::Ordering::Equal)
520            })
521            .map(|r| r.tokenizer_name.clone())
522            .unwrap_or_else(|| "N/A".to_string());
523
524        // Calculate overall statistics
525        let mut overall_stats = HashMap::new();
526        let total_time: Duration = self.results.iter().map(|r| r.timing.mean).sum();
527        overall_stats.insert(
528            "total_benchmark_time_ms".to_string(),
529            total_time.as_millis() as f64,
530        );
531
532        let avg_throughput = self
533            .results
534            .iter()
535            .filter_map(|r| r.throughput.as_ref())
536            .map(|t| t.average_throughput)
537            .sum::<f64>()
538            / self.results.len() as f64;
539        overall_stats.insert("average_throughput".to_string(), avg_throughput);
540
541        ProfilingSummary {
542            total_benchmarks: self.results.len(),
543            fastest_tokenizer: fastest,
544            most_memory_efficient,
545            highest_throughput,
546            most_consistent,
547            overall_stats,
548        }
549    }
550
551    /// Generate tokenizer comparisons
552    fn generate_comparisons(&self) -> Vec<TokenizerComparison> {
553        let mut comparisons = Vec::new();
554
555        // Group results by scenario (text_length + batch_size)
556        let mut scenarios: HashMap<String, Vec<&BenchmarkResult>> = HashMap::new();
557        for result in &self.results {
558            let scenario = format!("length_{}_batch_{}", result.text_length, result.batch_size);
559            scenarios.entry(scenario).or_default().push(result);
560        }
561
562        for (scenario, results) in scenarios {
563            if results.len() > 1 {
564                let mut scenario_results = HashMap::new();
565                for result in &results {
566                    scenario_results.insert(result.tokenizer_name.clone(), (*result).clone());
567                }
568
569                // Find winner (fastest)
570                let winner = results
571                    .iter()
572                    .min_by(|a, b| {
573                        a.timing
574                            .mean
575                            .partial_cmp(&b.timing.mean)
576                            .unwrap_or(std::cmp::Ordering::Equal)
577                    })
578                    .map(|r| r.tokenizer_name.clone())
579                    .unwrap_or_else(|| "N/A".to_string());
580
581                // Calculate performance gap
582                let fastest_time =
583                    results.iter().map(|r| r.timing.mean.as_millis()).min().unwrap_or(0);
584                let slowest_time =
585                    results.iter().map(|r| r.timing.mean.as_millis()).max().unwrap_or(0);
586
587                let performance_gap = if fastest_time > 0 {
588                    (slowest_time as f64 / fastest_time as f64) - 1.0
589                } else {
590                    0.0
591                };
592
593                comparisons.push(TokenizerComparison {
594                    scenario,
595                    results: scenario_results,
596                    winner,
597                    performance_gap,
598                });
599            }
600        }
601
602        comparisons
603    }
604
605    /// Generate recommendations based on results
606    fn generate_recommendations(&self) -> Vec<String> {
607        let mut recommendations = Vec::new();
608
609        if self.results.is_empty() {
610            return recommendations;
611        }
612
613        // Analyze error rates
614        let high_error_rate = self.results.iter().any(|r| r.error_rate > 0.1);
615        if high_error_rate {
616            recommendations
617                .push("Consider investigating tokenizers with high error rates (>10%)".to_string());
618        }
619
620        // Analyze memory usage
621        if let Some(max_memory) = self
622            .results
623            .iter()
624            .filter_map(|r| r.memory.as_ref())
625            .map(|m| m.peak_memory_mb)
626            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
627        {
628            if max_memory > 1000.0 {
629                recommendations.push(
630                    "Consider using memory-efficient tokenizers for large-scale processing"
631                        .to_string(),
632                );
633            }
634        }
635
636        // Analyze consistency
637        let high_variance = self
638            .results
639            .iter()
640            .any(|r| r.timing.std_dev.as_millis() > r.timing.mean.as_millis() / 2);
641        if high_variance {
642            recommendations.push(
643                "Some tokenizers show high timing variance - consider warmup strategies"
644                    .to_string(),
645            );
646        }
647
648        // Analyze throughput
649        let throughputs: Vec<f64> = self
650            .results
651            .iter()
652            .filter_map(|r| r.throughput.as_ref())
653            .map(|t| t.average_throughput)
654            .collect();
655        if !throughputs.is_empty() {
656            let max_throughput = throughputs
657                .iter()
658                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
659                .copied()
660                .unwrap_or(0.0);
661            let min_throughput = throughputs
662                .iter()
663                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
664                .copied()
665                .unwrap_or(0.0);
666
667            if max_throughput > min_throughput * 2.0 {
668                recommendations.push("Significant throughput differences detected - choose tokenizer based on use case".to_string());
669            }
670        }
671
672        recommendations
673    }
674
675    /// Export report to different formats
676    pub fn export_report(&self, report: &ProfilingReport, format: ExportFormat) -> Result<String> {
677        match format {
678            ExportFormat::Json => self.export_json(report),
679            ExportFormat::Csv => self.export_csv(report),
680            ExportFormat::Html => self.export_html(report),
681            ExportFormat::Markdown => self.export_markdown(report),
682        }
683    }
684
685    /// Export to JSON
686    fn export_json(&self, report: &ProfilingReport) -> Result<String> {
687        serde_json::to_string_pretty(report).map_err(|e| {
688            TrustformersError::other(
689                anyhow::anyhow!("Failed to serialize to JSON: {}", e).to_string(),
690            )
691        })
692    }
693
694    /// Export to CSV
695    fn export_csv(&self, report: &ProfilingReport) -> Result<String> {
696        let mut csv = String::new();
697        csv.push_str(
698            "tokenizer_name,text_length,batch_size,mean_time_ms,memory_mb,throughput,error_rate\n",
699        );
700
701        for benchmark in &report.benchmarks {
702            csv.push_str(&format!(
703                "{},{},{},{},{},{},{}\n",
704                benchmark.tokenizer_name,
705                benchmark.text_length,
706                benchmark.batch_size,
707                benchmark.timing.mean.as_millis(),
708                benchmark.memory.as_ref().map(|m| m.peak_memory_mb).unwrap_or(0.0),
709                benchmark.throughput.as_ref().map(|t| t.average_throughput).unwrap_or(0.0),
710                benchmark.error_rate
711            ));
712        }
713
714        Ok(csv)
715    }
716
717    /// Export to HTML
718    fn export_html(&self, report: &ProfilingReport) -> Result<String> {
719        let mut html = String::new();
720        html.push_str(
721            "<!DOCTYPE html>\n<html>\n<head>\n<title>Tokenizer Performance Report</title>\n",
722        );
723        html.push_str("<style>body{font-family:Arial,sans-serif;margin:40px;}table{border-collapse:collapse;width:100%;}th,td{border:1px solid #ddd;padding:8px;text-align:left;}th{background-color:#f2f2f2;}</style>\n");
724        html.push_str("</head>\n<body>\n");
725        html.push_str("<h1>Tokenizer Performance Report</h1>\n");
726
727        html.push_str("<h2>Summary</h2>\n");
728        html.push_str("<table>\n");
729        html.push_str(&format!(
730            "<tr><td>Total Benchmarks</td><td>{}</td></tr>\n",
731            report.summary.total_benchmarks
732        ));
733        html.push_str(&format!(
734            "<tr><td>Fastest Tokenizer</td><td>{}</td></tr>\n",
735            report.summary.fastest_tokenizer
736        ));
737        html.push_str(&format!(
738            "<tr><td>Most Memory Efficient</td><td>{}</td></tr>\n",
739            report.summary.most_memory_efficient
740        ));
741        html.push_str(&format!(
742            "<tr><td>Highest Throughput</td><td>{}</td></tr>\n",
743            report.summary.highest_throughput
744        ));
745        html.push_str("</table>\n");
746
747        html.push_str("<h2>Detailed Results</h2>\n");
748        html.push_str("<table>\n");
749        html.push_str("<tr><th>Tokenizer</th><th>Text Length</th><th>Batch Size</th><th>Mean Time (ms)</th><th>Memory (MB)</th><th>Throughput</th></tr>\n");
750
751        for benchmark in &report.benchmarks {
752            html.push_str(&format!(
753                "<tr><td>{}</td><td>{}</td><td>{}</td><td>{}</td><td>{:.1}</td><td>{:.1}</td></tr>\n",
754                benchmark.tokenizer_name,
755                benchmark.text_length,
756                benchmark.batch_size,
757                benchmark.timing.mean.as_millis(),
758                benchmark.memory.as_ref().map(|m| m.peak_memory_mb).unwrap_or(0.0),
759                benchmark.throughput.as_ref().map(|t| t.average_throughput).unwrap_or(0.0)
760            ));
761        }
762
763        html.push_str("</table>\n</body>\n</html>");
764        Ok(html)
765    }
766
767    /// Export to Markdown
768    fn export_markdown(&self, report: &ProfilingReport) -> Result<String> {
769        let mut md = String::new();
770        md.push_str("# Tokenizer Performance Report\n\n");
771
772        md.push_str("## Summary\n\n");
773        md.push_str(&format!(
774            "- **Total Benchmarks**: {}\n",
775            report.summary.total_benchmarks
776        ));
777        md.push_str(&format!(
778            "- **Fastest Tokenizer**: {}\n",
779            report.summary.fastest_tokenizer
780        ));
781        md.push_str(&format!(
782            "- **Most Memory Efficient**: {}\n",
783            report.summary.most_memory_efficient
784        ));
785        md.push_str(&format!(
786            "- **Highest Throughput**: {}\n\n",
787            report.summary.highest_throughput
788        ));
789
790        md.push_str("## Detailed Results\n\n");
791        md.push_str("| Tokenizer | Text Length | Batch Size | Mean Time (ms) | Memory (MB) | Throughput |\n");
792        md.push_str("|-----------|-------------|------------|----------------|-------------|------------|\n");
793
794        for benchmark in &report.benchmarks {
795            md.push_str(&format!(
796                "| {} | {} | {} | {} | {:.1} | {:.1} |\n",
797                benchmark.tokenizer_name,
798                benchmark.text_length,
799                benchmark.batch_size,
800                benchmark.timing.mean.as_millis(),
801                benchmark.memory.as_ref().map(|m| m.peak_memory_mb).unwrap_or(0.0),
802                benchmark.throughput.as_ref().map(|t| t.average_throughput).unwrap_or(0.0)
803            ));
804        }
805
806        if !report.recommendations.is_empty() {
807            md.push_str("\n## Recommendations\n\n");
808            for (i, rec) in report.recommendations.iter().enumerate() {
809                md.push_str(&format!("{}. {}\n", i + 1, rec));
810            }
811        }
812
813        Ok(md)
814    }
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820    use crate::char::CharTokenizer;
821    use std::collections::HashMap;
822
823    fn create_test_char_tokenizer() -> CharTokenizer {
824        let mut vocab = HashMap::new();
825        vocab.insert("[PAD]".to_string(), 0);
826        vocab.insert("[UNK]".to_string(), 1);
827        vocab.insert("[CLS]".to_string(), 2);
828        vocab.insert("[SEP]".to_string(), 3);
829        vocab.insert("h".to_string(), 4);
830        vocab.insert("e".to_string(), 5);
831        vocab.insert("l".to_string(), 6);
832        vocab.insert("o".to_string(), 7);
833        vocab.insert("w".to_string(), 8);
834        vocab.insert("r".to_string(), 9);
835        vocab.insert("d".to_string(), 10);
836        vocab.insert(" ".to_string(), 11);
837        vocab.insert("t".to_string(), 12);
838        vocab.insert("s".to_string(), 13);
839        CharTokenizer::new(vocab)
840    }
841
842    #[test]
843    fn test_profiler_creation() {
844        let config = ProfilerConfig::default();
845        let profiler = PerformanceProfiler::new(config);
846        assert_eq!(profiler.results.len(), 0);
847    }
848
849    #[test]
850    fn test_single_tokenizer_profiling() {
851        let mut profiler = PerformanceProfiler::new(ProfilerConfig {
852            warmup_iterations: 1,
853            benchmark_iterations: 2,
854            text_lengths: vec![10],
855            batch_sizes: vec![1],
856            ..Default::default()
857        });
858
859        let tokenizer = create_test_char_tokenizer();
860        let test_texts = vec!["Hello world!".to_string()];
861
862        let results = profiler
863            .profile_tokenizer("char", &tokenizer, &test_texts)
864            .expect("Operation failed in test");
865        assert_eq!(results.len(), 1);
866        assert_eq!(results[0].tokenizer_name, "char");
867    }
868
869    #[test]
870    fn test_timing_stats_calculation() {
871        let profiler = PerformanceProfiler::default();
872        let timings = vec![
873            Duration::from_millis(100),
874            Duration::from_millis(110),
875            Duration::from_millis(90),
876            Duration::from_millis(105),
877        ];
878
879        let stats = profiler.calculate_timing_stats(&timings);
880        assert!(stats.mean.as_millis() > 0);
881        assert!(stats.min <= stats.median);
882        assert!(stats.median <= stats.max);
883    }
884
885    #[test]
886    fn test_report_generation() {
887        let mut profiler = PerformanceProfiler::new(ProfilerConfig {
888            warmup_iterations: 1,
889            benchmark_iterations: 1,
890            text_lengths: vec![5],
891            batch_sizes: vec![1],
892            ..Default::default()
893        });
894
895        let tokenizer = create_test_char_tokenizer();
896        let test_texts = vec!["Hi".to_string()];
897
898        profiler
899            .profile_tokenizer("test", &tokenizer, &test_texts)
900            .expect("Operation failed in test");
901        let report = profiler.generate_report().expect("Operation failed in test");
902
903        assert_eq!(report.benchmarks.len(), 1);
904        assert_eq!(report.summary.total_benchmarks, 1);
905    }
906
907    #[test]
908    fn test_export_formats() {
909        let profiler = PerformanceProfiler::default();
910        let report = ProfilingReport {
911            config: ProfilerConfig::default(),
912            benchmarks: vec![],
913            summary: ProfilingSummary {
914                total_benchmarks: 0,
915                fastest_tokenizer: "test".to_string(),
916                most_memory_efficient: "test".to_string(),
917                highest_throughput: "test".to_string(),
918                most_consistent: "test".to_string(),
919                overall_stats: HashMap::new(),
920            },
921            comparisons: vec![],
922            recommendations: vec![],
923            timestamp: "2023-01-01T00:00:00Z".to_string(),
924        };
925
926        let json = profiler
927            .export_report(&report, ExportFormat::Json)
928            .expect("Operation failed in test");
929        assert!(json.contains("fastest_tokenizer"));
930
931        let csv = profiler
932            .export_report(&report, ExportFormat::Csv)
933            .expect("Operation failed in test");
934        assert!(csv.contains("tokenizer_name"));
935
936        let html = profiler
937            .export_report(&report, ExportFormat::Html)
938            .expect("Operation failed in test");
939        assert!(html.contains("<html>"));
940
941        let md = profiler
942            .export_report(&report, ExportFormat::Markdown)
943            .expect("Operation failed in test");
944        assert!(md.contains("# Tokenizer Performance Report"));
945    }
946}