1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use trustformers_core::errors::{Result, TrustformersError};
4use trustformers_core::traits::{TokenizedInput, Tokenizer};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct VisualizationConfig {
9 pub show_token_ids: bool,
10 pub show_attention_mask: bool,
11 pub show_special_tokens: bool,
12 pub show_position_info: bool,
13 pub use_colors: bool,
14 pub max_display_length: Option<usize>,
15 pub highlight_patterns: Vec<String>,
16 pub custom_token_colors: HashMap<String, String>,
17}
18
19impl Default for VisualizationConfig {
20 fn default() -> Self {
21 Self {
22 show_token_ids: true,
23 show_attention_mask: false,
24 show_special_tokens: true,
25 show_position_info: false,
26 use_colors: true,
27 max_display_length: Some(100),
28 highlight_patterns: Vec::new(),
29 custom_token_colors: HashMap::new(),
30 }
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct TokenizationStats {
37 pub total_tokens: usize,
38 pub unique_tokens: usize,
39 pub special_tokens_count: usize,
40 pub average_token_length: f64,
41 pub compression_ratio: f64,
42 pub oov_count: usize,
43 pub token_type_distribution: HashMap<String, usize>,
44 pub longest_token: Option<String>,
45 pub shortest_token: Option<String>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct TokenInfo {
51 pub token: String,
52 pub token_id: u32,
53 pub position: usize,
54 pub start_char: Option<usize>,
55 pub end_char: Option<usize>,
56 pub is_special: bool,
57 pub attention_value: u8,
58 pub token_type: Option<String>,
59 pub frequency: Option<f64>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct TokenVisualization {
65 pub original_text: String,
66 pub tokens: Vec<TokenInfo>,
67 pub statistics: TokenizationStats,
68 pub config: VisualizationConfig,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct TokenizerComparison {
74 pub original_text: String,
75 pub tokenizations: HashMap<String, TokenVisualization>,
76 pub comparison_stats: ComparisonStats,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct ComparisonStats {
82 pub token_count_variance: f64,
83 pub common_tokens: Vec<String>,
84 pub unique_tokens_by_tokenizer: HashMap<String, Vec<String>>,
85 pub compression_ratio_comparison: HashMap<String, f64>,
86 pub similarity_scores: HashMap<String, HashMap<String, f64>>,
87}
88
89pub struct TokenVisualizer {
91 config: VisualizationConfig,
92 special_tokens: HashMap<String, u32>,
93}
94
95impl TokenVisualizer {
96 pub fn new(config: VisualizationConfig) -> Self {
98 Self {
99 config,
100 special_tokens: HashMap::new(),
101 }
102 }
103
104 pub fn default() -> Self {
106 Self::new(VisualizationConfig::default())
107 }
108
109 pub fn with_special_tokens(mut self, special_tokens: HashMap<String, u32>) -> Self {
111 self.special_tokens = special_tokens;
112 self
113 }
114
115 pub fn visualize<T: Tokenizer>(&self, tokenizer: &T, text: &str) -> Result<TokenVisualization> {
117 let tokenized = tokenizer.encode(text)?;
118 let tokens = self.extract_token_info(tokenizer, text, &tokenized)?;
119 let statistics = self.calculate_statistics(text, &tokens);
120
121 Ok(TokenVisualization {
122 original_text: text.to_string(),
123 tokens,
124 statistics,
125 config: self.config.clone(),
126 })
127 }
128
129 fn extract_token_info<T: Tokenizer>(
131 &self,
132 tokenizer: &T,
133 _original_text: &str,
134 tokenized: &TokenizedInput,
135 ) -> Result<Vec<TokenInfo>> {
136 let mut tokens = Vec::new();
137
138 for (i, &token_id) in tokenized.input_ids.iter().enumerate() {
139 let token_text = match tokenizer.decode(&[token_id]) {
141 Ok(text) => text,
142 Err(_) => format!("[UNK:{}]", token_id),
143 };
144
145 let is_special = self.special_tokens.values().any(|&id| id == token_id)
146 || token_text.starts_with('[') && token_text.ends_with(']');
147
148 let attention_value = tokenized.attention_mask.get(i).copied().unwrap_or(0);
149
150 tokens.push(TokenInfo {
151 token: token_text,
152 token_id,
153 position: i,
154 start_char: None, end_char: None,
156 is_special,
157 attention_value,
158 token_type: None, frequency: None, });
161 }
162
163 Ok(tokens)
164 }
165
166 fn calculate_statistics(&self, original_text: &str, tokens: &[TokenInfo]) -> TokenizationStats {
168 let total_tokens = tokens.len();
169 let unique_tokens =
170 tokens.iter().map(|t| &t.token).collect::<std::collections::HashSet<_>>().len();
171
172 let special_tokens_count = tokens.iter().filter(|t| t.is_special).count();
173
174 let total_char_length: usize = tokens.iter().map(|t| t.token.len()).sum();
175
176 let average_token_length = if total_tokens > 0 {
177 total_char_length as f64 / total_tokens as f64
178 } else {
179 0.0
180 };
181
182 let compression_ratio = if !original_text.is_empty() {
183 total_tokens as f64 / original_text.len() as f64
184 } else {
185 0.0
186 };
187
188 let oov_count =
189 tokens.iter().filter(|t| t.token.contains("[UNK") || t.token == "[UNK]").count();
190
191 let mut token_type_distribution = HashMap::new();
192 for token in tokens {
193 let token_type = self.classify_token(&token.token);
194 *token_type_distribution.entry(token_type).or_insert(0) += 1;
195 }
196
197 let longest_token = tokens.iter().max_by_key(|t| t.token.len()).map(|t| t.token.clone());
198
199 let shortest_token = tokens
200 .iter()
201 .filter(|t| !t.is_special)
202 .min_by_key(|t| t.token.len())
203 .map(|t| t.token.clone());
204
205 TokenizationStats {
206 total_tokens,
207 unique_tokens,
208 special_tokens_count,
209 average_token_length,
210 compression_ratio,
211 oov_count,
212 token_type_distribution,
213 longest_token,
214 shortest_token,
215 }
216 }
217
218 fn classify_token(&self, token: &str) -> String {
220 if token.starts_with('[') && token.ends_with(']') {
221 "special".to_string()
222 } else if token.chars().all(|c| c.is_numeric()) {
223 "numeric".to_string()
224 } else if token.chars().all(|c| c.is_alphabetic()) {
225 "alphabetic".to_string()
226 } else if token.chars().all(|c| c.is_alphanumeric()) {
227 "alphanumeric".to_string()
228 } else if token.chars().all(|c| c.is_whitespace()) {
229 "whitespace".to_string()
230 } else if token.chars().all(|c| c.is_ascii_punctuation()) {
231 "punctuation".to_string()
232 } else {
233 "mixed".to_string()
234 }
235 }
236
237 pub fn compare_tokenizers<T: Tokenizer>(
239 &self,
240 tokenizers: HashMap<String, &T>,
241 text: &str,
242 ) -> Result<TokenizerComparison> {
243 let mut tokenizations = HashMap::new();
244
245 for (name, tokenizer) in tokenizers {
246 let visualization = self.visualize(tokenizer, text)?;
247 tokenizations.insert(name, visualization);
248 }
249
250 let comparison_stats = self.calculate_comparison_stats(&tokenizations);
251
252 Ok(TokenizerComparison {
253 original_text: text.to_string(),
254 tokenizations,
255 comparison_stats,
256 })
257 }
258
259 fn calculate_comparison_stats(
261 &self,
262 tokenizations: &HashMap<String, TokenVisualization>,
263 ) -> ComparisonStats {
264 let token_counts: Vec<usize> =
265 tokenizations.values().map(|t| t.statistics.total_tokens).collect();
266
267 let token_count_variance = if token_counts.len() > 1 {
268 let mean = token_counts.iter().sum::<usize>() as f64 / token_counts.len() as f64;
269 let variance_sum: f64 =
270 token_counts.iter().map(|&count| (count as f64 - mean).powi(2)).sum();
271 variance_sum / token_counts.len() as f64
272 } else {
273 0.0
274 };
275
276 let all_tokens: Vec<Vec<String>> = tokenizations
278 .values()
279 .map(|t| t.tokens.iter().map(|token| token.token.clone()).collect())
280 .collect();
281
282 let mut common_tokens = Vec::new();
283 if !all_tokens.is_empty() {
284 let first_tokens: std::collections::HashSet<String> =
285 all_tokens[0].iter().cloned().collect();
286 common_tokens = first_tokens
287 .into_iter()
288 .filter(|token| all_tokens.iter().skip(1).all(|tokens| tokens.contains(token)))
289 .collect();
290 }
291
292 let mut unique_tokens_by_tokenizer = HashMap::new();
294 for (name, visualization) in tokenizations {
295 let tokens: std::collections::HashSet<String> =
296 visualization.tokens.iter().map(|t| t.token.clone()).collect();
297
298 let unique: Vec<String> = tokens
299 .into_iter()
300 .filter(|token| {
301 tokenizations
302 .iter()
303 .filter(|(other_name, _)| *other_name != name)
304 .all(|(_, other_viz)| !other_viz.tokens.iter().any(|t| &t.token == token))
305 })
306 .collect();
307
308 unique_tokens_by_tokenizer.insert(name.clone(), unique);
309 }
310
311 let compression_ratio_comparison: HashMap<String, f64> = tokenizations
313 .iter()
314 .map(|(name, viz)| (name.clone(), viz.statistics.compression_ratio))
315 .collect();
316
317 let mut similarity_scores = HashMap::new();
319 for (name1, viz1) in tokenizations {
320 let mut scores = HashMap::new();
321 let tokens1: std::collections::HashSet<String> =
322 viz1.tokens.iter().map(|t| t.token.clone()).collect();
323
324 for (name2, viz2) in tokenizations {
325 if name1 != name2 {
326 let tokens2: std::collections::HashSet<String> =
327 viz2.tokens.iter().map(|t| t.token.clone()).collect();
328
329 let intersection = tokens1.intersection(&tokens2).count();
330 let union = tokens1.union(&tokens2).count();
331 let similarity =
332 if union > 0 { intersection as f64 / union as f64 } else { 0.0 };
333
334 scores.insert(name2.clone(), similarity);
335 }
336 }
337 similarity_scores.insert(name1.clone(), scores);
338 }
339
340 ComparisonStats {
341 token_count_variance,
342 common_tokens,
343 unique_tokens_by_tokenizer,
344 compression_ratio_comparison,
345 similarity_scores,
346 }
347 }
348
349 pub fn to_html(&self, visualization: &TokenVisualization) -> String {
351 let mut html = String::new();
352
353 html.push_str("<!DOCTYPE html>\n<html>\n<head>\n");
354 html.push_str("<title>Token Visualization</title>\n");
355 html.push_str("<style>\n");
356 html.push_str(Self::get_css());
357 html.push_str("</style>\n</head>\n<body>\n");
358
359 html.push_str("<h1>Token Visualization</h1>\n");
360
361 html.push_str("<div class='section'>\n");
363 html.push_str("<h2>Original Text</h2>\n");
364 html.push_str(&format!(
365 "<div class='original-text'>{}</div>\n",
366 html_escape(&visualization.original_text)
367 ));
368 html.push_str("</div>\n");
369
370 html.push_str("<div class='section'>\n");
372 html.push_str("<h2>Tokens</h2>\n");
373 html.push_str("<div class='tokens'>\n");
374
375 for (i, token) in visualization.tokens.iter().enumerate() {
376 let class = if token.is_special { "token special" } else { "token" };
377 let color = self.get_token_color(token);
378
379 html.push_str(&format!(
380 "<span class='{}' style='background-color: {}' title='ID: {}, Pos: {}'>",
381 class, color, token.token_id, token.position
382 ));
383 html.push_str(&html_escape(&token.token));
384
385 if self.config.show_token_ids {
386 html.push_str(&format!("<sub>{}</sub>", token.token_id));
387 }
388
389 html.push_str("</span>");
390
391 if i < visualization.tokens.len() - 1 {
392 html.push(' ');
393 }
394 }
395
396 html.push_str("</div>\n</div>\n");
397
398 html.push_str("<div class='section'>\n");
400 html.push_str("<h2>Statistics</h2>\n");
401 html.push_str("<table class='stats-table'>\n");
402
403 let stats = &visualization.statistics;
404 html.push_str(&format!(
405 "<tr><td>Total Tokens</td><td>{}</td></tr>\n",
406 stats.total_tokens
407 ));
408 html.push_str(&format!(
409 "<tr><td>Unique Tokens</td><td>{}</td></tr>\n",
410 stats.unique_tokens
411 ));
412 html.push_str(&format!(
413 "<tr><td>Special Tokens</td><td>{}</td></tr>\n",
414 stats.special_tokens_count
415 ));
416 html.push_str(&format!(
417 "<tr><td>Average Token Length</td><td>{:.2}</td></tr>\n",
418 stats.average_token_length
419 ));
420 html.push_str(&format!(
421 "<tr><td>Compression Ratio</td><td>{:.4}</td></tr>\n",
422 stats.compression_ratio
423 ));
424 html.push_str(&format!(
425 "<tr><td>OOV Count</td><td>{}</td></tr>\n",
426 stats.oov_count
427 ));
428
429 if let Some(longest) = &stats.longest_token {
430 html.push_str(&format!(
431 "<tr><td>Longest Token</td><td>{}</td></tr>\n",
432 html_escape(longest)
433 ));
434 }
435
436 if let Some(shortest) = &stats.shortest_token {
437 html.push_str(&format!(
438 "<tr><td>Shortest Token</td><td>{}</td></tr>\n",
439 html_escape(shortest)
440 ));
441 }
442
443 html.push_str("</table>\n</div>\n");
444
445 html.push_str("<div class='section'>\n");
447 html.push_str("<h2>Token Type Distribution</h2>\n");
448 html.push_str("<table class='stats-table'>\n");
449
450 for (token_type, count) in &stats.token_type_distribution {
451 html.push_str(&format!(
452 "<tr><td>{}</td><td>{}</td></tr>\n",
453 token_type, count
454 ));
455 }
456
457 html.push_str("</table>\n</div>\n");
458
459 html.push_str("</body>\n</html>");
460 html
461 }
462
463 fn get_css() -> &'static str {
465 r#"
466body {
467 font-family: Arial, sans-serif;
468 max-width: 1200px;
469 margin: 0 auto;
470 padding: 20px;
471 background-color: #f5f5f5;
472}
473
474.section {
475 background: white;
476 margin: 20px 0;
477 padding: 20px;
478 border-radius: 8px;
479 box-shadow: 0 2px 4px rgba(0,0,0,0.1);
480}
481
482h1, h2 {
483 color: #333;
484}
485
486.original-text {
487 background: #f8f9fa;
488 padding: 15px;
489 border-radius: 4px;
490 font-family: monospace;
491 border-left: 4px solid #007bff;
492}
493
494.tokens {
495 font-family: monospace;
496 line-height: 2;
497 word-wrap: break-word;
498}
499
500.token {
501 display: inline-block;
502 padding: 2px 4px;
503 margin: 1px;
504 border-radius: 3px;
505 border: 1px solid #ddd;
506 background-color: #e9ecef;
507 position: relative;
508}
509
510.token.special {
511 background-color: #fff3cd;
512 border-color: #ffeaa7;
513 font-weight: bold;
514}
515
516.token:hover {
517 box-shadow: 0 2px 8px rgba(0,0,0,0.15);
518 z-index: 10;
519}
520
521.stats-table {
522 width: 100%;
523 border-collapse: collapse;
524}
525
526.stats-table td {
527 padding: 8px 12px;
528 border-bottom: 1px solid #eee;
529}
530
531.stats-table td:first-child {
532 font-weight: bold;
533 color: #555;
534}
535
536sub {
537 font-size: 0.7em;
538 color: #666;
539}
540"#
541 }
542
543 fn get_token_color(&self, token: &TokenInfo) -> String {
545 if let Some(color) = self.config.custom_token_colors.get(&token.token) {
546 return color.clone();
547 }
548
549 if token.is_special {
550 return "#fff3cd".to_string();
551 }
552
553 match token.token_type.as_deref() {
555 Some("numeric") => "#d1ecf1".to_string(),
556 Some("alphabetic") => "#d4edda".to_string(),
557 Some("punctuation") => "#f8d7da".to_string(),
558 Some("whitespace") => "#f1f3f4".to_string(),
559 _ => "#e9ecef".to_string(),
560 }
561 }
562
563 pub fn to_text(&self, visualization: &TokenVisualization) -> String {
565 let mut text = String::new();
566
567 text.push_str("=== Token Visualization ===\n\n");
568
569 text.push_str("Original Text:\n");
570 text.push_str(&visualization.original_text);
571 text.push_str("\n\n");
572
573 text.push_str("Tokens:\n");
574 for (i, token) in visualization.tokens.iter().enumerate() {
575 text.push_str(&format!("{:3}: ", i));
576 if self.config.show_token_ids {
577 text.push_str(&format!("[{}] ", token.token_id));
578 }
579 text.push_str(&format!("\"{}\"", token.token));
580 if token.is_special {
581 text.push_str(" (SPECIAL)");
582 }
583 text.push('\n');
584 }
585
586 text.push_str("\nStatistics:\n");
587 let stats = &visualization.statistics;
588 text.push_str(&format!(" Total Tokens: {}\n", stats.total_tokens));
589 text.push_str(&format!(" Unique Tokens: {}\n", stats.unique_tokens));
590 text.push_str(&format!(
591 " Special Tokens: {}\n",
592 stats.special_tokens_count
593 ));
594 text.push_str(&format!(
595 " Average Token Length: {:.2}\n",
596 stats.average_token_length
597 ));
598 text.push_str(&format!(
599 " Compression Ratio: {:.4}\n",
600 stats.compression_ratio
601 ));
602 text.push_str(&format!(" OOV Count: {}\n", stats.oov_count));
603
604 if !stats.token_type_distribution.is_empty() {
605 text.push_str("\nToken Type Distribution:\n");
606 for (token_type, count) in &stats.token_type_distribution {
607 text.push_str(&format!(" {}: {}\n", token_type, count));
608 }
609 }
610
611 text
612 }
613
614 pub fn to_json(&self, visualization: &TokenVisualization) -> Result<String> {
616 serde_json::to_string_pretty(visualization).map_err(|e| {
617 TrustformersError::other(
618 anyhow::anyhow!("Failed to serialize to JSON: {}", e).to_string(),
619 )
620 })
621 }
622
623 pub fn comparison_report(&self, comparison: &TokenizerComparison) -> String {
625 let mut report = String::new();
626
627 report.push_str("=== Tokenizer Comparison Report ===\n\n");
628
629 report.push_str("Original Text:\n");
630 report.push_str(&comparison.original_text);
631 report.push_str("\n\n");
632
633 report.push_str("Tokenization Results:\n");
634 for (name, viz) in &comparison.tokenizations {
635 report.push_str(&format!(
636 "\n{} ({} tokens):\n",
637 name, viz.statistics.total_tokens
638 ));
639 for token in &viz.tokens {
640 report.push_str(&format!(" \"{}\"", token.token));
641 }
642 report.push('\n');
643 }
644
645 report.push_str("\nComparison Statistics:\n");
646 let stats = &comparison.comparison_stats;
647 report.push_str(&format!(
648 " Token Count Variance: {:.2}\n",
649 stats.token_count_variance
650 ));
651 report.push_str(&format!(" Common Tokens: {}\n", stats.common_tokens.len()));
652
653 if !stats.common_tokens.is_empty() {
654 report.push_str(" ");
655 for (i, token) in stats.common_tokens.iter().enumerate() {
656 if i > 0 {
657 report.push_str(", ");
658 }
659 report.push_str(&format!("\"{}\"", token));
660 if i >= 10 {
661 report.push_str("...");
662 break;
663 }
664 }
665 report.push('\n');
666 }
667
668 report.push_str("\nSimilarity Scores (Jaccard):\n");
669 for (name1, scores) in &stats.similarity_scores {
670 for (name2, score) in scores {
671 report.push_str(&format!(" {} vs {}: {:.3}\n", name1, name2, score));
672 }
673 }
674
675 report
676 }
677}
678
679fn html_escape(text: &str) -> String {
681 text.replace('&', "&")
682 .replace('<', "<")
683 .replace('>', ">")
684 .replace('"', """)
685 .replace('\'', "'")
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691 use crate::char::CharTokenizer;
692
693 fn create_test_char_tokenizer() -> CharTokenizer {
694 let mut vocab = HashMap::new();
695 vocab.insert("[PAD]".to_string(), 0);
696 vocab.insert("[UNK]".to_string(), 1);
697 vocab.insert("[CLS]".to_string(), 2);
698 vocab.insert("[SEP]".to_string(), 3);
699 vocab.insert("h".to_string(), 4);
700 vocab.insert("e".to_string(), 5);
701 vocab.insert("l".to_string(), 6);
702 vocab.insert("o".to_string(), 7);
703 vocab.insert("w".to_string(), 8);
704 vocab.insert("r".to_string(), 9);
705 vocab.insert("d".to_string(), 10);
706 vocab.insert(" ".to_string(), 11);
707 vocab.insert("t".to_string(), 12);
708 vocab.insert("s".to_string(), 13);
709 CharTokenizer::new(vocab)
710 }
711
712 #[test]
713 fn test_visualization_creation() {
714 let tokenizer = create_test_char_tokenizer();
715 let visualizer = TokenVisualizer::default();
716
717 let result = visualizer
718 .visualize(&tokenizer, "Hello world!")
719 .expect("Operation failed in test");
720
721 assert_eq!(result.original_text, "Hello world!");
722 assert!(!result.tokens.is_empty());
723 assert!(result.statistics.total_tokens > 0);
724 }
725
726 #[test]
727 fn test_html_generation() {
728 let tokenizer = create_test_char_tokenizer();
729 let visualizer = TokenVisualizer::default();
730
731 let visualization =
732 visualizer.visualize(&tokenizer, "Hello").expect("Operation failed in test");
733 let html = visualizer.to_html(&visualization);
734
735 assert!(html.contains("<!DOCTYPE html>"));
736 assert!(html.contains("Token Visualization"));
737 assert!(html.contains("Hello"));
738 }
739
740 #[test]
741 fn test_text_generation() {
742 let tokenizer = create_test_char_tokenizer();
743 let visualizer = TokenVisualizer::default();
744
745 let visualization =
746 visualizer.visualize(&tokenizer, "Hello").expect("Operation failed in test");
747 let text = visualizer.to_text(&visualization);
748
749 assert!(text.contains("=== Token Visualization ==="));
750 assert!(text.contains("Hello"));
751 assert!(text.contains("Statistics:"));
752 }
753
754 #[test]
755 fn test_json_export() {
756 let tokenizer = create_test_char_tokenizer();
757 let visualizer = TokenVisualizer::default();
758
759 let visualization =
760 visualizer.visualize(&tokenizer, "Hi").expect("Operation failed in test");
761 let json = visualizer.to_json(&visualization).expect("Operation failed in test");
762
763 assert!(json.contains("original_text"));
764 assert!(json.contains("tokens"));
765 assert!(json.contains("statistics"));
766 }
767
768 #[test]
769 fn test_tokenizer_comparison() {
770 let char_tokenizer = create_test_char_tokenizer();
771 let tokenizer2 = create_test_char_tokenizer();
772
773 let mut tokenizers = HashMap::new();
774 tokenizers.insert("char1".to_string(), &char_tokenizer);
775 tokenizers.insert("char2".to_string(), &tokenizer2);
776
777 let visualizer = TokenVisualizer::default();
778 let comparison = visualizer
779 .compare_tokenizers(tokenizers, "Hello")
780 .expect("Operation failed in test");
781
782 assert_eq!(comparison.original_text, "Hello");
783 assert_eq!(comparison.tokenizations.len(), 2);
784 assert!(comparison.comparison_stats.similarity_scores.contains_key("char1"));
785 }
786}