Skip to main content

scirs2_text/tokenizers/
bert.rs

1//! BERT-style WordPiece tokenizer with special tokens.
2//!
3//! This module provides a full BERT/RoBERTa-compatible tokenizer implementing:
4//! - Basic tokenization (whitespace + punctuation splitting, optional lowercasing)
5//! - WordPiece subword segmentation (greedy longest-match with `##` continuations)
6//! - Special token management: `[CLS]`, `[SEP]`, `[MASK]`, `[PAD]`, `[UNK]`
7//! - Single- and pair-sentence encoding with token type IDs
8//! - Batched encoding with padding and truncation
9
10use crate::error::{Result, TextError};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufRead, BufReader};
14use unicode_normalization::UnicodeNormalization;
15
16// ─── Constants ────────────────────────────────────────────────────────────────
17
18const CLS_TOKEN: &str = "[CLS]";
19const SEP_TOKEN: &str = "[SEP]";
20const PAD_TOKEN: &str = "[PAD]";
21const MASK_TOKEN: &str = "[MASK]";
22const UNK_TOKEN: &str = "[UNK]";
23
24/// Maximum characters in a single word before falling back to `[UNK]`.
25const MAX_WORD_CHARS: usize = 200;
26
27// ─── BertEncoding ─────────────────────────────────────────────────────────────
28
29/// A single encoded BERT input sequence.
30///
31/// Contains the input IDs, an attention mask (1 = real token, 0 = padding),
32/// and token type IDs (0 = first segment, 1 = second segment).
33#[derive(Debug, Clone, PartialEq)]
34pub struct BertEncoding {
35    /// Token IDs including special tokens.
36    pub input_ids: Vec<u32>,
37    /// Attention mask: `1` for real tokens, `0` for padding.
38    pub attention_mask: Vec<u32>,
39    /// Segment indicator: `0` for text_a, `1` for text_b.
40    pub token_type_ids: Vec<u32>,
41}
42
43impl BertEncoding {
44    /// Create a new encoding with consistent lengths.
45    pub fn new(input_ids: Vec<u32>, attention_mask: Vec<u32>, token_type_ids: Vec<u32>) -> Self {
46        BertEncoding {
47            input_ids,
48            attention_mask,
49            token_type_ids,
50        }
51    }
52
53    /// Returns the sequence length (number of tokens including padding).
54    pub fn len(&self) -> usize {
55        self.input_ids.len()
56    }
57
58    /// Returns `true` if there are no tokens.
59    pub fn is_empty(&self) -> bool {
60        self.input_ids.is_empty()
61    }
62}
63
64// ─── BatchEncoding ────────────────────────────────────────────────────────────
65
66/// A batch of [`BertEncoding`] instances with consistent (padded) lengths.
67#[derive(Debug, Clone)]
68pub struct BatchEncoding {
69    /// Individual encodings, all of the same length.
70    pub encodings: Vec<BertEncoding>,
71}
72
73impl BatchEncoding {
74    /// Create from a vector of encodings.
75    pub fn new(encodings: Vec<BertEncoding>) -> Self {
76        BatchEncoding { encodings }
77    }
78
79    /// Number of sequences in the batch.
80    pub fn len(&self) -> usize {
81        self.encodings.len()
82    }
83
84    /// Returns `true` if the batch is empty.
85    pub fn is_empty(&self) -> bool {
86        self.encodings.is_empty()
87    }
88
89    /// Collect all `input_ids` into a 2-D vector `[batch, seq_len]`.
90    pub fn input_ids(&self) -> Vec<Vec<u32>> {
91        self.encodings.iter().map(|e| e.input_ids.clone()).collect()
92    }
93
94    /// Collect all `attention_mask` into a 2-D vector.
95    pub fn attention_masks(&self) -> Vec<Vec<u32>> {
96        self.encodings
97            .iter()
98            .map(|e| e.attention_mask.clone())
99            .collect()
100    }
101
102    /// Collect all `token_type_ids` into a 2-D vector.
103    pub fn token_type_ids(&self) -> Vec<Vec<u32>> {
104        self.encodings
105            .iter()
106            .map(|e| e.token_type_ids.clone())
107            .collect()
108    }
109}
110
111// ─── BasicTokenizer ───────────────────────────────────────────────────────────
112
113/// Whitespace + punctuation tokenizer (BERT pre-tokenization step).
114///
115/// Optionally lowercases and strips combining Unicode marks (accents).
116#[derive(Debug, Clone)]
117struct BasicTokenizer {
118    do_lower_case: bool,
119}
120
121impl BasicTokenizer {
122    fn new(do_lower_case: bool) -> Self {
123        BasicTokenizer { do_lower_case }
124    }
125
126    fn tokenize(&self, text: &str) -> Vec<String> {
127        let text = if self.do_lower_case {
128            text.to_lowercase()
129        } else {
130            text.to_string()
131        };
132
133        // Strip combining characters by NFD decomposition.
134        let text: String = text.nfd().filter(|c| !is_combining_mark(*c)).collect();
135
136        // Insert spaces around punctuation and CJK characters.
137        let mut spaced = String::with_capacity(text.len() + 32);
138        for ch in text.chars() {
139            if ch.is_whitespace() {
140                spaced.push(' ');
141            } else if is_punctuation_char(ch) || is_chinese_char(ch) {
142                spaced.push(' ');
143                spaced.push(ch);
144                spaced.push(' ');
145            } else {
146                spaced.push(ch);
147            }
148        }
149
150        spaced
151            .split_whitespace()
152            .filter(|s| !s.is_empty())
153            .map(|s| s.to_string())
154            .collect()
155    }
156}
157
158/// Returns `true` for Unicode combining marks (NFD non-spacing marks).
159fn is_combining_mark(ch: char) -> bool {
160    let cp = ch as u32;
161    (0x0300..=0x036F).contains(&cp)
162        || (0x1DC0..=0x1DFF).contains(&cp)
163        || (0x1AB0..=0x1AFF).contains(&cp)
164        || (0x20D0..=0x20FF).contains(&cp)
165}
166
167/// Returns `true` for ASCII punctuation and common Unicode punctuation.
168fn is_punctuation_char(ch: char) -> bool {
169    let cp = ch as u32;
170    // ASCII control/punctuation ranges
171    if cp <= 47 || (58..=64).contains(&cp) || (91..=96).contains(&cp) || (123..=126).contains(&cp) {
172        return true;
173    }
174    // Unicode punctuation categories (approximate)
175    ch.is_ascii_punctuation()
176        || matches!(
177            ch,
178            '。' | ','
179                | '、'
180                | ';'
181                | ':'
182                | '?'
183                | '!'
184                | '—'
185                | '…'
186                | '\u{2018}'
187                | '\u{2019}'
188                | '\u{201C}'
189                | '\u{201D}'
190        )
191}
192
193/// Returns `true` for CJK Unified Ideographs.
194fn is_chinese_char(ch: char) -> bool {
195    let cp = ch as u32;
196    (0x4E00..=0x9FFF).contains(&cp)
197        || (0x3400..=0x4DBF).contains(&cp)
198        || (0x20000..=0x2A6DF).contains(&cp)
199        || (0x2A700..=0x2B73F).contains(&cp)
200        || (0x2B740..=0x2B81F).contains(&cp)
201        || (0x2B820..=0x2CEAF).contains(&cp)
202        || (0xF900..=0xFAFF).contains(&cp)
203        || (0x2F800..=0x2FA1F).contains(&cp)
204}
205
206// ─── WordPiece helper ─────────────────────────────────────────────────────────
207
208/// Greedy longest-match WordPiece segmentation for a single `word`.
209///
210/// Returns a list of subword strings; continuation pieces are prefixed with
211/// `"##"`.  If the word cannot be fully segmented, returns `["[UNK]"]`.
212fn wordpiece_segment(word: &str, vocab: &HashMap<String, u32>) -> Vec<String> {
213    let chars: Vec<char> = word.chars().collect();
214    if chars.len() > MAX_WORD_CHARS {
215        return vec![UNK_TOKEN.to_string()];
216    }
217
218    let n = chars.len();
219    let mut sub_tokens: Vec<String> = Vec::new();
220    let mut start = 0usize;
221
222    while start < n {
223        let mut end = n;
224        let mut found_tok: Option<String> = None;
225
226        while start < end {
227            let substr: String = chars[start..end].iter().collect();
228            let candidate = if start == 0 {
229                substr.clone()
230            } else {
231                format!("##{}", substr)
232            };
233
234            if vocab.contains_key(&candidate) {
235                found_tok = Some(candidate);
236                break;
237            }
238
239            if end == start + 1 {
240                // Single character not in vocab: whole word is unknown.
241                return vec![UNK_TOKEN.to_string()];
242            }
243            end -= 1;
244        }
245
246        match found_tok {
247            Some(tok) => {
248                sub_tokens.push(tok);
249                start = end;
250            }
251            None => {
252                return vec![UNK_TOKEN.to_string()];
253            }
254        }
255    }
256
257    sub_tokens
258}
259
260// ─── BertTokenizer ────────────────────────────────────────────────────────────
261
262/// BERT-style tokenizer combining basic tokenization and WordPiece subword
263/// segmentation.
264///
265/// Special tokens:
266/// - `[CLS]` (classification): prepended to every encoded sequence
267/// - `[SEP]` (separator): appended after each segment
268/// - `[MASK]` (masking): placeholder for masked-language-model pre-training
269/// - `[PAD]` (padding): used to fill sequences to a target length
270/// - `[UNK]` (unknown): substituted for tokens not present in the vocabulary
271///
272/// # Example
273///
274/// ```rust
275/// use std::collections::HashMap;
276/// use scirs2_text::tokenizers::bert::BertTokenizer;
277///
278/// let mut vocab: HashMap<String, u32> = HashMap::new();
279/// for (i, tok) in ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]",
280///                   "hello","world","##ing","play","##ed"].iter().enumerate() {
281///     vocab.insert(tok.to_string(), i as u32);
282/// }
283/// let tokenizer = BertTokenizer::new(vocab, true);
284/// let ids = tokenizer.encode("Hello World").unwrap();
285/// assert_eq!(ids[0], tokenizer.cls_token_id());
286/// ```
287#[derive(Debug, Clone)]
288pub struct BertTokenizer {
289    vocab: HashMap<String, u32>,
290    ids_to_tokens: HashMap<u32, String>,
291    cls_token_id: u32,
292    sep_token_id: u32,
293    pad_token_id: u32,
294    mask_token_id: u32,
295    unk_token_id: u32,
296    max_len: usize,
297    lowercase: bool,
298    basic: BasicTokenizer,
299}
300
301impl BertTokenizer {
302    // ── Construction ──────────────────────────────────────────────────────
303
304    /// Build a `BertTokenizer` from a `token → id` vocabulary map.
305    ///
306    /// All five special tokens (`[PAD]`, `[UNK]`, `[CLS]`, `[SEP]`, `[MASK]`)
307    /// are inserted into the vocabulary if absent.
308    pub fn new(mut vocab: HashMap<String, u32>, lowercase: bool) -> Self {
309        // Ensure all required special tokens exist.
310        let specials = [PAD_TOKEN, UNK_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN];
311        for tok in &specials {
312            if !vocab.contains_key(*tok) {
313                let next_id = vocab.len() as u32;
314                vocab.insert(tok.to_string(), next_id);
315            }
316        }
317
318        let cls_token_id = vocab[CLS_TOKEN];
319        let sep_token_id = vocab[SEP_TOKEN];
320        let pad_token_id = vocab[PAD_TOKEN];
321        let mask_token_id = vocab[MASK_TOKEN];
322        let unk_token_id = vocab[UNK_TOKEN];
323
324        let ids_to_tokens: HashMap<u32, String> =
325            vocab.iter().map(|(k, &v)| (v, k.clone())).collect();
326
327        BertTokenizer {
328            vocab,
329            ids_to_tokens,
330            cls_token_id,
331            sep_token_id,
332            pad_token_id,
333            mask_token_id,
334            unk_token_id,
335            max_len: 512,
336            lowercase,
337            basic: BasicTokenizer::new(lowercase),
338        }
339    }
340
341    /// Load a tokenizer from a `vocab.txt` file (one token per line; line
342    /// index = token ID, 0-based).
343    ///
344    /// Returns an error if the file cannot be read or if the resulting
345    /// vocabulary is missing required special tokens after auto-insertion.
346    pub fn from_vocab_file(path: &str) -> Result<Self> {
347        let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
348        let reader = BufReader::new(file);
349
350        let mut vocab = HashMap::new();
351        for (idx, line) in reader.lines().enumerate() {
352            let token = line.map_err(|e| TextError::IoError(e.to_string()))?;
353            let token = token.trim().to_string();
354            if !token.is_empty() {
355                vocab.insert(token, idx as u32);
356            }
357        }
358
359        if vocab.is_empty() {
360            return Err(TextError::VocabularyError(
361                "Vocabulary file is empty".to_string(),
362            ));
363        }
364
365        Ok(Self::new(vocab, true))
366    }
367
368    /// Override the maximum sequence length (default 512).
369    pub fn with_max_len(mut self, max_len: usize) -> Self {
370        self.max_len = max_len;
371        self
372    }
373
374    // ── Accessors ─────────────────────────────────────────────────────────
375
376    /// Returns the `[CLS]` token ID.
377    pub fn cls_token_id(&self) -> u32 {
378        self.cls_token_id
379    }
380
381    /// Returns the `[SEP]` token ID.
382    pub fn sep_token_id(&self) -> u32 {
383        self.sep_token_id
384    }
385
386    /// Returns the `[PAD]` token ID.
387    pub fn pad_token_id(&self) -> u32 {
388        self.pad_token_id
389    }
390
391    /// Returns the `[MASK]` token ID.
392    pub fn mask_token_id(&self) -> u32 {
393        self.mask_token_id
394    }
395
396    /// Returns the `[UNK]` token ID.
397    pub fn unk_token_id(&self) -> u32 {
398        self.unk_token_id
399    }
400
401    /// Vocabulary size.
402    pub fn vocab_size(&self) -> usize {
403        self.vocab.len()
404    }
405
406    /// Return a reference to the full `token → id` vocabulary map.
407    pub fn vocab(&self) -> &HashMap<String, u32> {
408        &self.vocab
409    }
410
411    /// Return whether this tokenizer lowercases input text.
412    pub fn lowercase(&self) -> bool {
413        self.lowercase
414    }
415
416    // ── Core tokenization ─────────────────────────────────────────────────
417
418    /// Tokenize `text` into a list of subword strings.
419    ///
420    /// Applies basic tokenization (whitespace + punctuation split, optional
421    /// lowercasing) followed by WordPiece subword segmentation.  Unknown
422    /// characters/words map to `"[UNK]"`.
423    pub fn tokenize(&self, text: &str) -> Vec<String> {
424        if text.is_empty() {
425            return Vec::new();
426        }
427        let words = self.basic.tokenize(text);
428        words
429            .iter()
430            .flat_map(|w| wordpiece_segment(w, &self.vocab))
431            .collect()
432    }
433
434    /// Convert a token string to its vocabulary ID.
435    fn token_to_id(&self, token: &str) -> u32 {
436        self.vocab.get(token).copied().unwrap_or(self.unk_token_id)
437    }
438
439    // ── Encoding API ──────────────────────────────────────────────────────
440
441    /// Encode a single text segment as `[CLS] tokens [SEP]`.
442    ///
443    /// Returns the flat sequence of token IDs.  Use `encode_pair` for
444    /// two-segment inputs (e.g. question + context).
445    pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
446        let sub_tokens = self.tokenize(text);
447        let mut ids = Vec::with_capacity(sub_tokens.len() + 2);
448        ids.push(self.cls_token_id);
449        ids.extend(sub_tokens.iter().map(|t| self.token_to_id(t)));
450        ids.push(self.sep_token_id);
451        Ok(ids)
452    }
453
454    /// Encode a pair of text segments (e.g. sentence A and sentence B).
455    ///
456    /// Layout: `[CLS] A-tokens [SEP] B-tokens [SEP]`
457    ///
458    /// Returns `(token_ids, token_type_ids)` where `token_type_ids[i]` is `0`
459    /// for the first segment and `1` for the second.
460    pub fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<(Vec<u32>, Vec<u32>)> {
461        let tokens_a = self.tokenize(text_a);
462        let tokens_b = self.tokenize(text_b);
463
464        let total = 1 + tokens_a.len() + 1 + tokens_b.len() + 1; // [CLS]+A+[SEP]+B+[SEP]
465        let mut ids = Vec::with_capacity(total);
466        let mut type_ids = Vec::with_capacity(total);
467
468        // Segment A: [CLS] ... [SEP]  → type 0
469        ids.push(self.cls_token_id);
470        type_ids.push(0u32);
471
472        for tok in &tokens_a {
473            ids.push(self.token_to_id(tok));
474            type_ids.push(0);
475        }
476
477        ids.push(self.sep_token_id);
478        type_ids.push(0);
479
480        // Segment B: ... [SEP]  → type 1
481        for tok in &tokens_b {
482            ids.push(self.token_to_id(tok));
483            type_ids.push(1);
484        }
485
486        ids.push(self.sep_token_id);
487        type_ids.push(1);
488
489        Ok((ids, type_ids))
490    }
491
492    /// Build a single [`BertEncoding`] for `text`, with optional padding and
493    /// truncation to `max_length`.
494    ///
495    /// If `padding` is `true`, short sequences are padded with `[PAD]` to
496    /// reach `max_length`.  If `truncation` is `true`, long sequences are
497    /// trimmed (preserving `[CLS]` and `[SEP]`).
498    pub fn encode_single(
499        &self,
500        text: &str,
501        max_length: usize,
502        padding: bool,
503        truncation: bool,
504    ) -> Result<BertEncoding> {
505        if max_length == 0 {
506            return Err(TextError::InvalidInput(
507                "max_length must be greater than 0".to_string(),
508            ));
509        }
510
511        let sub_tokens = self.tokenize(text);
512        // Budget for content tokens: max_length - [CLS] - [SEP]
513        let budget = max_length.saturating_sub(2);
514
515        let content: Vec<u32> = if truncation && sub_tokens.len() > budget {
516            sub_tokens[..budget]
517                .iter()
518                .map(|t| self.token_to_id(t))
519                .collect()
520        } else {
521            sub_tokens.iter().map(|t| self.token_to_id(t)).collect()
522        };
523
524        let mut ids = Vec::with_capacity(max_length);
525        ids.push(self.cls_token_id);
526        ids.extend_from_slice(&content);
527        ids.push(self.sep_token_id);
528
529        let real_len = ids.len();
530
531        if padding && ids.len() < max_length {
532            let pad_count = max_length - ids.len();
533            ids.extend(std::iter::repeat_n(self.pad_token_id, pad_count));
534        }
535
536        let seq_len = ids.len();
537        let mut mask = vec![0u32; seq_len];
538        for m in mask.iter_mut().take(real_len) {
539            *m = 1;
540        }
541        let type_ids = vec![0u32; seq_len];
542
543        Ok(BertEncoding::new(ids, mask, type_ids))
544    }
545
546    /// Encode a batch of texts with consistent sequence length.
547    ///
548    /// When `padding` is `true`, all sequences in the batch are padded to the
549    /// longest (or to `max_length`, whichever is smaller).  When `truncation`
550    /// is `true`, sequences exceeding `max_length` are truncated.
551    pub fn encode_batch(
552        &self,
553        texts: &[&str],
554        max_length: usize,
555        padding: bool,
556        truncation: bool,
557    ) -> Result<BatchEncoding> {
558        if max_length == 0 {
559            return Err(TextError::InvalidInput(
560                "max_length must be greater than 0".to_string(),
561            ));
562        }
563
564        // First pass: build raw ids for each text (before padding).
565        let mut raw_encodings: Vec<(Vec<u32>, usize)> = Vec::with_capacity(texts.len());
566
567        for text in texts {
568            let sub_tokens = self.tokenize(text);
569            let budget = max_length.saturating_sub(2);
570            let content: Vec<u32> = if truncation && sub_tokens.len() > budget {
571                sub_tokens[..budget]
572                    .iter()
573                    .map(|t| self.token_to_id(t))
574                    .collect()
575            } else {
576                sub_tokens.iter().map(|t| self.token_to_id(t)).collect()
577            };
578
579            let mut ids = Vec::with_capacity(content.len() + 2);
580            ids.push(self.cls_token_id);
581            ids.extend_from_slice(&content);
582            ids.push(self.sep_token_id);
583            let real_len = ids.len();
584            raw_encodings.push((ids, real_len));
585        }
586
587        // Determine target length for padding.
588        let target_len = if padding {
589            let max_real = raw_encodings
590                .iter()
591                .map(|(ids, _)| ids.len())
592                .max()
593                .unwrap_or(0);
594            max_real.min(max_length)
595        } else {
596            max_length
597        };
598
599        // Second pass: pad and build BertEncoding.
600        let encodings = raw_encodings
601            .into_iter()
602            .map(|(mut ids, real_len)| {
603                if padding && ids.len() < target_len {
604                    let pad_count = target_len - ids.len();
605                    ids.extend(std::iter::repeat_n(self.pad_token_id, pad_count));
606                }
607
608                let seq_len = ids.len();
609                let mut mask = vec![0u32; seq_len];
610                for m in mask.iter_mut().take(real_len) {
611                    *m = 1;
612                }
613                let type_ids = vec![0u32; seq_len];
614                BertEncoding::new(ids, mask, type_ids)
615            })
616            .collect();
617
618        Ok(BatchEncoding::new(encodings))
619    }
620
621    // ── Decoding ──────────────────────────────────────────────────────────
622
623    /// Decode a sequence of token IDs back to a human-readable string.
624    ///
625    /// Special tokens (`[CLS]`, `[SEP]`, `[PAD]`, `[MASK]`) are skipped.
626    /// WordPiece continuation tokens (prefixed with `##`) are merged directly
627    /// onto the preceding piece without a space.
628    pub fn decode(&self, ids: &[u32]) -> String {
629        let special_ids: [u32; 4] = [
630            self.cls_token_id,
631            self.sep_token_id,
632            self.pad_token_id,
633            self.mask_token_id,
634        ];
635
636        let mut out = String::new();
637        for &id in ids {
638            if special_ids.contains(&id) {
639                continue;
640            }
641
642            let tok = match self.ids_to_tokens.get(&id) {
643                Some(t) => t.as_str(),
644                None => UNK_TOKEN,
645            };
646
647            if tok == UNK_TOKEN {
648                if !out.is_empty() {
649                    out.push(' ');
650                }
651                out.push_str(tok);
652                continue;
653            }
654
655            if let Some(cont) = tok.strip_prefix("##") {
656                // Continuation: append directly (no space).
657                out.push_str(cont);
658            } else {
659                if !out.is_empty() {
660                    out.push(' ');
661                }
662                out.push_str(tok);
663            }
664        }
665        out
666    }
667
668    /// Convert token string to its ID (exposed for testing / downstream use).
669    pub fn convert_token_to_id(&self, token: &str) -> Option<u32> {
670        self.vocab.get(token).copied()
671    }
672
673    /// Convert token ID to its string representation.
674    pub fn convert_id_to_token(&self, id: u32) -> Option<&str> {
675        self.ids_to_tokens.get(&id).map(|s| s.as_str())
676    }
677}
678
679// ─── Tests ────────────────────────────────────────────────────────────────────
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use std::collections::HashMap;
685
686    // ── Vocabulary helpers ──────────────────────────────────────────────
687
688    /// Minimal vocabulary with special tokens and a handful of content tokens.
689    fn base_vocab() -> HashMap<String, u32> {
690        let tokens = [
691            "[PAD]",   // 0
692            "[UNK]",   // 1
693            "[CLS]",   // 2
694            "[SEP]",   // 3
695            "[MASK]",  // 4
696            "hello",   // 5
697            "world",   // 6
698            "play",    // 7
699            "##ing",   // 8
700            "##ed",    // 9
701            "good",    // 10
702            "morning", // 11
703            "the",     // 12
704            "quick",   // 13
705            "brown",   // 14
706            "fox",     // 15
707            ",",       // 16
708            "!",       // 17
709        ];
710        tokens
711            .iter()
712            .enumerate()
713            .map(|(i, t)| (t.to_string(), i as u32))
714            .collect()
715    }
716
717    fn make_tokenizer() -> BertTokenizer {
718        BertTokenizer::new(base_vocab(), true)
719    }
720
721    // ── Test: basic tokenization ────────────────────────────────────────
722
723    #[test]
724    fn test_bert_tokenize_basic() {
725        let tok = make_tokenizer();
726        let tokens = tok.tokenize("Hello, World!");
727        // lowercase=true → "hello" and "world"
728        assert!(
729            tokens.contains(&"hello".to_string()),
730            "expected 'hello' in {:?}",
731            tokens
732        );
733        assert!(
734            tokens.contains(&"world".to_string()),
735            "expected 'world' in {:?}",
736            tokens
737        );
738        // punctuation should be isolated tokens
739        assert!(
740            tokens.contains(&",".to_string()),
741            "expected ',' in {:?}",
742            tokens
743        );
744        assert!(
745            tokens.contains(&"!".to_string()),
746            "expected '!' in {:?}",
747            tokens
748        );
749    }
750
751    // ── Test: special tokens added in encode ────────────────────────────
752
753    #[test]
754    fn test_bert_special_tokens() {
755        let tok = make_tokenizer();
756        let ids = tok.encode("hello world").expect("encode failed");
757        // [CLS] at start, [SEP] at end
758        assert_eq!(ids[0], tok.cls_token_id(), "first token should be [CLS]");
759        assert_eq!(
760            *ids.last().expect("non-empty"),
761            tok.sep_token_id(),
762            "last token should be [SEP]"
763        );
764    }
765
766    // ── Test: WordPiece subword segmentation ────────────────────────────
767
768    #[test]
769    fn test_bert_wordpiece() {
770        let tok = make_tokenizer();
771        // "playing" → play + ##ing
772        let tokens = tok.tokenize("playing");
773        assert_eq!(tokens, vec!["play", "##ing"]);
774    }
775
776    // ── Test: unknown tokens ────────────────────────────────────────────
777
778    #[test]
779    fn test_bert_unknown() {
780        let tok = make_tokenizer();
781        // "xyzzy" is not in vocab; should map to [UNK]
782        let ids = tok.encode("xyzzy").expect("encode failed");
783        // [CLS] [UNK] [SEP]
784        assert_eq!(ids.len(), 3);
785        assert_eq!(ids[1], tok.unk_token_id(), "OOV token should map to [UNK]");
786    }
787
788    // ── Test: pair encoding ─────────────────────────────────────────────
789
790    #[test]
791    fn test_bert_encode_pair() {
792        let tok = make_tokenizer();
793        let (ids, type_ids) = tok
794            .encode_pair("hello", "world")
795            .expect("encode_pair failed");
796        // Layout: [CLS](0) hello(0) [SEP](0) world(1) [SEP](1)
797        assert_eq!(ids[0], tok.cls_token_id());
798        // type_id for [CLS] should be 0
799        assert_eq!(type_ids[0], 0);
800        // Last token should be [SEP]
801        assert_eq!(*ids.last().expect("non-empty"), tok.sep_token_id());
802        // Last type_id should be 1 (segment B)
803        assert_eq!(*type_ids.last().expect("non-empty"), 1);
804
805        // Verify segment boundary: find first SEP
806        let first_sep_pos = ids
807            .iter()
808            .position(|&id| id == tok.sep_token_id())
809            .expect("has SEP");
810        // Everything up to and including first SEP is type 0
811        for i in 0..=first_sep_pos {
812            assert_eq!(type_ids[i], 0, "position {} should be type 0", i);
813        }
814        // Everything after first SEP is type 1
815        for i in (first_sep_pos + 1)..type_ids.len() {
816            assert_eq!(type_ids[i], 1, "position {} should be type 1", i);
817        }
818    }
819
820    // ── Test: decode skips special tokens ───────────────────────────────
821
822    #[test]
823    fn test_bert_decode_skips_special() {
824        let tok = make_tokenizer();
825        let ids = tok.encode("hello world").expect("encode failed");
826        let decoded = tok.decode(&ids);
827        // [CLS] and [SEP] should not appear in decoded output
828        assert!(
829            !decoded.contains("[CLS]"),
830            "decoded should not contain [CLS]: {:?}",
831            decoded
832        );
833        assert!(
834            !decoded.contains("[SEP]"),
835            "decoded should not contain [SEP]: {:?}",
836            decoded
837        );
838        assert!(
839            decoded.contains("hello"),
840            "decoded should contain 'hello': {:?}",
841            decoded
842        );
843        assert!(
844            decoded.contains("world"),
845            "decoded should contain 'world': {:?}",
846            decoded
847        );
848    }
849
850    // ── Test: batch padding ─────────────────────────────────────────────
851
852    #[test]
853    fn test_bert_batch_padding() {
854        let tok = make_tokenizer();
855        let texts = vec!["hello", "hello world"];
856        let batch = tok
857            .encode_batch(&texts, 10, true, false)
858            .expect("encode_batch failed");
859
860        assert_eq!(batch.len(), 2);
861        // All sequences must have the same length after padding
862        let len0 = batch.encodings[0].len();
863        let len1 = batch.encodings[1].len();
864        assert_eq!(len0, len1, "padded lengths must be equal");
865
866        // The shorter sequence should have padding tokens
867        let short_enc = &batch.encodings[0];
868        let has_pad = short_enc
869            .input_ids
870            .iter()
871            .any(|&id| id == tok.pad_token_id());
872        let longer_real = batch.encodings[1]
873            .attention_mask
874            .iter()
875            .filter(|&&m| m == 1)
876            .count();
877        let shorter_real = batch.encodings[0]
878            .attention_mask
879            .iter()
880            .filter(|&&m| m == 1)
881            .count();
882        assert!(
883            has_pad,
884            "shorter sequence should have padding; ids={:?}",
885            short_enc.input_ids
886        );
887        assert!(
888            shorter_real < longer_real,
889            "shorter text should have fewer real tokens"
890        );
891        // Padding positions should have attention_mask = 0
892        for (id, mask) in short_enc
893            .input_ids
894            .iter()
895            .zip(short_enc.attention_mask.iter())
896        {
897            if *id == tok.pad_token_id() {
898                assert_eq!(*mask, 0, "padding token must have mask 0");
899            }
900        }
901    }
902
903    // ── Test: batch truncation ──────────────────────────────────────────
904
905    #[test]
906    fn test_bert_batch_truncation() {
907        let tok = make_tokenizer();
908        // "the quick brown fox" has 4 real tokens; max_length=4 forces truncation
909        // (budget: 4 - 2 = 2 content tokens)
910        let texts = vec!["the quick brown fox"];
911        let batch = tok
912            .encode_batch(&texts, 4, false, true)
913            .expect("encode_batch failed");
914
915        let enc = &batch.encodings[0];
916        // ids: [CLS] + up-to-2 content tokens + [SEP] = 4
917        assert_eq!(enc.input_ids.len(), 4);
918        assert_eq!(enc.input_ids[0], tok.cls_token_id());
919        assert_eq!(
920            *enc.input_ids.last().expect("non-empty"),
921            tok.sep_token_id()
922        );
923    }
924
925    // ── Test: lowercase folding ──────────────────────────────────────────
926
927    #[test]
928    fn test_bert_lowercase() {
929        let tok_lower = BertTokenizer::new(base_vocab(), true);
930        let tok_cased = BertTokenizer::new(base_vocab(), false);
931
932        // With lowercase=true, "HELLO" should hit "hello" in vocab
933        let lower_tokens = tok_lower.tokenize("HELLO");
934        assert!(
935            lower_tokens.contains(&"hello".to_string()),
936            "lowercase should map HELLO→hello: {:?}",
937            lower_tokens
938        );
939
940        // With lowercase=false, "HELLO" is not in vocab → [UNK]
941        let cased_tokens = tok_cased.tokenize("HELLO");
942        assert!(
943            cased_tokens.contains(&"[UNK]".to_string()),
944            "cased tokenizer should map HELLO to [UNK]: {:?}",
945            cased_tokens
946        );
947    }
948
949    // ── Test: build from in-memory vocab ─────────────────────────────────
950
951    #[test]
952    fn test_bert_from_vocab_string() {
953        // Build vocab directly from a list of strings (simulating a vocab.txt load)
954        let token_list: &[&str] = &[
955            "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "rust", "is", "great",
956        ];
957        let vocab: HashMap<String, u32> = token_list
958            .iter()
959            .enumerate()
960            .map(|(i, t)| (t.to_string(), i as u32))
961            .collect();
962        let tokenizer = BertTokenizer::new(vocab, true);
963        let ids = tokenizer.encode("rust is great").expect("encode failed");
964        // [CLS] rust is great [SEP] = 5 tokens
965        assert_eq!(ids.len(), 5);
966        assert_eq!(ids[0], tokenizer.cls_token_id());
967    }
968
969    // ── Test: empty input ────────────────────────────────────────────────
970
971    #[test]
972    fn test_bert_empty_input() {
973        let tok = make_tokenizer();
974        let ids = tok.encode("").expect("encode empty");
975        // [CLS] [SEP]
976        assert_eq!(ids.len(), 2);
977        assert_eq!(ids[0], tok.cls_token_id());
978        assert_eq!(ids[1], tok.sep_token_id());
979    }
980
981    // ── Test: all-OOV input ───────────────────────────────────────────────
982
983    #[test]
984    fn test_bert_all_oov() {
985        let tok = make_tokenizer();
986        // Three OOV words, each → [UNK]
987        let ids = tok.encode("zzz yyy xxx").expect("encode all-OOV");
988        // [CLS] [UNK] [UNK] [UNK] [SEP]
989        assert_eq!(ids.len(), 5);
990        for &id in &ids[1..4] {
991            assert_eq!(id, tok.unk_token_id());
992        }
993    }
994
995    // ── Test: max_len=1 edge case ─────────────────────────────────────────
996
997    #[test]
998    fn test_bert_max_len_one_truncation() {
999        let tok = make_tokenizer();
1000        // With max_length=1, budget=0 content tokens (1 - 2 saturates to 0).
1001        // The encoder always emits [CLS] + content + [SEP]; with zero content
1002        // budget the result is [CLS] [SEP] (length 2), which exceeds max_length=1
1003        // but is the minimal valid BERT sequence.  The implementation does not
1004        // truncate the mandatory special tokens themselves.
1005        let enc = tok
1006            .encode_single("hello world", 1, false, true)
1007            .expect("encode_single");
1008        // Always at least [CLS] + [SEP].
1009        assert!(
1010            enc.input_ids.len() >= 2,
1011            "must contain at least [CLS] and [SEP]"
1012        );
1013        assert_eq!(enc.input_ids[0], tok.cls_token_id());
1014        assert_eq!(
1015            *enc.input_ids.last().expect("non-empty"),
1016            tok.sep_token_id()
1017        );
1018        // No content tokens (budget was 0).
1019        assert_eq!(enc.input_ids.len(), 2, "only [CLS] and [SEP] expected");
1020    }
1021
1022    // ── Test: decode WordPiece continuations ─────────────────────────────
1023
1024    #[test]
1025    fn test_bert_decode_wordpiece_merge() {
1026        let tok = make_tokenizer();
1027        // play(7) + ##ing(8) → "playing"
1028        let decoded = tok.decode(&[7, 8]);
1029        assert_eq!(decoded, "playing", "expected 'playing', got '{}'", decoded);
1030    }
1031
1032    // ── Test: from_vocab_file round-trip ─────────────────────────────────
1033
1034    #[test]
1035    fn test_bert_from_vocab_file() {
1036        use std::io::Write;
1037
1038        let mut tmp = std::env::temp_dir();
1039        tmp.push("scirs2_bert_vocab_test.txt");
1040        {
1041            let mut f = std::fs::File::create(&tmp).expect("create temp file");
1042            writeln!(f, "[PAD]").expect("write");
1043            writeln!(f, "[UNK]").expect("write");
1044            writeln!(f, "[CLS]").expect("write");
1045            writeln!(f, "[SEP]").expect("write");
1046            writeln!(f, "[MASK]").expect("write");
1047            writeln!(f, "hello").expect("write");
1048            writeln!(f, "world").expect("write");
1049        }
1050        let path = tmp.to_str().expect("valid path");
1051        let tokenizer = BertTokenizer::from_vocab_file(path).expect("from_vocab_file");
1052        assert_eq!(tokenizer.convert_token_to_id("[CLS]"), Some(2));
1053        assert_eq!(tokenizer.convert_token_to_id("hello"), Some(5));
1054        let _ = std::fs::remove_file(&tmp);
1055    }
1056}