Skip to main content

scirs2_text/
tokenizer.rs

1//! Tokenization utilities for transformer models
2//!
3//! This module provides production-grade tokenizer implementations designed for
4//! use with transformer-based models such as BERT, GPT, and T5. It includes:
5//!
6//! - **BPE (Byte-Pair Encoding)**: Subword tokenization used by GPT-2, RoBERTa
7//! - **WordPiece**: BERT-style subword tokenization with `##` continuation prefix
8//! - **SimpleWhitespace**: Vocabulary-aware whitespace tokenizer with UNK handling
9//! - **SimpleChar**: Character-level tokenizer mapping each char to an ID
10//!
11//! All tokenizers implement the [`TransformerTokenizer`] trait which provides
12//! `encode` (text to token IDs) and `decode` (token IDs to text) operations,
13//! as well as `vocab_size` for embedding layer configuration.
14//!
15//! # Example
16//!
17//! ```rust
18//! use scirs2_text::tokenizer::{BPETokenizer, TransformerTokenizer};
19//!
20//! // Train a BPE tokenizer on a small corpus
21//! let corpus = &["the cat sat on the mat", "the dog sat on the log"];
22//! let tokenizer = BPETokenizer::train(corpus, 100).expect("training failed");
23//!
24//! // Encode and decode
25//! let ids = tokenizer.encode("the cat");
26//! let text = tokenizer.decode(&ids);
27//! assert!(text.contains("the"));
28//! ```
29
30use crate::error::{Result, TextError};
31use std::collections::HashMap;
32use std::fs::File;
33use std::io::{BufRead, BufReader, BufWriter, Write as IoWrite};
34use std::path::Path;
35
36// ---------------------------------------------------------------------------
37// TransformerTokenizer trait
38// ---------------------------------------------------------------------------
39
40/// Trait for tokenizers designed for transformer model input/output.
41///
42/// Unlike the general-purpose [`crate::tokenize::Tokenizer`] trait which returns
43/// `Vec<String>`, this trait operates on integer token IDs (`u32`) suitable for
44/// embedding lookup in neural models.
45pub trait TransformerTokenizer {
46    /// Encode text into a sequence of token IDs.
47    fn encode(&self, text: &str) -> Vec<u32>;
48
49    /// Decode a sequence of token IDs back into text.
50    fn decode(&self, ids: &[u32]) -> String;
51
52    /// Return the vocabulary size (number of distinct token IDs).
53    fn vocab_size(&self) -> usize;
54}
55
56// ---------------------------------------------------------------------------
57// Pre-tokenisation helpers (shared across tokenizer implementations)
58// ---------------------------------------------------------------------------
59
60/// Normalise text before tokenisation: lowercase and collapse whitespace.
61fn pre_tokenize(text: &str) -> String {
62    let lower = text.to_lowercase();
63    // Collapse whitespace runs into single spaces, trim.
64    lower.split_whitespace().collect::<Vec<&str>>().join(" ")
65}
66
67/// Split text into words on whitespace boundaries.
68fn split_words(text: &str) -> Vec<String> {
69    text.split_whitespace().map(|w| w.to_string()).collect()
70}
71
72// ===========================================================================
73// BPETokenizer
74// ===========================================================================
75
76/// A Byte-Pair Encoding tokenizer for transformer models.
77///
78/// BPE iteratively merges the most frequent pair of symbols (initially
79/// characters) to form a vocabulary of subword units. This implementation
80/// supports:
81///
82/// - Training from a text corpus with a target vocabulary size
83/// - Encoding text to `u32` token IDs
84/// - Decoding token IDs back to text
85/// - Special tokens (`[PAD]`, `[UNK]`, `[CLS]`, `[SEP]`, `[MASK]`)
86/// - JSON-based persistence (save/load)
87///
88/// # Training
89///
90/// ```rust
91/// use scirs2_text::tokenizer::{BPETokenizer, TransformerTokenizer};
92///
93/// let corpus = &["hello world", "hello there"];
94/// let tok = BPETokenizer::train(corpus, 80).expect("train");
95/// assert!(tok.vocab_size() > 0);
96/// ```
97#[derive(Debug, Clone)]
98pub struct BPETokenizer {
99    /// token string -> token ID
100    vocab: HashMap<String, u32>,
101    /// token ID -> token string (inverse)
102    id_to_token: HashMap<u32, String>,
103    /// Ordered list of merge pairs learned during training
104    merges: Vec<(String, String)>,
105    /// Special tokens with reserved IDs
106    special_tokens: HashMap<String, u32>,
107}
108
109/// Default special tokens used by BPE when constructed via `new()`.
110const DEFAULT_SPECIAL_TOKENS: &[&str] = &["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"];
111
112impl BPETokenizer {
113    /// Create a new, empty BPE tokenizer pre-loaded with standard special tokens.
114    pub fn new() -> Self {
115        let mut vocab = HashMap::new();
116        let mut id_to_token = HashMap::new();
117        let mut special_tokens = HashMap::new();
118
119        for (i, &tok) in DEFAULT_SPECIAL_TOKENS.iter().enumerate() {
120            let id = i as u32;
121            vocab.insert(tok.to_string(), id);
122            id_to_token.insert(id, tok.to_string());
123            special_tokens.insert(tok.to_string(), id);
124        }
125
126        Self {
127            vocab,
128            id_to_token,
129            merges: Vec::new(),
130            special_tokens,
131        }
132    }
133
134    /// Train a BPE tokenizer on a text corpus.
135    ///
136    /// The tokenizer learns `vocab_size` total tokens (including special tokens
137    /// and individual characters that appear in the corpus).
138    ///
139    /// # Arguments
140    ///
141    /// * `texts` - Slice of text strings to train on.
142    /// * `vocab_size` - Target vocabulary size (must be > number of special tokens).
143    ///
144    /// # Errors
145    ///
146    /// Returns `TextError::TokenizationError` if the corpus is empty or
147    /// `vocab_size` is too small.
148    pub fn train(texts: &[&str], vocab_size: usize) -> Result<Self> {
149        if texts.is_empty() {
150            return Err(TextError::TokenizationError(
151                "Cannot train on empty corpus".to_string(),
152            ));
153        }
154        if vocab_size < DEFAULT_SPECIAL_TOKENS.len() + 1 {
155            return Err(TextError::TokenizationError(format!(
156                "vocab_size must be at least {} (special tokens + 1)",
157                DEFAULT_SPECIAL_TOKENS.len() + 1
158            )));
159        }
160
161        let mut tokenizer = Self::new();
162
163        // Step 1: Collect all unique characters and add them to the vocab.
164        let mut char_set: Vec<char> = Vec::new();
165        for text in texts {
166            let normalized = pre_tokenize(text);
167            for ch in normalized.chars() {
168                if !char_set.contains(&ch) {
169                    char_set.push(ch);
170                }
171            }
172        }
173        char_set.sort();
174
175        for ch in &char_set {
176            let s = ch.to_string();
177            if !tokenizer.vocab.contains_key(&s) {
178                let id = tokenizer.vocab.len() as u32;
179                tokenizer.vocab.insert(s.clone(), id);
180                tokenizer.id_to_token.insert(id, s);
181            }
182        }
183
184        // Step 2: Build word-frequency table.
185        // Each word is represented as a sequence of symbols (initially chars).
186        let mut word_freqs: HashMap<Vec<String>, u64> = HashMap::new();
187        for text in texts {
188            let normalized = pre_tokenize(text);
189            for word in split_words(&normalized) {
190                let symbols: Vec<String> = word.chars().map(|c| c.to_string()).collect();
191                *word_freqs.entry(symbols).or_insert(0) += 1;
192            }
193        }
194
195        // Step 3: Iteratively merge the most frequent pair.
196        let max_merges = vocab_size.saturating_sub(tokenizer.vocab.len());
197
198        for _ in 0..max_merges {
199            // Count bigram frequencies across all words.
200            let mut pair_counts: HashMap<(String, String), u64> = HashMap::new();
201            for (symbols, freq) in &word_freqs {
202                if symbols.len() < 2 {
203                    continue;
204                }
205                for pair in symbols.windows(2) {
206                    let key = (pair[0].clone(), pair[1].clone());
207                    *pair_counts.entry(key).or_insert(0) += freq;
208                }
209            }
210
211            // Find the most frequent pair.
212            let best = pair_counts
213                .iter()
214                .max_by_key(|&(_, &count)| count)
215                .map(|(pair, _)| pair.clone());
216
217            let best = match best {
218                Some(p) => p,
219                None => break, // No pairs left to merge.
220            };
221
222            let merged = format!("{}{}", best.0, best.1);
223
224            // Register merge & new token.
225            tokenizer.merges.push(best.clone());
226            if !tokenizer.vocab.contains_key(&merged) {
227                let id = tokenizer.vocab.len() as u32;
228                tokenizer.vocab.insert(merged.clone(), id);
229                tokenizer.id_to_token.insert(id, merged.clone());
230            }
231
232            // Apply merge to word table.
233            let mut new_word_freqs: HashMap<Vec<String>, u64> = HashMap::new();
234            for (symbols, freq) in &word_freqs {
235                let updated = apply_merge(symbols, &best.0, &best.1, &merged);
236                *new_word_freqs.entry(updated).or_insert(0) += freq;
237            }
238            word_freqs = new_word_freqs;
239        }
240
241        Ok(tokenizer)
242    }
243
244    /// Return the UNK token ID.
245    fn unk_id(&self) -> u32 {
246        self.special_tokens.get("[UNK]").copied().unwrap_or(1)
247    }
248
249    /// Encode a single word (no whitespace) into token IDs using learned merges.
250    fn encode_word(&self, word: &str) -> Vec<u32> {
251        if word.is_empty() {
252            return Vec::new();
253        }
254
255        let mut symbols: Vec<String> = word.chars().map(|c| c.to_string()).collect();
256
257        // Apply merges in order.
258        for (left, right) in &self.merges {
259            let merged = format!("{}{}", left, right);
260            symbols = apply_merge(&symbols, left, right, &merged);
261        }
262
263        // Map symbols to IDs.
264        let unk = self.unk_id();
265        symbols
266            .iter()
267            .map(|s| self.vocab.get(s).copied().unwrap_or(unk))
268            .collect()
269    }
270
271    /// Get a reference to the special tokens map.
272    pub fn special_tokens(&self) -> &HashMap<String, u32> {
273        &self.special_tokens
274    }
275
276    /// Get the token ID for a given special token name, e.g. `"[CLS]"`.
277    pub fn special_token_id(&self, name: &str) -> Option<u32> {
278        self.special_tokens.get(name).copied()
279    }
280
281    /// Add a custom special token. Returns the assigned ID.
282    pub fn add_special_token(&mut self, token: &str) -> u32 {
283        if let Some(&id) = self.vocab.get(token) {
284            self.special_tokens.insert(token.to_string(), id);
285            return id;
286        }
287        let id = self.vocab.len() as u32;
288        self.vocab.insert(token.to_string(), id);
289        self.id_to_token.insert(id, token.to_string());
290        self.special_tokens.insert(token.to_string(), id);
291        id
292    }
293
294    /// Save the tokenizer to a JSON file.
295    ///
296    /// The JSON format stores `vocab`, `merges`, and `special_tokens`.
297    pub fn save_json(&self, path: &Path) -> Result<()> {
298        let file = File::create(path).map_err(|e| TextError::IoError(format!("save_json: {e}")))?;
299        let writer = BufWriter::new(file);
300
301        // Build a serializable structure manually (no serde required at runtime
302        // if serde-support feature is off -- we write JSON by hand).
303        write_bpe_json(writer, &self.vocab, &self.merges, &self.special_tokens)
304    }
305
306    /// Load a BPE tokenizer from a JSON file previously written by [`Self::save_json`].
307    pub fn load_json(path: &Path) -> Result<Self> {
308        let file = File::open(path).map_err(|e| TextError::IoError(format!("load_json: {e}")))?;
309        let reader = BufReader::new(file);
310        read_bpe_json(reader)
311    }
312}
313
314impl Default for BPETokenizer {
315    fn default() -> Self {
316        Self::new()
317    }
318}
319
320impl TransformerTokenizer for BPETokenizer {
321    fn encode(&self, text: &str) -> Vec<u32> {
322        let normalized = pre_tokenize(text);
323        let words = split_words(&normalized);
324        let mut ids = Vec::new();
325        for word in &words {
326            ids.extend(self.encode_word(word));
327        }
328        ids
329    }
330
331    fn decode(&self, ids: &[u32]) -> String {
332        let tokens: Vec<String> = ids
333            .iter()
334            .filter_map(|&id| self.id_to_token.get(&id).cloned())
335            .collect();
336        // Join tokens -- BPE tokens are subword pieces so just concatenate.
337        // We insert a space before tokens that start a new word. As a heuristic
338        // we detect word boundaries when the token does not start with a
339        // continuation indicator. For a simple BPE without explicit word boundary
340        // markers we just concatenate with spaces between independent words.
341        // Since our training splits on whitespace and encodes each word separately
342        // we do a rough reconstruction.
343        rejoin_bpe_tokens(&tokens)
344    }
345
346    fn vocab_size(&self) -> usize {
347        self.vocab.len()
348    }
349}
350
351// ---------------------------------------------------------------------------
352// BPE helper functions
353// ---------------------------------------------------------------------------
354
355/// Apply a single merge to a symbol sequence.
356fn apply_merge(symbols: &[String], left: &str, right: &str, merged: &str) -> Vec<String> {
357    let mut result = Vec::with_capacity(symbols.len());
358    let mut i = 0;
359    while i < symbols.len() {
360        if i + 1 < symbols.len() && symbols[i] == left && symbols[i + 1] == right {
361            result.push(merged.to_string());
362            i += 2;
363        } else {
364            result.push(symbols[i].clone());
365            i += 1;
366        }
367    }
368    result
369}
370
371/// Rejoin BPE-decoded token strings into readable text.
372///
373/// We use a simple word-boundary heuristic: tokens that were produced from
374/// separate words during encoding will be separated by spaces. Because our BPE
375/// trains per-word, the concatenation of subword pieces within a word is direct.
376/// Between words we insert spaces. We detect word starts by checking if the
377/// decoded sequence would make sense with a space. This is approximate.
378fn rejoin_bpe_tokens(tokens: &[String]) -> String {
379    if tokens.is_empty() {
380        return String::new();
381    }
382    // Since BPE encodes each whitespace-separated word independently, the
383    // decode output is the concatenation of subword pieces. A space character
384    // token (" ") would indicate word boundaries if present. Otherwise we
385    // just concatenate everything.
386    let joined: String = tokens.concat();
387    // If the vocabulary includes the space character as a token, the spaces
388    // appear naturally. If not, the caller may need post-processing.
389    joined
390}
391
392/// Write BPE tokenizer state as JSON to a writer.
393fn write_bpe_json<W: IoWrite>(
394    mut w: W,
395    vocab: &HashMap<String, u32>,
396    merges: &[(String, String)],
397    special_tokens: &HashMap<String, u32>,
398) -> Result<()> {
399    let write_err = |e: std::io::Error| TextError::IoError(format!("write_bpe_json: {e}"));
400
401    w.write_all(b"{\n").map_err(write_err)?;
402
403    // vocab
404    w.write_all(b"  \"vocab\": {\n").map_err(write_err)?;
405    let mut sorted_vocab: Vec<(&String, &u32)> = vocab.iter().collect();
406    sorted_vocab.sort_by_key(|&(_, id)| *id);
407    for (idx, (token, id)) in sorted_vocab.iter().enumerate() {
408        let comma = if idx + 1 < sorted_vocab.len() {
409            ","
410        } else {
411            ""
412        };
413        let escaped = escape_json_string(token);
414        writeln!(w, "    \"{}\": {}{}", escaped, id, comma).map_err(write_err)?;
415    }
416    w.write_all(b"  },\n").map_err(write_err)?;
417
418    // merges
419    w.write_all(b"  \"merges\": [\n").map_err(write_err)?;
420    for (idx, (left, right)) in merges.iter().enumerate() {
421        let comma = if idx + 1 < merges.len() { "," } else { "" };
422        let left_esc = escape_json_string(left);
423        let right_esc = escape_json_string(right);
424        writeln!(w, "    [\"{}\", \"{}\"]{}", left_esc, right_esc, comma).map_err(write_err)?;
425    }
426    w.write_all(b"  ],\n").map_err(write_err)?;
427
428    // special_tokens
429    w.write_all(b"  \"special_tokens\": {\n")
430        .map_err(write_err)?;
431    let mut sorted_special: Vec<(&String, &u32)> = special_tokens.iter().collect();
432    sorted_special.sort_by_key(|&(_, id)| *id);
433    for (idx, (token, id)) in sorted_special.iter().enumerate() {
434        let comma = if idx + 1 < sorted_special.len() {
435            ","
436        } else {
437            ""
438        };
439        let escaped = escape_json_string(token);
440        writeln!(w, "    \"{}\": {}{}", escaped, id, comma).map_err(write_err)?;
441    }
442    w.write_all(b"  }\n").map_err(write_err)?;
443
444    w.write_all(b"}\n").map_err(write_err)?;
445    Ok(())
446}
447
448/// Read BPE tokenizer state from a JSON reader.
449///
450/// This is a minimal hand-rolled JSON parser sufficient for our format.
451/// It avoids requiring serde at runtime.
452fn read_bpe_json<R: BufRead>(reader: R) -> Result<BPETokenizer> {
453    let content: String = reader
454        .lines()
455        .collect::<std::result::Result<Vec<_>, _>>()
456        .map_err(|e| TextError::IoError(format!("read_bpe_json: {e}")))?
457        .join("\n");
458
459    let mut vocab: HashMap<String, u32> = HashMap::new();
460    let mut id_to_token: HashMap<u32, String> = HashMap::new();
461    let mut merges: Vec<(String, String)> = Vec::new();
462    let mut special_tokens: HashMap<String, u32> = HashMap::new();
463
464    // Parse vocab section
465    if let Some(vocab_section) = extract_json_object(&content, "vocab") {
466        for (key, val) in parse_string_int_pairs(&vocab_section) {
467            vocab.insert(key.clone(), val);
468            id_to_token.insert(val, key);
469        }
470    }
471
472    // Parse merges section
473    if let Some(merges_section) = extract_json_array(&content, "merges") {
474        merges = parse_merge_pairs(&merges_section);
475    }
476
477    // Parse special_tokens section
478    if let Some(special_section) = extract_json_object(&content, "special_tokens") {
479        for (key, val) in parse_string_int_pairs(&special_section) {
480            special_tokens.insert(key, val);
481        }
482    }
483
484    Ok(BPETokenizer {
485        vocab,
486        id_to_token,
487        merges,
488        special_tokens,
489    })
490}
491
492/// Escape a string for JSON output.
493fn escape_json_string(s: &str) -> String {
494    let mut out = String::with_capacity(s.len() + 2);
495    for ch in s.chars() {
496        match ch {
497            '"' => out.push_str("\\\""),
498            '\\' => out.push_str("\\\\"),
499            '\n' => out.push_str("\\n"),
500            '\r' => out.push_str("\\r"),
501            '\t' => out.push_str("\\t"),
502            c if c.is_control() => {
503                out.push_str(&format!("\\u{:04x}", c as u32));
504            }
505            c => out.push(c),
506        }
507    }
508    out
509}
510
511/// Unescape a JSON string value.
512fn unescape_json_string(s: &str) -> String {
513    let mut out = String::with_capacity(s.len());
514    let mut chars = s.chars();
515    while let Some(ch) = chars.next() {
516        if ch == '\\' {
517            match chars.next() {
518                Some('"') => out.push('"'),
519                Some('\\') => out.push('\\'),
520                Some('n') => out.push('\n'),
521                Some('r') => out.push('\r'),
522                Some('t') => out.push('\t'),
523                Some('u') => {
524                    let hex: String = chars.by_ref().take(4).collect();
525                    if let Ok(code) = u32::from_str_radix(&hex, 16) {
526                        if let Some(c) = char::from_u32(code) {
527                            out.push(c);
528                        }
529                    }
530                }
531                Some(other) => {
532                    out.push('\\');
533                    out.push(other);
534                }
535                None => out.push('\\'),
536            }
537        } else {
538            out.push(ch);
539        }
540    }
541    out
542}
543
544/// Extract a JSON object value by key name, returning the content between braces.
545fn extract_json_object(json: &str, key: &str) -> Option<String> {
546    let pattern = format!("\"{}\"", key);
547    let start = json.find(&pattern)?;
548    let after_key = &json[start + pattern.len()..];
549    // Find the opening brace
550    let brace_start = after_key.find('{')?;
551    let content_start = start + pattern.len() + brace_start;
552
553    // Find matching closing brace
554    let mut depth = 0;
555    for (i, ch) in json[content_start..].chars().enumerate() {
556        match ch {
557            '{' => depth += 1,
558            '}' => {
559                depth -= 1;
560                if depth == 0 {
561                    return Some(json[content_start + 1..content_start + i].to_string());
562                }
563            }
564            _ => {}
565        }
566    }
567    None
568}
569
570/// Extract a JSON array value by key name, returning the content between brackets.
571fn extract_json_array(json: &str, key: &str) -> Option<String> {
572    let pattern = format!("\"{}\"", key);
573    let start = json.find(&pattern)?;
574    let after_key = &json[start + pattern.len()..];
575    let bracket_start = after_key.find('[')?;
576    let content_start = start + pattern.len() + bracket_start;
577
578    let mut depth = 0;
579    for (i, ch) in json[content_start..].chars().enumerate() {
580        match ch {
581            '[' => depth += 1,
582            ']' => {
583                depth -= 1;
584                if depth == 0 {
585                    return Some(json[content_start + 1..content_start + i].to_string());
586                }
587            }
588            _ => {}
589        }
590    }
591    None
592}
593
594/// Parse `"key": int` pairs from the inside of a JSON object.
595fn parse_string_int_pairs(content: &str) -> Vec<(String, u32)> {
596    let mut pairs = Vec::new();
597    let mut remaining = content.trim();
598
599    while !remaining.is_empty() {
600        // Find next quoted key
601        let q1 = match remaining.find('"') {
602            Some(pos) => pos,
603            None => break,
604        };
605        let after_q1 = &remaining[q1 + 1..];
606        // Find closing quote (handle escapes)
607        let q2 = match find_unescaped_quote(after_q1) {
608            Some(pos) => pos,
609            None => break,
610        };
611        let key = unescape_json_string(&after_q1[..q2]);
612        let after_key = &after_q1[q2 + 1..];
613
614        // Find colon then the integer value
615        let colon = match after_key.find(':') {
616            Some(pos) => pos,
617            None => break,
618        };
619        let after_colon = after_key[colon + 1..].trim_start();
620
621        // Read integer
622        let num_end = after_colon
623            .find(|c: char| !c.is_ascii_digit())
624            .unwrap_or(after_colon.len());
625        if let Ok(val) = after_colon[..num_end].parse::<u32>() {
626            pairs.push((key, val));
627        }
628
629        // Advance past comma or end
630        let consumed = after_colon[num_end..].trim_start();
631        remaining = if consumed.starts_with(',') {
632            &consumed[1..]
633        } else {
634            consumed
635        };
636    }
637    pairs
638}
639
640/// Parse merge pairs `["left", "right"]` from inside a JSON array.
641fn parse_merge_pairs(content: &str) -> Vec<(String, String)> {
642    let mut pairs = Vec::new();
643    let mut remaining = content.trim();
644
645    while !remaining.is_empty() {
646        // Find next inner array '['
647        let bracket = match remaining.find('[') {
648            Some(pos) => pos,
649            None => break,
650        };
651        let end_bracket = match remaining[bracket..].find(']') {
652            Some(pos) => bracket + pos,
653            None => break,
654        };
655        let inner = &remaining[bracket + 1..end_bracket];
656
657        // Extract two quoted strings from inner
658        let mut strings = Vec::new();
659        let mut inner_rem = inner.trim();
660        for _ in 0..2 {
661            let q1 = match inner_rem.find('"') {
662                Some(pos) => pos,
663                None => break,
664            };
665            let after_q1 = &inner_rem[q1 + 1..];
666            let q2 = match find_unescaped_quote(after_q1) {
667                Some(pos) => pos,
668                None => break,
669            };
670            strings.push(unescape_json_string(&after_q1[..q2]));
671            inner_rem = &after_q1[q2 + 1..];
672            inner_rem = inner_rem.trim_start();
673            if inner_rem.starts_with(',') {
674                inner_rem = &inner_rem[1..];
675            }
676        }
677
678        if strings.len() == 2 {
679            pairs.push((strings[0].clone(), strings[1].clone()));
680        }
681
682        remaining = &remaining[end_bracket + 1..];
683        remaining = remaining.trim_start();
684        if remaining.starts_with(',') {
685            remaining = &remaining[1..];
686        }
687    }
688    pairs
689}
690
691/// Find the position of the first unescaped double-quote in a string.
692fn find_unescaped_quote(s: &str) -> Option<usize> {
693    let mut escape = false;
694    for (i, ch) in s.chars().enumerate() {
695        if escape {
696            escape = false;
697            continue;
698        }
699        if ch == '\\' {
700            escape = true;
701            continue;
702        }
703        if ch == '"' {
704            return Some(i);
705        }
706    }
707    None
708}
709
710// ===========================================================================
711// WordPieceTokenizer
712// ===========================================================================
713
714/// A WordPiece tokenizer (BERT-style).
715///
716/// WordPiece is a subword tokenization algorithm that greedily matches the
717/// longest prefix of a word from the vocabulary, using a continuation prefix
718/// (typically `"##"`) for non-initial subwords.
719///
720/// # Example
721///
722/// ```rust
723/// use scirs2_text::tokenizer::WordPieceTokenizer;
724/// use std::collections::HashMap;
725///
726/// let mut vocab = HashMap::new();
727/// vocab.insert("[UNK]".to_string(), 0);
728/// vocab.insert("hello".to_string(), 1);
729/// vocab.insert("world".to_string(), 2);
730/// vocab.insert("hel".to_string(), 3);
731/// vocab.insert("##lo".to_string(), 4);
732///
733/// let tokenizer = WordPieceTokenizer::new(vocab);
734/// let tokens = tokenizer.tokenize("hello world");
735/// assert!(tokens.contains(&"hello".to_string()) || tokens.contains(&"hel".to_string()));
736/// ```
737#[derive(Debug, Clone)]
738pub struct WordPieceTokenizer {
739    /// Token string -> token ID
740    vocab: HashMap<String, u32>,
741    /// Token ID -> token string (inverse)
742    id_to_token: HashMap<u32, String>,
743    /// Maximum word length to consider (longer words become [UNK])
744    max_word_len: usize,
745    /// The unknown token string
746    unk_token: String,
747    /// The prefix prepended to continuation subwords (default: `"##"`)
748    continuing_prefix: String,
749}
750
751impl WordPieceTokenizer {
752    /// Create a new WordPiece tokenizer from a vocabulary map.
753    ///
754    /// The vocabulary must contain at least `[UNK]`. The continuation prefix
755    /// defaults to `"##"` and max word length defaults to 200.
756    pub fn new(vocab: HashMap<String, u32>) -> Self {
757        let id_to_token: HashMap<u32, String> =
758            vocab.iter().map(|(k, &v)| (v, k.clone())).collect();
759        Self {
760            vocab,
761            id_to_token,
762            max_word_len: 200,
763            unk_token: "[UNK]".to_string(),
764            continuing_prefix: "##".to_string(),
765        }
766    }
767
768    /// Set the maximum word length.
769    pub fn with_max_word_len(mut self, max_len: usize) -> Self {
770        self.max_word_len = max_len;
771        self
772    }
773
774    /// Set the unknown token string.
775    pub fn with_unk_token(mut self, unk: &str) -> Self {
776        self.unk_token = unk.to_string();
777        self
778    }
779
780    /// Set the continuation prefix (default `"##"`).
781    pub fn with_continuing_prefix(mut self, prefix: &str) -> Self {
782        self.continuing_prefix = prefix.to_string();
783        self
784    }
785
786    /// Load a WordPiece vocabulary from a text file (one token per line).
787    ///
788    /// Token IDs are assigned sequentially starting from 0.
789    pub fn from_vocab_file(path: &Path) -> Result<Self> {
790        let file =
791            File::open(path).map_err(|e| TextError::IoError(format!("from_vocab_file: {e}")))?;
792        let reader = BufReader::new(file);
793
794        let mut vocab = HashMap::new();
795        for (id, line) in reader.lines().enumerate() {
796            let line =
797                line.map_err(|e| TextError::IoError(format!("from_vocab_file read: {e}")))?;
798            let token = line.trim().to_string();
799            if !token.is_empty() {
800                vocab.insert(token, id as u32);
801            }
802        }
803
804        if vocab.is_empty() {
805            return Err(TextError::TokenizationError(
806                "Vocabulary file is empty".to_string(),
807            ));
808        }
809
810        Ok(Self::new(vocab))
811    }
812
813    /// Tokenize text into a list of token strings (subword pieces).
814    ///
815    /// Each word is greedily segmented into the longest matching vocab entries.
816    /// Non-initial pieces are prefixed with `##`.
817    pub fn tokenize(&self, text: &str) -> Vec<String> {
818        let normalized = pre_tokenize(text);
819        let words = split_words(&normalized);
820        let mut tokens = Vec::new();
821
822        for word in &words {
823            if word.len() > self.max_word_len {
824                tokens.push(self.unk_token.clone());
825                continue;
826            }
827            let word_tokens = self.tokenize_word(word);
828            tokens.extend(word_tokens);
829        }
830        tokens
831    }
832
833    /// Segment a single word into WordPiece tokens.
834    fn tokenize_word(&self, word: &str) -> Vec<String> {
835        let chars: Vec<char> = word.chars().collect();
836        let mut tokens = Vec::new();
837        let mut start = 0;
838
839        while start < chars.len() {
840            let mut end = chars.len();
841            let mut found = false;
842
843            while start < end {
844                let substr: String = chars[start..end].iter().collect();
845                let candidate = if start == 0 {
846                    substr.clone()
847                } else {
848                    format!("{}{}", self.continuing_prefix, substr)
849                };
850
851                if self.vocab.contains_key(&candidate) {
852                    tokens.push(candidate);
853                    found = true;
854                    break;
855                }
856                end -= 1;
857            }
858
859            if !found {
860                // Could not match any subword starting at `start`
861                tokens.push(self.unk_token.clone());
862                start += 1;
863            } else {
864                start = end;
865            }
866        }
867
868        tokens
869    }
870
871    /// Get the UNK token ID.
872    fn unk_id(&self) -> u32 {
873        self.vocab.get(&self.unk_token).copied().unwrap_or(0)
874    }
875}
876
877impl TransformerTokenizer for WordPieceTokenizer {
878    fn encode(&self, text: &str) -> Vec<u32> {
879        let tokens = self.tokenize(text);
880        let unk = self.unk_id();
881        tokens
882            .iter()
883            .map(|t| self.vocab.get(t).copied().unwrap_or(unk))
884            .collect()
885    }
886
887    fn decode(&self, ids: &[u32]) -> String {
888        let mut result = String::new();
889        let mut first_in_word = true;
890
891        for &id in ids {
892            let token = match self.id_to_token.get(&id) {
893                Some(t) => t.as_str(),
894                None => &self.unk_token,
895            };
896
897            if let Some(stripped) = token.strip_prefix(&self.continuing_prefix) {
898                // Continuation piece: append directly (no space)
899                result.push_str(stripped);
900            } else {
901                // New word
902                if !first_in_word {
903                    result.push(' ');
904                }
905                result.push_str(token);
906            }
907            first_in_word = false;
908        }
909        result
910    }
911
912    fn vocab_size(&self) -> usize {
913        self.vocab.len()
914    }
915}
916
917// ===========================================================================
918// SimpleWhitespaceTokenizer
919// ===========================================================================
920
921/// A vocabulary-aware whitespace tokenizer.
922///
923/// Splits text on whitespace and maps each word to an integer token ID.
924/// Unknown words map to a reserved UNK ID.
925///
926/// # Example
927///
928/// ```rust
929/// use scirs2_text::tokenizer::{SimpleWhitespaceTokenizer, TransformerTokenizer};
930///
931/// let texts = &["hello world", "hello there"];
932/// let tok = SimpleWhitespaceTokenizer::build(texts, 100);
933/// let ids = tok.encode("hello world");
934/// let decoded = tok.decode(&ids);
935/// assert_eq!(decoded, "hello world");
936/// ```
937#[derive(Debug, Clone)]
938pub struct SimpleWhitespaceTokenizer {
939    /// word -> token ID
940    vocab: HashMap<String, u32>,
941    /// token ID -> word
942    id_to_token: HashMap<u32, String>,
943    /// Token ID for unknown words
944    unk_id: u32,
945}
946
947impl SimpleWhitespaceTokenizer {
948    /// Build a whitespace tokenizer from training texts.
949    ///
950    /// Counts word frequencies and keeps the top `max_vocab` words.
951    /// Token ID 0 is reserved for `[UNK]`.
952    pub fn build(texts: &[&str], max_vocab: usize) -> Self {
953        let mut word_counts: HashMap<String, u64> = HashMap::new();
954        for text in texts {
955            let normalized = pre_tokenize(text);
956            for word in split_words(&normalized) {
957                *word_counts.entry(word).or_insert(0) += 1;
958            }
959        }
960
961        // Sort by frequency (descending), then alphabetically for determinism.
962        let mut sorted: Vec<(String, u64)> = word_counts.into_iter().collect();
963        sorted.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
964
965        let mut vocab = HashMap::new();
966        let mut id_to_token = HashMap::new();
967
968        // Reserve ID 0 for UNK
969        vocab.insert("[UNK]".to_string(), 0);
970        id_to_token.insert(0, "[UNK]".to_string());
971
972        let limit = max_vocab.saturating_sub(1); // -1 for UNK
973        for (word, _) in sorted.into_iter().take(limit) {
974            let id = vocab.len() as u32;
975            id_to_token.insert(id, word.clone());
976            vocab.insert(word, id);
977        }
978
979        Self {
980            vocab,
981            id_to_token,
982            unk_id: 0,
983        }
984    }
985
986    /// Create from an existing vocabulary.
987    pub fn from_vocab(vocab: HashMap<String, u32>, unk_id: u32) -> Self {
988        let id_to_token: HashMap<u32, String> =
989            vocab.iter().map(|(k, &v)| (v, k.clone())).collect();
990        Self {
991            vocab,
992            id_to_token,
993            unk_id,
994        }
995    }
996}
997
998impl TransformerTokenizer for SimpleWhitespaceTokenizer {
999    fn encode(&self, text: &str) -> Vec<u32> {
1000        let normalized = pre_tokenize(text);
1001        split_words(&normalized)
1002            .iter()
1003            .map(|w| self.vocab.get(w).copied().unwrap_or(self.unk_id))
1004            .collect()
1005    }
1006
1007    fn decode(&self, ids: &[u32]) -> String {
1008        ids.iter()
1009            .filter_map(|&id| self.id_to_token.get(&id).cloned())
1010            .collect::<Vec<String>>()
1011            .join(" ")
1012    }
1013
1014    fn vocab_size(&self) -> usize {
1015        self.vocab.len()
1016    }
1017}
1018
1019// ===========================================================================
1020// SimpleCharTokenizer
1021// ===========================================================================
1022
1023/// A character-level tokenizer that maps each unique character to a token ID.
1024///
1025/// Useful as a baseline or for character-level transformer models.
1026///
1027/// # Example
1028///
1029/// ```rust
1030/// use scirs2_text::tokenizer::{SimpleCharTokenizer, TransformerTokenizer};
1031///
1032/// let texts = &["abc", "bcd"];
1033/// let tok = SimpleCharTokenizer::build(texts);
1034/// let ids = tok.encode("abc");
1035/// assert_eq!(ids.len(), 3);
1036/// let decoded = tok.decode(&ids);
1037/// assert_eq!(decoded, "abc");
1038/// ```
1039#[derive(Debug, Clone)]
1040pub struct SimpleCharTokenizer {
1041    /// char -> token ID
1042    vocab: HashMap<char, u32>,
1043    /// token ID -> char
1044    id_to_char: HashMap<u32, char>,
1045    /// Token ID for unknown characters
1046    unk_id: u32,
1047}
1048
1049impl SimpleCharTokenizer {
1050    /// Build a character tokenizer from training texts.
1051    ///
1052    /// Token ID 0 is reserved for unknown characters (those not seen in training).
1053    pub fn build(texts: &[&str]) -> Self {
1054        let mut char_set: Vec<char> = Vec::new();
1055        for text in texts {
1056            for ch in text.chars() {
1057                if !char_set.contains(&ch) {
1058                    char_set.push(ch);
1059                }
1060            }
1061        }
1062        char_set.sort();
1063
1064        let mut vocab = HashMap::new();
1065        let mut id_to_char = HashMap::new();
1066        // Reserve ID 0 for UNK
1067        let unk_id = 0_u32;
1068
1069        for ch in char_set {
1070            let id = (vocab.len() + 1) as u32; // start from 1
1071            vocab.insert(ch, id);
1072            id_to_char.insert(id, ch);
1073        }
1074
1075        Self {
1076            vocab,
1077            id_to_char,
1078            unk_id,
1079        }
1080    }
1081
1082    /// Create from an existing character vocabulary.
1083    pub fn from_vocab(vocab: HashMap<char, u32>, unk_id: u32) -> Self {
1084        let id_to_char: HashMap<u32, char> = vocab.iter().map(|(&c, &id)| (id, c)).collect();
1085        Self {
1086            vocab,
1087            id_to_char,
1088            unk_id,
1089        }
1090    }
1091}
1092
1093impl TransformerTokenizer for SimpleCharTokenizer {
1094    fn encode(&self, text: &str) -> Vec<u32> {
1095        text.chars()
1096            .map(|ch| self.vocab.get(&ch).copied().unwrap_or(self.unk_id))
1097            .collect()
1098    }
1099
1100    fn decode(&self, ids: &[u32]) -> String {
1101        ids.iter()
1102            .filter_map(|&id| self.id_to_char.get(&id).copied())
1103            .collect()
1104    }
1105
1106    fn vocab_size(&self) -> usize {
1107        self.vocab.len() + 1 // +1 for the UNK slot
1108    }
1109}
1110
1111// ===========================================================================
1112// Tests
1113// ===========================================================================
1114
1115#[cfg(test)]
1116mod tests {
1117    use super::*;
1118
1119    // --- BPE tests ---
1120
1121    #[test]
1122    fn test_bpe_train_basic() {
1123        let corpus = &[
1124            "the cat sat on the mat",
1125            "the dog sat on the log",
1126            "the cat and the dog",
1127        ];
1128        let tok = BPETokenizer::train(corpus, 80).expect("train should succeed");
1129        assert!(tok.vocab_size() > 0);
1130        assert!(tok.vocab_size() <= 80);
1131    }
1132
1133    #[test]
1134    fn test_bpe_train_empty_corpus_error() {
1135        let result = BPETokenizer::train(&[], 100);
1136        assert!(result.is_err());
1137    }
1138
1139    #[test]
1140    fn test_bpe_train_small_vocab_error() {
1141        let corpus = &["hello"];
1142        // Vocab size too small (less than special tokens + 1)
1143        let result = BPETokenizer::train(corpus, 3);
1144        assert!(result.is_err());
1145    }
1146
1147    #[test]
1148    fn test_bpe_encode_decode_roundtrip() {
1149        let corpus = &["hello world", "hello there world", "the world is great"];
1150        let tok = BPETokenizer::train(corpus, 100).expect("train");
1151
1152        let text = "hello world";
1153        let ids = tok.encode(text);
1154        assert!(!ids.is_empty());
1155
1156        let decoded = tok.decode(&ids);
1157        // The decoded text should contain the original characters
1158        // (spaces are part of training so they appear as tokens)
1159        assert!(decoded.contains("hello"));
1160        assert!(decoded.contains("world"));
1161    }
1162
1163    #[test]
1164    fn test_bpe_encode_empty() {
1165        let corpus = &["hello world"];
1166        let tok = BPETokenizer::train(corpus, 50).expect("train");
1167        let ids = tok.encode("");
1168        assert!(ids.is_empty());
1169    }
1170
1171    #[test]
1172    fn test_bpe_special_tokens() {
1173        let tok = BPETokenizer::new();
1174        assert!(tok.special_token_id("[PAD]").is_some());
1175        assert!(tok.special_token_id("[UNK]").is_some());
1176        assert!(tok.special_token_id("[CLS]").is_some());
1177        assert!(tok.special_token_id("[SEP]").is_some());
1178        assert!(tok.special_token_id("[MASK]").is_some());
1179    }
1180
1181    #[test]
1182    fn test_bpe_add_special_token() {
1183        let mut tok = BPETokenizer::new();
1184        let id = tok.add_special_token("[BOS]");
1185        assert_eq!(tok.special_token_id("[BOS]"), Some(id));
1186        assert_eq!(tok.vocab_size(), DEFAULT_SPECIAL_TOKENS.len() + 1);
1187    }
1188
1189    #[test]
1190    fn test_bpe_save_load_json() {
1191        let corpus = &["the cat sat on the mat", "the dog sat on the log"];
1192        let tok = BPETokenizer::train(corpus, 60).expect("train");
1193
1194        let dir = std::env::temp_dir();
1195        let path = dir.join("test_bpe_tokenizer.json");
1196
1197        tok.save_json(&path).expect("save");
1198        let loaded = BPETokenizer::load_json(&path).expect("load");
1199
1200        assert_eq!(tok.vocab_size(), loaded.vocab_size());
1201        assert_eq!(tok.merges.len(), loaded.merges.len());
1202
1203        // Encoding should produce same results
1204        let text = "the cat sat";
1205        let ids1 = tok.encode(text);
1206        let ids2 = loaded.encode(text);
1207        assert_eq!(ids1, ids2);
1208
1209        // Clean up
1210        let _ = std::fs::remove_file(&path);
1211    }
1212
1213    #[test]
1214    fn test_bpe_unknown_chars() {
1215        let corpus = &["abc"];
1216        let tok = BPETokenizer::train(corpus, 30).expect("train");
1217
1218        // Encode text with characters not in training corpus
1219        let ids = tok.encode("xyz");
1220        // Unknown chars should map to UNK
1221        let unk = tok.unk_id();
1222        assert!(ids.iter().all(|&id| id == unk));
1223    }
1224
1225    #[test]
1226    fn test_bpe_default_constructor() {
1227        let tok = BPETokenizer::default();
1228        assert_eq!(tok.vocab_size(), DEFAULT_SPECIAL_TOKENS.len());
1229        assert!(tok.merges.is_empty());
1230    }
1231
1232    #[test]
1233    fn test_bpe_vocab_size_trait() {
1234        let corpus = &["hello world hello"];
1235        let tok = BPETokenizer::train(corpus, 50).expect("train");
1236        let trait_ref: &dyn TransformerTokenizer = &tok;
1237        assert!(trait_ref.vocab_size() > 0);
1238    }
1239
1240    // --- WordPiece tests ---
1241
1242    #[test]
1243    fn test_wordpiece_basic() {
1244        let mut vocab = HashMap::new();
1245        vocab.insert("[UNK]".to_string(), 0);
1246        vocab.insert("hello".to_string(), 1);
1247        vocab.insert("world".to_string(), 2);
1248        vocab.insert("hel".to_string(), 3);
1249        vocab.insert("##lo".to_string(), 4);
1250        vocab.insert("wor".to_string(), 5);
1251        vocab.insert("##ld".to_string(), 6);
1252
1253        let tok = WordPieceTokenizer::new(vocab);
1254        let tokens = tok.tokenize("hello world");
1255
1256        // "hello" should be found as-is (full match)
1257        assert!(tokens.contains(&"hello".to_string()));
1258        assert!(tokens.contains(&"world".to_string()));
1259    }
1260
1261    #[test]
1262    fn test_wordpiece_subword_split() {
1263        let mut vocab = HashMap::new();
1264        vocab.insert("[UNK]".to_string(), 0);
1265        vocab.insert("un".to_string(), 1);
1266        vocab.insert("##aff".to_string(), 2);
1267        vocab.insert("##able".to_string(), 3);
1268
1269        let tok = WordPieceTokenizer::new(vocab);
1270        let tokens = tok.tokenize("unaffable");
1271        // "unaffable" -> "un" + "##aff" + "##able"
1272        assert_eq!(tokens, vec!["un", "##aff", "##able"]);
1273    }
1274
1275    #[test]
1276    fn test_wordpiece_unknown_word() {
1277        let mut vocab = HashMap::new();
1278        vocab.insert("[UNK]".to_string(), 0);
1279        vocab.insert("hello".to_string(), 1);
1280
1281        let tok = WordPieceTokenizer::new(vocab);
1282        let tokens = tok.tokenize("xyz");
1283        // "xyz" has no matching subwords -> all characters become UNK
1284        assert!(tokens.contains(&"[UNK]".to_string()));
1285    }
1286
1287    #[test]
1288    fn test_wordpiece_encode_decode() {
1289        let mut vocab = HashMap::new();
1290        vocab.insert("[UNK]".to_string(), 0);
1291        vocab.insert("play".to_string(), 1);
1292        vocab.insert("##ing".to_string(), 2);
1293        vocab.insert("##er".to_string(), 3);
1294        vocab.insert("##s".to_string(), 4);
1295        vocab.insert("the".to_string(), 5);
1296
1297        let tok = WordPieceTokenizer::new(vocab);
1298
1299        let ids = tok.encode("the playing");
1300        assert!(!ids.is_empty());
1301
1302        let decoded = tok.decode(&ids);
1303        assert!(decoded.contains("the"));
1304        assert!(decoded.contains("play"));
1305        assert!(decoded.contains("ing"));
1306    }
1307
1308    #[test]
1309    fn test_wordpiece_max_word_len() {
1310        let mut vocab = HashMap::new();
1311        vocab.insert("[UNK]".to_string(), 0);
1312        vocab.insert("a".to_string(), 1);
1313
1314        let tok = WordPieceTokenizer::new(vocab).with_max_word_len(5);
1315
1316        // Word longer than max_word_len -> UNK
1317        let tokens = tok.tokenize("toolongword");
1318        assert_eq!(tokens, vec!["[UNK]"]);
1319    }
1320
1321    #[test]
1322    fn test_wordpiece_custom_prefix() {
1323        let mut vocab = HashMap::new();
1324        vocab.insert("[UNK]".to_string(), 0);
1325        vocab.insert("hel".to_string(), 1);
1326        vocab.insert("@@lo".to_string(), 2);
1327
1328        let tok = WordPieceTokenizer::new(vocab).with_continuing_prefix("@@");
1329        let tokens = tok.tokenize("hello");
1330        assert_eq!(tokens, vec!["hel", "@@lo"]);
1331    }
1332
1333    #[test]
1334    fn test_wordpiece_from_vocab_file() {
1335        let dir = std::env::temp_dir();
1336        let path = dir.join("test_wp_vocab.txt");
1337
1338        // Write a vocab file
1339        {
1340            let mut f = File::create(&path).expect("create vocab file");
1341            writeln!(f, "[UNK]").expect("write");
1342            writeln!(f, "[PAD]").expect("write");
1343            writeln!(f, "hello").expect("write");
1344            writeln!(f, "world").expect("write");
1345            writeln!(f, "##ing").expect("write");
1346        }
1347
1348        let tok = WordPieceTokenizer::from_vocab_file(&path).expect("load vocab");
1349        assert_eq!(tok.vocab_size(), 5);
1350
1351        let _ = std::fs::remove_file(&path);
1352    }
1353
1354    #[test]
1355    fn test_wordpiece_vocab_size() {
1356        let mut vocab = HashMap::new();
1357        vocab.insert("[UNK]".to_string(), 0);
1358        vocab.insert("a".to_string(), 1);
1359        vocab.insert("b".to_string(), 2);
1360
1361        let tok = WordPieceTokenizer::new(vocab);
1362        assert_eq!(tok.vocab_size(), 3);
1363    }
1364
1365    // --- SimpleWhitespace tests ---
1366
1367    #[test]
1368    fn test_whitespace_build_and_encode() {
1369        let texts = &["hello world", "hello there", "world peace"];
1370        let tok = SimpleWhitespaceTokenizer::build(texts, 100);
1371
1372        let ids = tok.encode("hello world");
1373        assert_eq!(ids.len(), 2);
1374
1375        // hello and world should have different IDs
1376        assert_ne!(ids[0], ids[1]);
1377    }
1378
1379    #[test]
1380    fn test_whitespace_decode() {
1381        let texts = &["hello world", "foo bar"];
1382        let tok = SimpleWhitespaceTokenizer::build(texts, 100);
1383
1384        let ids = tok.encode("hello world");
1385        let decoded = tok.decode(&ids);
1386        assert_eq!(decoded, "hello world");
1387    }
1388
1389    #[test]
1390    fn test_whitespace_unknown_word() {
1391        let texts = &["hello world"];
1392        let tok = SimpleWhitespaceTokenizer::build(texts, 100);
1393
1394        let ids = tok.encode("hello xyz");
1395        // "xyz" is unknown -> maps to unk_id (0)
1396        assert_eq!(ids[1], 0);
1397    }
1398
1399    #[test]
1400    fn test_whitespace_max_vocab_limit() {
1401        let texts = &["a b c d e f g"];
1402        let tok = SimpleWhitespaceTokenizer::build(texts, 4); // 3 words + 1 UNK
1403        assert!(tok.vocab_size() <= 4);
1404    }
1405
1406    #[test]
1407    fn test_whitespace_vocab_size() {
1408        let texts = &["one two three"];
1409        let tok = SimpleWhitespaceTokenizer::build(texts, 100);
1410        // 3 words + 1 UNK = 4
1411        assert_eq!(tok.vocab_size(), 4);
1412    }
1413
1414    // --- SimpleChar tests ---
1415
1416    #[test]
1417    fn test_char_build_and_encode() {
1418        let texts = &["abc", "bcd"];
1419        let tok = SimpleCharTokenizer::build(texts);
1420
1421        let ids = tok.encode("abc");
1422        assert_eq!(ids.len(), 3);
1423        // All IDs should be non-zero (0 is UNK)
1424        assert!(ids.iter().all(|&id| id > 0));
1425    }
1426
1427    #[test]
1428    fn test_char_decode() {
1429        let texts = &["hello"];
1430        let tok = SimpleCharTokenizer::build(texts);
1431
1432        let ids = tok.encode("hello");
1433        let decoded = tok.decode(&ids);
1434        assert_eq!(decoded, "hello");
1435    }
1436
1437    #[test]
1438    fn test_char_unknown_char() {
1439        let texts = &["abc"];
1440        let tok = SimpleCharTokenizer::build(texts);
1441
1442        let ids = tok.encode("xyz");
1443        // All unknown -> UNK id (0)
1444        assert!(ids.iter().all(|&id| id == 0));
1445    }
1446
1447    #[test]
1448    fn test_char_vocab_size() {
1449        let texts = &["ab", "bc"];
1450        let tok = SimpleCharTokenizer::build(texts);
1451        // 3 unique chars (a, b, c) + 1 UNK slot = 4
1452        assert_eq!(tok.vocab_size(), 4);
1453    }
1454
1455    #[test]
1456    fn test_char_roundtrip() {
1457        let texts = &["The quick brown fox!"];
1458        let tok = SimpleCharTokenizer::build(texts);
1459
1460        let original = "The quick brown fox!";
1461        let ids = tok.encode(original);
1462        let decoded = tok.decode(&ids);
1463        assert_eq!(decoded, original);
1464    }
1465
1466    // --- Cross-tokenizer trait tests ---
1467
1468    #[test]
1469    fn test_trait_object_dispatch() {
1470        let corpus = &["hello world hello"];
1471        let bpe = BPETokenizer::train(corpus, 50).expect("train");
1472
1473        let mut vocab = HashMap::new();
1474        vocab.insert("[UNK]".to_string(), 0);
1475        vocab.insert("hello".to_string(), 1);
1476        vocab.insert("world".to_string(), 2);
1477        let wp = WordPieceTokenizer::new(vocab);
1478
1479        let ws_texts = &["hello world"];
1480        let ws = SimpleWhitespaceTokenizer::build(ws_texts, 50);
1481
1482        let char_texts = &["hello world"];
1483        let ch = SimpleCharTokenizer::build(char_texts);
1484
1485        // All implement TransformerTokenizer and can be used as trait objects.
1486        let tokenizers: Vec<&dyn TransformerTokenizer> = vec![&bpe, &wp, &ws, &ch];
1487        for tok in tokenizers {
1488            assert!(tok.vocab_size() > 0);
1489            let ids = tok.encode("hello");
1490            assert!(!ids.is_empty());
1491            let _ = tok.decode(&ids);
1492        }
1493    }
1494
1495    // --- JSON escape/unescape tests ---
1496
1497    #[test]
1498    fn test_json_escape_roundtrip() {
1499        let original = "hello \"world\"\nnewline\\backslash\ttab";
1500        let escaped = escape_json_string(original);
1501        let unescaped = unescape_json_string(&escaped);
1502        assert_eq!(original, unescaped);
1503    }
1504
1505    #[test]
1506    fn test_bpe_multiple_sentences() {
1507        let corpus = &[
1508            "machine learning is transforming the world",
1509            "deep learning models use transformers",
1510            "natural language processing with transformers",
1511            "the transformer architecture is powerful",
1512        ];
1513        let tok = BPETokenizer::train(corpus, 120).expect("train");
1514
1515        let text = "learning transformers";
1516        let ids = tok.encode(text);
1517        assert!(!ids.is_empty());
1518
1519        // Verify no UNK tokens for known text
1520        let unk = tok.unk_id();
1521        // Characters in "learning transformers" should all be in training data
1522        assert!(ids.iter().all(|&id| id != unk));
1523    }
1524
1525    #[test]
1526    fn test_bpe_merges_reduce_token_count() {
1527        let corpus = &["aaaa aaaa aaaa aaaa aaaa", "aaaa aaaa aaaa aaaa aaaa"];
1528        let tok = BPETokenizer::train(corpus, 50).expect("train");
1529
1530        // "aaaa" should be encoded in fewer tokens than 4 chars
1531        // because BPE should merge "a"+"a" -> "aa", etc.
1532        let ids = tok.encode("aaaa");
1533        assert!(
1534            ids.len() < 4,
1535            "BPE should merge repeated chars: got {} tokens",
1536            ids.len()
1537        );
1538    }
1539
1540    #[test]
1541    fn test_wordpiece_empty_input() {
1542        let mut vocab = HashMap::new();
1543        vocab.insert("[UNK]".to_string(), 0);
1544        let tok = WordPieceTokenizer::new(vocab);
1545
1546        let tokens = tok.tokenize("");
1547        assert!(tokens.is_empty());
1548
1549        let ids = tok.encode("");
1550        assert!(ids.is_empty());
1551
1552        let decoded = tok.decode(&[]);
1553        assert!(decoded.is_empty());
1554    }
1555}