Skip to main content

trustformers_tokenizers/
visualization.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use trustformers_core::errors::{Result, TrustformersError};
4use trustformers_core::traits::{TokenizedInput, Tokenizer};
5
6/// Configuration for token visualization
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct VisualizationConfig {
9    pub show_token_ids: bool,
10    pub show_attention_mask: bool,
11    pub show_special_tokens: bool,
12    pub show_position_info: bool,
13    pub use_colors: bool,
14    pub max_display_length: Option<usize>,
15    pub highlight_patterns: Vec<String>,
16    pub custom_token_colors: HashMap<String, String>,
17}
18
19impl Default for VisualizationConfig {
20    fn default() -> Self {
21        Self {
22            show_token_ids: true,
23            show_attention_mask: false,
24            show_special_tokens: true,
25            show_position_info: false,
26            use_colors: true,
27            max_display_length: Some(100),
28            highlight_patterns: Vec::new(),
29            custom_token_colors: HashMap::new(),
30        }
31    }
32}
33
34/// Statistics about tokenization
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct TokenizationStats {
37    pub total_tokens: usize,
38    pub unique_tokens: usize,
39    pub special_tokens_count: usize,
40    pub average_token_length: f64,
41    pub compression_ratio: f64,
42    pub oov_count: usize,
43    pub token_type_distribution: HashMap<String, usize>,
44    pub longest_token: Option<String>,
45    pub shortest_token: Option<String>,
46}
47
48/// Detailed token information for visualization
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct TokenInfo {
51    pub token: String,
52    pub token_id: u32,
53    pub position: usize,
54    pub start_char: Option<usize>,
55    pub end_char: Option<usize>,
56    pub is_special: bool,
57    pub attention_value: u8,
58    pub token_type: Option<String>,
59    pub frequency: Option<f64>,
60}
61
62/// Visualization of tokenized input
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct TokenVisualization {
65    pub original_text: String,
66    pub tokens: Vec<TokenInfo>,
67    pub statistics: TokenizationStats,
68    pub config: VisualizationConfig,
69}
70
71/// Comparison between different tokenizers
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct TokenizerComparison {
74    pub original_text: String,
75    pub tokenizations: HashMap<String, TokenVisualization>,
76    pub comparison_stats: ComparisonStats,
77}
78
79/// Statistics comparing multiple tokenizers
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ComparisonStats {
82    pub token_count_variance: f64,
83    pub common_tokens: Vec<String>,
84    pub unique_tokens_by_tokenizer: HashMap<String, Vec<String>>,
85    pub compression_ratio_comparison: HashMap<String, f64>,
86    pub similarity_scores: HashMap<String, HashMap<String, f64>>,
87}
88
89/// Token visualizer implementation
90pub struct TokenVisualizer {
91    config: VisualizationConfig,
92    special_tokens: HashMap<String, u32>,
93}
94
95impl TokenVisualizer {
96    /// Create a new token visualizer
97    pub fn new(config: VisualizationConfig) -> Self {
98        Self {
99            config,
100            special_tokens: HashMap::new(),
101        }
102    }
103
104    /// Create visualizer with default configuration
105    pub fn default() -> Self {
106        Self::new(VisualizationConfig::default())
107    }
108
109    /// Add special tokens for recognition
110    pub fn with_special_tokens(mut self, special_tokens: HashMap<String, u32>) -> Self {
111        self.special_tokens = special_tokens;
112        self
113    }
114
115    /// Visualize tokenization from any tokenizer
116    pub fn visualize<T: Tokenizer>(&self, tokenizer: &T, text: &str) -> Result<TokenVisualization> {
117        let tokenized = tokenizer.encode(text)?;
118        let tokens = self.extract_token_info(tokenizer, text, &tokenized)?;
119        let statistics = self.calculate_statistics(text, &tokens);
120
121        Ok(TokenVisualization {
122            original_text: text.to_string(),
123            tokens,
124            statistics,
125            config: self.config.clone(),
126        })
127    }
128
129    /// Extract detailed token information
130    fn extract_token_info<T: Tokenizer>(
131        &self,
132        tokenizer: &T,
133        _original_text: &str,
134        tokenized: &TokenizedInput,
135    ) -> Result<Vec<TokenInfo>> {
136        let mut tokens = Vec::new();
137
138        for (i, &token_id) in tokenized.input_ids.iter().enumerate() {
139            // Try to decode individual token
140            let token_text = match tokenizer.decode(&[token_id]) {
141                Ok(text) => text,
142                Err(_) => format!("[UNK:{}]", token_id),
143            };
144
145            let is_special = self.special_tokens.values().any(|&id| id == token_id)
146                || token_text.starts_with('[') && token_text.ends_with(']');
147
148            let attention_value = tokenized.attention_mask.get(i).copied().unwrap_or(0);
149
150            tokens.push(TokenInfo {
151                token: token_text,
152                token_id,
153                position: i,
154                start_char: None, // Would need offset mapping from tokenizer
155                end_char: None,
156                is_special,
157                attention_value,
158                token_type: None, // Could be enhanced with token type classification
159                frequency: None,  // Could be enhanced with frequency data
160            });
161        }
162
163        Ok(tokens)
164    }
165
166    /// Calculate tokenization statistics
167    fn calculate_statistics(&self, original_text: &str, tokens: &[TokenInfo]) -> TokenizationStats {
168        let total_tokens = tokens.len();
169        let unique_tokens =
170            tokens.iter().map(|t| &t.token).collect::<std::collections::HashSet<_>>().len();
171
172        let special_tokens_count = tokens.iter().filter(|t| t.is_special).count();
173
174        let total_char_length: usize = tokens.iter().map(|t| t.token.len()).sum();
175
176        let average_token_length = if total_tokens > 0 {
177            total_char_length as f64 / total_tokens as f64
178        } else {
179            0.0
180        };
181
182        let compression_ratio = if !original_text.is_empty() {
183            total_tokens as f64 / original_text.len() as f64
184        } else {
185            0.0
186        };
187
188        let oov_count =
189            tokens.iter().filter(|t| t.token.contains("[UNK") || t.token == "[UNK]").count();
190
191        let mut token_type_distribution = HashMap::new();
192        for token in tokens {
193            let token_type = self.classify_token(&token.token);
194            *token_type_distribution.entry(token_type).or_insert(0) += 1;
195        }
196
197        let longest_token = tokens.iter().max_by_key(|t| t.token.len()).map(|t| t.token.clone());
198
199        let shortest_token = tokens
200            .iter()
201            .filter(|t| !t.is_special)
202            .min_by_key(|t| t.token.len())
203            .map(|t| t.token.clone());
204
205        TokenizationStats {
206            total_tokens,
207            unique_tokens,
208            special_tokens_count,
209            average_token_length,
210            compression_ratio,
211            oov_count,
212            token_type_distribution,
213            longest_token,
214            shortest_token,
215        }
216    }
217
218    /// Classify token type for statistics
219    fn classify_token(&self, token: &str) -> String {
220        if token.starts_with('[') && token.ends_with(']') {
221            "special".to_string()
222        } else if token.chars().all(|c| c.is_numeric()) {
223            "numeric".to_string()
224        } else if token.chars().all(|c| c.is_alphabetic()) {
225            "alphabetic".to_string()
226        } else if token.chars().all(|c| c.is_alphanumeric()) {
227            "alphanumeric".to_string()
228        } else if token.chars().all(|c| c.is_whitespace()) {
229            "whitespace".to_string()
230        } else if token.chars().all(|c| c.is_ascii_punctuation()) {
231            "punctuation".to_string()
232        } else {
233            "mixed".to_string()
234        }
235    }
236
237    /// Compare multiple tokenizers
238    pub fn compare_tokenizers<T: Tokenizer>(
239        &self,
240        tokenizers: HashMap<String, &T>,
241        text: &str,
242    ) -> Result<TokenizerComparison> {
243        let mut tokenizations = HashMap::new();
244
245        for (name, tokenizer) in tokenizers {
246            let visualization = self.visualize(tokenizer, text)?;
247            tokenizations.insert(name, visualization);
248        }
249
250        let comparison_stats = self.calculate_comparison_stats(&tokenizations);
251
252        Ok(TokenizerComparison {
253            original_text: text.to_string(),
254            tokenizations,
255            comparison_stats,
256        })
257    }
258
259    /// Calculate comparison statistics
260    fn calculate_comparison_stats(
261        &self,
262        tokenizations: &HashMap<String, TokenVisualization>,
263    ) -> ComparisonStats {
264        let token_counts: Vec<usize> =
265            tokenizations.values().map(|t| t.statistics.total_tokens).collect();
266
267        let token_count_variance = if token_counts.len() > 1 {
268            let mean = token_counts.iter().sum::<usize>() as f64 / token_counts.len() as f64;
269            let variance_sum: f64 =
270                token_counts.iter().map(|&count| (count as f64 - mean).powi(2)).sum();
271            variance_sum / token_counts.len() as f64
272        } else {
273            0.0
274        };
275
276        // Find common tokens across all tokenizers
277        let all_tokens: Vec<Vec<String>> = tokenizations
278            .values()
279            .map(|t| t.tokens.iter().map(|token| token.token.clone()).collect())
280            .collect();
281
282        let mut common_tokens = Vec::new();
283        if !all_tokens.is_empty() {
284            let first_tokens: std::collections::HashSet<String> =
285                all_tokens[0].iter().cloned().collect();
286            common_tokens = first_tokens
287                .into_iter()
288                .filter(|token| all_tokens.iter().skip(1).all(|tokens| tokens.contains(token)))
289                .collect();
290        }
291
292        // Find unique tokens by tokenizer
293        let mut unique_tokens_by_tokenizer = HashMap::new();
294        for (name, visualization) in tokenizations {
295            let tokens: std::collections::HashSet<String> =
296                visualization.tokens.iter().map(|t| t.token.clone()).collect();
297
298            let unique: Vec<String> = tokens
299                .into_iter()
300                .filter(|token| {
301                    tokenizations
302                        .iter()
303                        .filter(|(other_name, _)| *other_name != name)
304                        .all(|(_, other_viz)| !other_viz.tokens.iter().any(|t| &t.token == token))
305                })
306                .collect();
307
308            unique_tokens_by_tokenizer.insert(name.clone(), unique);
309        }
310
311        // Compression ratio comparison
312        let compression_ratio_comparison: HashMap<String, f64> = tokenizations
313            .iter()
314            .map(|(name, viz)| (name.clone(), viz.statistics.compression_ratio))
315            .collect();
316
317        // Calculate similarity scores (Jaccard similarity)
318        let mut similarity_scores = HashMap::new();
319        for (name1, viz1) in tokenizations {
320            let mut scores = HashMap::new();
321            let tokens1: std::collections::HashSet<String> =
322                viz1.tokens.iter().map(|t| t.token.clone()).collect();
323
324            for (name2, viz2) in tokenizations {
325                if name1 != name2 {
326                    let tokens2: std::collections::HashSet<String> =
327                        viz2.tokens.iter().map(|t| t.token.clone()).collect();
328
329                    let intersection = tokens1.intersection(&tokens2).count();
330                    let union = tokens1.union(&tokens2).count();
331                    let similarity =
332                        if union > 0 { intersection as f64 / union as f64 } else { 0.0 };
333
334                    scores.insert(name2.clone(), similarity);
335                }
336            }
337            similarity_scores.insert(name1.clone(), scores);
338        }
339
340        ComparisonStats {
341            token_count_variance,
342            common_tokens,
343            unique_tokens_by_tokenizer,
344            compression_ratio_comparison,
345            similarity_scores,
346        }
347    }
348
349    /// Generate HTML visualization
350    pub fn to_html(&self, visualization: &TokenVisualization) -> String {
351        let mut html = String::new();
352
353        html.push_str("<!DOCTYPE html>\n<html>\n<head>\n");
354        html.push_str("<title>Token Visualization</title>\n");
355        html.push_str("<style>\n");
356        html.push_str(Self::get_css());
357        html.push_str("</style>\n</head>\n<body>\n");
358
359        html.push_str("<h1>Token Visualization</h1>\n");
360
361        // Original text
362        html.push_str("<div class='section'>\n");
363        html.push_str("<h2>Original Text</h2>\n");
364        html.push_str(&format!(
365            "<div class='original-text'>{}</div>\n",
366            html_escape(&visualization.original_text)
367        ));
368        html.push_str("</div>\n");
369
370        // Tokens
371        html.push_str("<div class='section'>\n");
372        html.push_str("<h2>Tokens</h2>\n");
373        html.push_str("<div class='tokens'>\n");
374
375        for (i, token) in visualization.tokens.iter().enumerate() {
376            let class = if token.is_special { "token special" } else { "token" };
377            let color = self.get_token_color(token);
378
379            html.push_str(&format!(
380                "<span class='{}' style='background-color: {}' title='ID: {}, Pos: {}'>",
381                class, color, token.token_id, token.position
382            ));
383            html.push_str(&html_escape(&token.token));
384
385            if self.config.show_token_ids {
386                html.push_str(&format!("<sub>{}</sub>", token.token_id));
387            }
388
389            html.push_str("</span>");
390
391            if i < visualization.tokens.len() - 1 {
392                html.push(' ');
393            }
394        }
395
396        html.push_str("</div>\n</div>\n");
397
398        // Statistics
399        html.push_str("<div class='section'>\n");
400        html.push_str("<h2>Statistics</h2>\n");
401        html.push_str("<table class='stats-table'>\n");
402
403        let stats = &visualization.statistics;
404        html.push_str(&format!(
405            "<tr><td>Total Tokens</td><td>{}</td></tr>\n",
406            stats.total_tokens
407        ));
408        html.push_str(&format!(
409            "<tr><td>Unique Tokens</td><td>{}</td></tr>\n",
410            stats.unique_tokens
411        ));
412        html.push_str(&format!(
413            "<tr><td>Special Tokens</td><td>{}</td></tr>\n",
414            stats.special_tokens_count
415        ));
416        html.push_str(&format!(
417            "<tr><td>Average Token Length</td><td>{:.2}</td></tr>\n",
418            stats.average_token_length
419        ));
420        html.push_str(&format!(
421            "<tr><td>Compression Ratio</td><td>{:.4}</td></tr>\n",
422            stats.compression_ratio
423        ));
424        html.push_str(&format!(
425            "<tr><td>OOV Count</td><td>{}</td></tr>\n",
426            stats.oov_count
427        ));
428
429        if let Some(longest) = &stats.longest_token {
430            html.push_str(&format!(
431                "<tr><td>Longest Token</td><td>{}</td></tr>\n",
432                html_escape(longest)
433            ));
434        }
435
436        if let Some(shortest) = &stats.shortest_token {
437            html.push_str(&format!(
438                "<tr><td>Shortest Token</td><td>{}</td></tr>\n",
439                html_escape(shortest)
440            ));
441        }
442
443        html.push_str("</table>\n</div>\n");
444
445        // Token type distribution
446        html.push_str("<div class='section'>\n");
447        html.push_str("<h2>Token Type Distribution</h2>\n");
448        html.push_str("<table class='stats-table'>\n");
449
450        for (token_type, count) in &stats.token_type_distribution {
451            html.push_str(&format!(
452                "<tr><td>{}</td><td>{}</td></tr>\n",
453                token_type, count
454            ));
455        }
456
457        html.push_str("</table>\n</div>\n");
458
459        html.push_str("</body>\n</html>");
460        html
461    }
462
463    /// Get CSS styles for HTML visualization
464    fn get_css() -> &'static str {
465        r#"
466body {
467    font-family: Arial, sans-serif;
468    max-width: 1200px;
469    margin: 0 auto;
470    padding: 20px;
471    background-color: #f5f5f5;
472}
473
474.section {
475    background: white;
476    margin: 20px 0;
477    padding: 20px;
478    border-radius: 8px;
479    box-shadow: 0 2px 4px rgba(0,0,0,0.1);
480}
481
482h1, h2 {
483    color: #333;
484}
485
486.original-text {
487    background: #f8f9fa;
488    padding: 15px;
489    border-radius: 4px;
490    font-family: monospace;
491    border-left: 4px solid #007bff;
492}
493
494.tokens {
495    font-family: monospace;
496    line-height: 2;
497    word-wrap: break-word;
498}
499
500.token {
501    display: inline-block;
502    padding: 2px 4px;
503    margin: 1px;
504    border-radius: 3px;
505    border: 1px solid #ddd;
506    background-color: #e9ecef;
507    position: relative;
508}
509
510.token.special {
511    background-color: #fff3cd;
512    border-color: #ffeaa7;
513    font-weight: bold;
514}
515
516.token:hover {
517    box-shadow: 0 2px 8px rgba(0,0,0,0.15);
518    z-index: 10;
519}
520
521.stats-table {
522    width: 100%;
523    border-collapse: collapse;
524}
525
526.stats-table td {
527    padding: 8px 12px;
528    border-bottom: 1px solid #eee;
529}
530
531.stats-table td:first-child {
532    font-weight: bold;
533    color: #555;
534}
535
536sub {
537    font-size: 0.7em;
538    color: #666;
539}
540"#
541    }
542
543    /// Get color for a token
544    fn get_token_color(&self, token: &TokenInfo) -> String {
545        if let Some(color) = self.config.custom_token_colors.get(&token.token) {
546            return color.clone();
547        }
548
549        if token.is_special {
550            return "#fff3cd".to_string();
551        }
552
553        // Generate color based on token type or content
554        match token.token_type.as_deref() {
555            Some("numeric") => "#d1ecf1".to_string(),
556            Some("alphabetic") => "#d4edda".to_string(),
557            Some("punctuation") => "#f8d7da".to_string(),
558            Some("whitespace") => "#f1f3f4".to_string(),
559            _ => "#e9ecef".to_string(),
560        }
561    }
562
563    /// Generate plain text visualization
564    pub fn to_text(&self, visualization: &TokenVisualization) -> String {
565        let mut text = String::new();
566
567        text.push_str("=== Token Visualization ===\n\n");
568
569        text.push_str("Original Text:\n");
570        text.push_str(&visualization.original_text);
571        text.push_str("\n\n");
572
573        text.push_str("Tokens:\n");
574        for (i, token) in visualization.tokens.iter().enumerate() {
575            text.push_str(&format!("{:3}: ", i));
576            if self.config.show_token_ids {
577                text.push_str(&format!("[{}] ", token.token_id));
578            }
579            text.push_str(&format!("\"{}\"", token.token));
580            if token.is_special {
581                text.push_str(" (SPECIAL)");
582            }
583            text.push('\n');
584        }
585
586        text.push_str("\nStatistics:\n");
587        let stats = &visualization.statistics;
588        text.push_str(&format!("  Total Tokens: {}\n", stats.total_tokens));
589        text.push_str(&format!("  Unique Tokens: {}\n", stats.unique_tokens));
590        text.push_str(&format!(
591            "  Special Tokens: {}\n",
592            stats.special_tokens_count
593        ));
594        text.push_str(&format!(
595            "  Average Token Length: {:.2}\n",
596            stats.average_token_length
597        ));
598        text.push_str(&format!(
599            "  Compression Ratio: {:.4}\n",
600            stats.compression_ratio
601        ));
602        text.push_str(&format!("  OOV Count: {}\n", stats.oov_count));
603
604        if !stats.token_type_distribution.is_empty() {
605            text.push_str("\nToken Type Distribution:\n");
606            for (token_type, count) in &stats.token_type_distribution {
607                text.push_str(&format!("  {}: {}\n", token_type, count));
608            }
609        }
610
611        text
612    }
613
614    /// Export visualization to JSON
615    pub fn to_json(&self, visualization: &TokenVisualization) -> Result<String> {
616        serde_json::to_string_pretty(visualization).map_err(|e| {
617            TrustformersError::other(
618                anyhow::anyhow!("Failed to serialize to JSON: {}", e).to_string(),
619            )
620        })
621    }
622
623    /// Generate comparison report
624    pub fn comparison_report(&self, comparison: &TokenizerComparison) -> String {
625        let mut report = String::new();
626
627        report.push_str("=== Tokenizer Comparison Report ===\n\n");
628
629        report.push_str("Original Text:\n");
630        report.push_str(&comparison.original_text);
631        report.push_str("\n\n");
632
633        report.push_str("Tokenization Results:\n");
634        for (name, viz) in &comparison.tokenizations {
635            report.push_str(&format!(
636                "\n{} ({} tokens):\n",
637                name, viz.statistics.total_tokens
638            ));
639            for token in &viz.tokens {
640                report.push_str(&format!("  \"{}\"", token.token));
641            }
642            report.push('\n');
643        }
644
645        report.push_str("\nComparison Statistics:\n");
646        let stats = &comparison.comparison_stats;
647        report.push_str(&format!(
648            "  Token Count Variance: {:.2}\n",
649            stats.token_count_variance
650        ));
651        report.push_str(&format!("  Common Tokens: {}\n", stats.common_tokens.len()));
652
653        if !stats.common_tokens.is_empty() {
654            report.push_str("    ");
655            for (i, token) in stats.common_tokens.iter().enumerate() {
656                if i > 0 {
657                    report.push_str(", ");
658                }
659                report.push_str(&format!("\"{}\"", token));
660                if i >= 10 {
661                    report.push_str("...");
662                    break;
663                }
664            }
665            report.push('\n');
666        }
667
668        report.push_str("\nSimilarity Scores (Jaccard):\n");
669        for (name1, scores) in &stats.similarity_scores {
670            for (name2, score) in scores {
671                report.push_str(&format!("  {} vs {}: {:.3}\n", name1, name2, score));
672            }
673        }
674
675        report
676    }
677}
678
679/// HTML escape function
680fn html_escape(text: &str) -> String {
681    text.replace('&', "&amp;")
682        .replace('<', "&lt;")
683        .replace('>', "&gt;")
684        .replace('"', "&quot;")
685        .replace('\'', "&#x27;")
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use crate::char::CharTokenizer;
692
693    fn create_test_char_tokenizer() -> CharTokenizer {
694        let mut vocab = HashMap::new();
695        vocab.insert("[PAD]".to_string(), 0);
696        vocab.insert("[UNK]".to_string(), 1);
697        vocab.insert("[CLS]".to_string(), 2);
698        vocab.insert("[SEP]".to_string(), 3);
699        vocab.insert("h".to_string(), 4);
700        vocab.insert("e".to_string(), 5);
701        vocab.insert("l".to_string(), 6);
702        vocab.insert("o".to_string(), 7);
703        vocab.insert("w".to_string(), 8);
704        vocab.insert("r".to_string(), 9);
705        vocab.insert("d".to_string(), 10);
706        vocab.insert(" ".to_string(), 11);
707        vocab.insert("t".to_string(), 12);
708        vocab.insert("s".to_string(), 13);
709        CharTokenizer::new(vocab)
710    }
711
712    #[test]
713    fn test_visualization_creation() {
714        let tokenizer = create_test_char_tokenizer();
715        let visualizer = TokenVisualizer::default();
716
717        let result = visualizer
718            .visualize(&tokenizer, "Hello world!")
719            .expect("Operation failed in test");
720
721        assert_eq!(result.original_text, "Hello world!");
722        assert!(!result.tokens.is_empty());
723        assert!(result.statistics.total_tokens > 0);
724    }
725
726    #[test]
727    fn test_html_generation() {
728        let tokenizer = create_test_char_tokenizer();
729        let visualizer = TokenVisualizer::default();
730
731        let visualization =
732            visualizer.visualize(&tokenizer, "Hello").expect("Operation failed in test");
733        let html = visualizer.to_html(&visualization);
734
735        assert!(html.contains("<!DOCTYPE html>"));
736        assert!(html.contains("Token Visualization"));
737        assert!(html.contains("Hello"));
738    }
739
740    #[test]
741    fn test_text_generation() {
742        let tokenizer = create_test_char_tokenizer();
743        let visualizer = TokenVisualizer::default();
744
745        let visualization =
746            visualizer.visualize(&tokenizer, "Hello").expect("Operation failed in test");
747        let text = visualizer.to_text(&visualization);
748
749        assert!(text.contains("=== Token Visualization ==="));
750        assert!(text.contains("Hello"));
751        assert!(text.contains("Statistics:"));
752    }
753
754    #[test]
755    fn test_json_export() {
756        let tokenizer = create_test_char_tokenizer();
757        let visualizer = TokenVisualizer::default();
758
759        let visualization =
760            visualizer.visualize(&tokenizer, "Hi").expect("Operation failed in test");
761        let json = visualizer.to_json(&visualization).expect("Operation failed in test");
762
763        assert!(json.contains("original_text"));
764        assert!(json.contains("tokens"));
765        assert!(json.contains("statistics"));
766    }
767
768    #[test]
769    fn test_tokenizer_comparison() {
770        let char_tokenizer = create_test_char_tokenizer();
771        let tokenizer2 = create_test_char_tokenizer();
772
773        let mut tokenizers = HashMap::new();
774        tokenizers.insert("char1".to_string(), &char_tokenizer);
775        tokenizers.insert("char2".to_string(), &tokenizer2);
776
777        let visualizer = TokenVisualizer::default();
778        let comparison = visualizer
779            .compare_tokenizers(tokenizers, "Hello")
780            .expect("Operation failed in test");
781
782        assert_eq!(comparison.original_text, "Hello");
783        assert_eq!(comparison.tokenizations.len(), 2);
784        assert!(comparison.comparison_stats.similarity_scores.contains_key("char1"));
785    }
786}