Skip to main content

scirs2_text/tokenization/
wordpiece.rs

1//! WordPiece tokenizer — the subword tokenization used by BERT.
2//!
3//! Implements:
4//! - [`BasicTokenizer`]: whitespace + punctuation splitting with optional
5//!   lower-casing and accent stripping.
6//! - [`WordPieceTokenizer`]: greedy longest-match subword tokenisation.
7
8use std::collections::HashMap;
9
10use crate::error::{Result, TextError};
11
12// ─── BasicTokenizer ───────────────────────────────────────────────────────────
13
14/// BERT-style basic tokenizer: split on whitespace and punctuation, optionally
15/// lower-case and strip accent marks.
16#[derive(Debug, Clone)]
17pub struct BasicTokenizer {
18    /// Lowercase all characters before tokenising.
19    pub do_lower_case: bool,
20    /// Strip Unicode combining characters (accents / diacritics).
21    pub strip_accents: bool,
22}
23
24impl BasicTokenizer {
25    /// Create a new [`BasicTokenizer`].
26    pub fn new(do_lower_case: bool, strip_accents: bool) -> Self {
27        BasicTokenizer {
28            do_lower_case,
29            strip_accents,
30        }
31    }
32
33    /// Tokenize `text` into a list of token strings.
34    pub fn tokenize(&self, text: &str) -> Vec<String> {
35        let text = if self.do_lower_case {
36            text.to_lowercase()
37        } else {
38            text.to_string()
39        };
40
41        // Strip accent marks (combining Unicode characters, category Mn)
42        let text = if self.strip_accents {
43            strip_accents_str(&text)
44        } else {
45            text
46        };
47
48        // Insert whitespace around punctuation and split
49        let mut spaced = String::with_capacity(text.len() + 16);
50        for ch in text.chars() {
51            if ch.is_whitespace() {
52                spaced.push(' ');
53            } else if is_punctuation(ch) || is_chinese_char(ch) {
54                spaced.push(' ');
55                spaced.push(ch);
56                spaced.push(' ');
57            } else {
58                spaced.push(ch);
59            }
60        }
61
62        spaced
63            .split_whitespace()
64            .filter(|s| !s.is_empty())
65            .map(|s| s.to_string())
66            .collect()
67    }
68}
69
70impl Default for BasicTokenizer {
71    fn default() -> Self {
72        BasicTokenizer::new(true, true)
73    }
74}
75
76/// Return `true` for characters in Unicode general category Mn (non-spacing
77/// combining marks).  We use an approximate range check.
78fn is_combining_mark(ch: char) -> bool {
79    let cp = ch as u32;
80    // Combining Diacritical Marks U+0300–U+036F
81    // Combining Diacritical Marks Supplement U+1DC0–U+1DFF
82    // Combining Diacritical Marks Extended U+1AB0–U+1AFF
83    (0x0300..=0x036F).contains(&cp)
84        || (0x1DC0..=0x1DFF).contains(&cp)
85        || (0x1AB0..=0x1AFF).contains(&cp)
86}
87
88/// Decompose `s` to NFD then drop combining characters.
89fn strip_accents_str(s: &str) -> String {
90    // Manual NFD-like decomposition: use unicode_normalization if available, else
91    // do a best-effort strip of common combining marks.
92    use unicode_normalization::UnicodeNormalization;
93    s.nfd().filter(|&ch| !is_combining_mark(ch)).collect()
94}
95
96/// `true` for ASCII punctuation and Unicode punctuation categories.
97fn is_punctuation(ch: char) -> bool {
98    if (ch as u32) <= 47
99        || (58..=64).contains(&(ch as u32))
100        || (91..=96).contains(&(ch as u32))
101        || (123..=126).contains(&(ch as u32))
102    {
103        return true;
104    }
105    ch.is_ascii_punctuation() || ch == '。' || ch == ','
106}
107
108/// `true` for CJK Unified Ideograph ranges.
109fn is_chinese_char(ch: char) -> bool {
110    let cp = ch as u32;
111    (0x4E00..=0x9FFF).contains(&cp)
112        || (0x3400..=0x4DBF).contains(&cp)
113        || (0x20000..=0x2A6DF).contains(&cp)
114        || (0x2A700..=0x2B73F).contains(&cp)
115        || (0x2B740..=0x2B81F).contains(&cp)
116        || (0x2B820..=0x2CEAF).contains(&cp)
117        || (0xF900..=0xFAFF).contains(&cp)
118        || (0x2F800..=0x2FA1F).contains(&cp)
119}
120
121// ─── WordPieceTokenizer ───────────────────────────────────────────────────────
122
123/// BERT-style WordPiece tokenizer.
124///
125/// Words are split by [`BasicTokenizer`] first, then each word is broken into
126/// the longest matching subwords from the vocabulary.  Continuation subwords
127/// are prefixed with `##`.
128#[derive(Debug, Clone)]
129pub struct WordPieceTokenizer {
130    vocab: HashMap<String, u32>,
131    id_to_token: Vec<String>,
132    unk_id: u32,
133    max_input_chars_per_word: usize,
134    basic: BasicTokenizer,
135}
136
137impl WordPieceTokenizer {
138    // Special token constants
139    const UNK_TOKEN: &'static str = "[UNK]";
140    const CLS_TOKEN: &'static str = "[CLS]";
141    const SEP_TOKEN: &'static str = "[SEP]";
142    const PAD_TOKEN: &'static str = "[PAD]";
143    const MASK_TOKEN: &'static str = "[MASK]";
144
145    /// Build from an existing `token → id` vocabulary map.
146    ///
147    /// The vocabulary must contain at least `[UNK]`.  If `[UNK]` is missing
148    /// it is added with ID `vocab.len()`.
149    pub fn from_vocab(mut vocab: HashMap<String, u32>) -> Self {
150        // Ensure [UNK] exists
151        if !vocab.contains_key(Self::UNK_TOKEN) {
152            let next_id = vocab.len() as u32;
153            vocab.insert(Self::UNK_TOKEN.to_string(), next_id);
154        }
155        let unk_id = vocab[Self::UNK_TOKEN];
156
157        // Build id→token from max ID
158        let max_id = vocab.values().copied().max().unwrap_or(0) as usize;
159        let mut id_to_token = vec![String::new(); max_id + 1];
160        for (tok, &id) in &vocab {
161            if let Some(slot) = id_to_token.get_mut(id as usize) {
162                *slot = tok.clone();
163            }
164        }
165
166        WordPieceTokenizer {
167            vocab,
168            id_to_token,
169            unk_id,
170            max_input_chars_per_word: 200,
171            basic: BasicTokenizer::default(),
172        }
173    }
174
175    /// Build a minimal tokenizer from a plain vocabulary list (one token per
176    /// entry; ID = index).
177    pub fn from_vocab_list(tokens: &[impl AsRef<str>]) -> Self {
178        let vocab: HashMap<String, u32> = tokens
179            .iter()
180            .enumerate()
181            .map(|(i, t)| (t.as_ref().to_string(), i as u32))
182            .collect();
183        Self::from_vocab(vocab)
184    }
185
186    /// Set the maximum number of input characters per word before falling back
187    /// to `[UNK]`.
188    pub fn with_max_input_chars(mut self, n: usize) -> Self {
189        self.max_input_chars_per_word = n;
190        self
191    }
192
193    // ── Core subword splitting ──────────────────────────────────────────
194
195    /// Greedy longest-match WordPiece segmentation of a single `word`.
196    fn wordpiece_word(&self, word: &str) -> Vec<String> {
197        let chars: Vec<char> = word.chars().collect();
198        if chars.len() > self.max_input_chars_per_word {
199            return vec![Self::UNK_TOKEN.to_string()];
200        }
201
202        let mut sub_tokens: Vec<String> = Vec::new();
203        let mut start = 0usize;
204        let n = chars.len();
205        let mut is_bad = false;
206
207        while start < n {
208            let mut end = n;
209            let mut found: Option<String> = None;
210
211            while start < end {
212                let substr: String = chars[start..end].iter().collect();
213                let candidate = if start == 0 {
214                    substr.clone()
215                } else {
216                    format!("##{}", substr)
217                };
218
219                if self.vocab.contains_key(&candidate) {
220                    found = Some(candidate);
221                    break;
222                }
223                if end == start + 1 {
224                    // Single character not in vocab → whole word is unknown
225                    is_bad = true;
226                    break;
227                }
228                end -= 1;
229            }
230
231            if is_bad {
232                break;
233            }
234
235            match found {
236                Some(tok) => {
237                    sub_tokens.push(tok);
238                    start = end;
239                }
240                None => {
241                    is_bad = true;
242                    break;
243                }
244            }
245        }
246
247        if is_bad {
248            vec![Self::UNK_TOKEN.to_string()]
249        } else {
250            sub_tokens
251        }
252    }
253
254    // ── Public API ──────────────────────────────────────────────────────
255
256    /// Tokenize `text` to subword token IDs.
257    pub fn tokenize(&self, text: &str) -> Vec<u32> {
258        self.tokenize_to_strings(text)
259            .iter()
260            .map(|tok| self.vocab.get(tok.as_str()).copied().unwrap_or(self.unk_id))
261            .collect()
262    }
263
264    /// Tokenize `text` to subword token strings.
265    pub fn tokenize_to_strings(&self, text: &str) -> Vec<String> {
266        let words = self.basic.tokenize(text);
267        words.iter().flat_map(|w| self.wordpiece_word(w)).collect()
268    }
269
270    /// Decode a sequence of token IDs back to text.
271    pub fn decode(&self, ids: &[u32]) -> String {
272        let mut out = String::new();
273        for &id in ids {
274            let tok = self
275                .id_to_token
276                .get(id as usize)
277                .map(|s| s.as_str())
278                .unwrap_or("[UNK]");
279
280            // Skip padding
281            if tok == Self::PAD_TOKEN {
282                continue;
283            }
284
285            if tok.starts_with("##") {
286                out.push_str(&tok[2..]);
287            } else if !out.is_empty() && tok != Self::CLS_TOKEN && tok != Self::SEP_TOKEN {
288                out.push(' ');
289                out.push_str(tok);
290            } else {
291                out.push_str(tok);
292            }
293        }
294        out
295    }
296
297    /// Encode `text` with optional special tokens, padding/truncation to
298    /// `max_length`.
299    ///
300    /// Returns `(input_ids, attention_mask)` where `attention_mask[i] = 1`
301    /// for real tokens and `0` for padding.
302    pub fn encode(
303        &self,
304        text: &str,
305        max_length: usize,
306        add_special_tokens: bool,
307    ) -> Result<(Vec<u32>, Vec<u8>)> {
308        if max_length == 0 {
309            return Err(TextError::InvalidInput(
310                "max_length must be > 0".to_string(),
311            ));
312        }
313
314        let cls_id = self
315            .vocab
316            .get(Self::CLS_TOKEN)
317            .copied()
318            .unwrap_or(self.unk_id);
319        let sep_id = self
320            .vocab
321            .get(Self::SEP_TOKEN)
322            .copied()
323            .unwrap_or(self.unk_id);
324        let pad_id = self
325            .vocab
326            .get(Self::PAD_TOKEN)
327            .copied()
328            .unwrap_or(self.unk_id);
329
330        let token_ids = self.tokenize(text);
331
332        // Reserve space for [CLS] and [SEP] when using special tokens
333        let reserve = if add_special_tokens { 2 } else { 0 };
334        let content_budget = max_length.saturating_sub(reserve);
335        let truncated: Vec<u32> = token_ids.into_iter().take(content_budget).collect();
336
337        let mut ids: Vec<u32> = Vec::with_capacity(max_length);
338        if add_special_tokens {
339            ids.push(cls_id);
340        }
341        ids.extend_from_slice(&truncated);
342        if add_special_tokens {
343            ids.push(sep_id);
344        }
345
346        let real_len = ids.len();
347        // Pad to max_length
348        while ids.len() < max_length {
349            ids.push(pad_id);
350        }
351
352        let mut mask: Vec<u8> = vec![0u8; max_length];
353        for m in mask.iter_mut().take(real_len) {
354            *m = 1;
355        }
356
357        Ok((ids, mask))
358    }
359
360    /// Vocabulary size.
361    pub fn vocab_size(&self) -> usize {
362        self.vocab.len()
363    }
364
365    /// Return a cloned snapshot of the `token → id` vocabulary map.
366    ///
367    /// Useful for serialisation (e.g. building a HuggingFace `tokenizers.json`).
368    pub fn vocab_snapshot(&self) -> HashMap<String, u32> {
369        self.vocab.clone()
370    }
371}
372
373// ─── Tests ────────────────────────────────────────────────────────────────────
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use std::collections::HashMap;
379
380    fn mini_vocab() -> HashMap<String, u32> {
381        let mut v = HashMap::new();
382        for (i, tok) in [
383            "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "he", "llo", "##llo", "world", "##world",
384            "want", "##ed", "to", "un", "##want", "##ed", "low", "##er", "##est", "new", "##er",
385            "##est", "h", "e", "l", "o", "w", "r", "d",
386        ]
387        .iter()
388        .enumerate()
389        {
390            v.entry(tok.to_string()).or_insert(i as u32);
391        }
392        v
393    }
394
395    #[test]
396    fn test_basic_tokenizer_lower() {
397        let tok = BasicTokenizer::new(true, false);
398        let tokens = tok.tokenize("Hello, World!");
399        assert!(tokens.iter().any(|t| t == "hello"));
400        assert!(tokens.iter().any(|t| t == "world"));
401        assert!(tokens.iter().any(|t| t == ","));
402        assert!(tokens.iter().any(|t| t == "!"));
403    }
404
405    #[test]
406    fn test_basic_tokenizer_no_lower() {
407        let tok = BasicTokenizer::new(false, false);
408        let tokens = tok.tokenize("Hello World");
409        assert!(tokens.iter().any(|t| t == "Hello"));
410        assert!(tokens.iter().any(|t| t == "World"));
411    }
412
413    #[test]
414    fn test_wordpiece_tokenize_to_strings_known() {
415        let vocab = mini_vocab();
416        let wp = WordPieceTokenizer::from_vocab(vocab);
417        // Words fully in vocab should not become [UNK]
418        let tokens = wp.tokenize_to_strings("low");
419        assert!(!tokens.iter().any(|t| t == "[UNK]"), "got {:?}", tokens);
420    }
421
422    #[test]
423    fn test_wordpiece_encode_length() {
424        let vocab = mini_vocab();
425        let wp = WordPieceTokenizer::from_vocab(vocab);
426        let (ids, mask) = wp.encode("low", 8, true).expect("encode failed");
427        assert_eq!(ids.len(), 8);
428        assert_eq!(mask.len(), 8);
429        // First mask value should be 1 ([CLS])
430        assert_eq!(mask[0], 1);
431    }
432
433    #[test]
434    fn test_wordpiece_encode_truncation() {
435        let vocab = mini_vocab();
436        let wp = WordPieceTokenizer::from_vocab(vocab);
437        let (ids, mask) = wp
438            .encode("low low low low", 4, true)
439            .expect("encode failed");
440        assert_eq!(ids.len(), 4);
441        assert_eq!(mask.len(), 4);
442    }
443
444    #[test]
445    fn test_wordpiece_encode_no_special_tokens() {
446        let vocab = mini_vocab();
447        let wp = WordPieceTokenizer::from_vocab(vocab);
448        let (ids, mask) = wp.encode("low", 4, false).expect("encode failed");
449        assert_eq!(ids.len(), 4);
450        // Real tokens + padding
451        assert!(mask[0] == 1);
452    }
453
454    #[test]
455    fn test_wordpiece_decode_strips_double_hash() {
456        let vocab = mini_vocab();
457        let wp = WordPieceTokenizer::from_vocab(vocab);
458        // low ##er should decode to "lower"
459        let low_id = *wp.vocab.get("low").unwrap();
460        let er_id = *wp.vocab.get("##er").unwrap();
461        let decoded = wp.decode(&[low_id, er_id]);
462        assert_eq!(decoded, "lower");
463    }
464
465    #[test]
466    fn test_basic_tokenizer_punctuation_isolation() {
467        let tok = BasicTokenizer::new(false, false);
468        let tokens = tok.tokenize("It's fine.");
469        // Period should be its own token
470        assert!(tokens.contains(&".".to_string()));
471    }
472}