Skip to main content

scirs2_text/tokenization/
unicode_bpe.rs

1//! Language-agnostic BPE tokenizer with Unicode/NFC normalization.
2//!
3//! Implements the standard BPE merge algorithm operating on Unicode characters
4//! (with optional byte-fallback for unknown characters) rather than on raw
5//! bytes alone, so it works across scripts.
6
7use crate::error::{Result, TextError};
8use std::collections::HashMap;
9
10// ---------------------------------------------------------------------------
11// Configuration
12// ---------------------------------------------------------------------------
13
14/// Configuration for the Unicode-aware BPE tokenizer.
15#[non_exhaustive]
16#[derive(Debug, Clone)]
17pub struct UnicodeBpeConfig {
18    /// Target vocabulary size (base chars + merge operations).
19    pub vocab_size: usize,
20    /// Minimum pair frequency for a merge operation to be kept.
21    pub min_frequency: usize,
22    /// Apply NFC-style normalization (simplified: recompose via canonical form).
23    pub normalize: bool,
24    /// Represent characters absent from the training vocabulary as `<0xHH>` byte tokens.
25    pub byte_fallback: bool,
26}
27
28impl Default for UnicodeBpeConfig {
29    fn default() -> Self {
30        Self {
31            vocab_size: 32_000,
32            min_frequency: 2,
33            normalize: true,
34            byte_fallback: true,
35        }
36    }
37}
38
39// ---------------------------------------------------------------------------
40// Helper: simplified NFC normalization
41// ---------------------------------------------------------------------------
42
43/// Simplified NFC: collect chars, re-emit them — Rust's `char` is Unicode scalar,
44/// so collecting into a `String` already yields a well-formed Unicode string.
45/// Full NFC would require unicode-normalization; here we at minimum remove
46/// ASCII control characters and canonicalize whitespace.
47fn nfc_normalize(s: &str) -> String {
48    s.chars()
49        .filter(|c| !c.is_control() || c.is_whitespace())
50        .collect()
51}
52
53// ---------------------------------------------------------------------------
54// BPE implementation
55// ---------------------------------------------------------------------------
56
57/// Unicode-normalized BPE tokenizer that trains on a raw text corpus.
58pub struct UnicodeBpeTokenizer {
59    config: UnicodeBpeConfig,
60    /// token → id mapping (populated after training).
61    vocab: HashMap<String, u32>,
62    /// id → token reverse mapping.
63    id_to_token: Vec<String>,
64    /// Ordered list of merge operations (pair → merged token).
65    merges: Vec<(String, String)>,
66    /// Special tokens always added to the vocabulary.
67    special_tokens: Vec<String>,
68}
69
70/// Result of a single merge step.
71struct MergeResult {
72    pair: (String, String),
73    freq: usize,
74    new_token: String,
75}
76
77impl UnicodeBpeTokenizer {
78    /// Create an untrained tokenizer with the given configuration.
79    pub fn new(config: UnicodeBpeConfig) -> Self {
80        Self {
81            config,
82            vocab: HashMap::new(),
83            id_to_token: Vec::new(),
84            merges: Vec::new(),
85            special_tokens: vec!["<unk>".into(), "<s>".into(), "</s>".into(), "<pad>".into()],
86        }
87    }
88
89    // -----------------------------------------------------------------------
90    // Training
91    // -----------------------------------------------------------------------
92
93    /// Train the BPE vocabulary on a corpus of strings.
94    pub fn train(&mut self, corpus: &[&str]) -> Result<()> {
95        if corpus.is_empty() {
96            return Err(TextError::InvalidInput(
97                "BPE training corpus must not be empty".into(),
98            ));
99        }
100
101        // ---- 1. Pre-tokenize corpus into words, normalize ----
102        let words: Vec<String> = corpus
103            .iter()
104            .flat_map(|doc| {
105                let normalized = if self.config.normalize {
106                    nfc_normalize(doc)
107                } else {
108                    doc.to_string()
109                };
110                normalized
111                    .split_whitespace()
112                    .map(|w| w.to_owned())
113                    .collect::<Vec<_>>()
114            })
115            .filter(|w| !w.is_empty())
116            .collect();
117
118        if words.is_empty() {
119            return Err(TextError::InvalidInput(
120                "corpus has no words after split".into(),
121            ));
122        }
123
124        // ---- 2. Count word frequencies ----
125        let mut word_freq: HashMap<String, usize> = HashMap::new();
126        for word in &words {
127            *word_freq.entry(word.clone()).or_insert(0) += 1;
128        }
129
130        // ---- 3. Represent each word as a sequence of chars (+ </w> end marker) ----
131        // word_splits: word → Vec<String> of character tokens
132        let mut word_splits: HashMap<String, Vec<String>> = word_freq
133            .keys()
134            .map(|w| {
135                let chars: Vec<String> = w.chars().map(|c| c.to_string()).collect();
136                (w.clone(), chars)
137            })
138            .collect();
139
140        // ---- 4. Collect base character vocabulary ----
141        let mut base_chars: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
142        for chars in word_splits.values() {
143            for c in chars {
144                base_chars.insert(c.clone());
145            }
146        }
147
148        // ---- 5. Initialise vocabulary with special tokens + base chars ----
149        self.vocab.clear();
150        self.id_to_token.clear();
151        self.merges.clear();
152
153        for sp in &self.special_tokens {
154            let id = self.id_to_token.len() as u32;
155            self.vocab.insert(sp.clone(), id);
156            self.id_to_token.push(sp.clone());
157        }
158        for c in &base_chars {
159            if !self.vocab.contains_key(c) {
160                let id = self.id_to_token.len() as u32;
161                self.vocab.insert(c.clone(), id);
162                self.id_to_token.push(c.clone());
163            }
164        }
165
166        // ---- 6. BPE merge loop ----
167        let max_merges = self.config.vocab_size.saturating_sub(self.vocab.len());
168
169        for _ in 0..max_merges {
170            // Count bigram frequencies
171            let mut pair_freq: HashMap<(String, String), usize> = HashMap::new();
172            for (word, freq) in &word_freq {
173                let chars = match word_splits.get(word) {
174                    Some(c) => c,
175                    None => continue,
176                };
177                for window in chars.windows(2) {
178                    *pair_freq
179                        .entry((window[0].clone(), window[1].clone()))
180                        .or_insert(0) += freq;
181                }
182            }
183
184            // Find best merge (highest frequency, tie-break by lexicographic order)
185            let best = pair_freq
186                .iter()
187                .filter(|(_, &freq)| freq >= self.config.min_frequency)
188                .max_by_key(|((a, b), &freq)| (freq, std::cmp::Reverse((a.clone(), b.clone()))));
189
190            let merge = match best {
191                Some(((a, b), &freq)) => MergeResult {
192                    pair: (a.clone(), b.clone()),
193                    freq,
194                    new_token: format!("{}{}", a, b),
195                },
196                None => break, // no more eligible merges
197            };
198
199            if merge.freq < self.config.min_frequency {
200                break;
201            }
202
203            // Register new token
204            if !self.vocab.contains_key(&merge.new_token) {
205                let id = self.id_to_token.len() as u32;
206                self.vocab.insert(merge.new_token.clone(), id);
207                self.id_to_token.push(merge.new_token.clone());
208            }
209            self.merges.push(merge.pair.clone());
210
211            // Apply merge to all word splits
212            let (ref left, ref right) = merge.pair;
213            for chars in word_splits.values_mut() {
214                let mut i = 0;
215                while i + 1 < chars.len() {
216                    if chars[i] == *left && chars[i + 1] == *right {
217                        let merged = format!("{}{}", chars[i], chars[i + 1]);
218                        chars.splice(i..=i + 1, std::iter::once(merged));
219                        // Don't advance i — the newly merged token might pair again
220                    } else {
221                        i += 1;
222                    }
223                }
224            }
225        }
226
227        Ok(())
228    }
229
230    // -----------------------------------------------------------------------
231    // Encoding
232    // -----------------------------------------------------------------------
233
234    /// Tokenize a string to token IDs.
235    pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
236        if self.vocab.is_empty() {
237            return Err(TextError::ModelNotFitted(
238                "BPE tokenizer has not been trained".into(),
239            ));
240        }
241
242        let normalized = if self.config.normalize {
243            nfc_normalize(text)
244        } else {
245            text.to_string()
246        };
247
248        let unk_id = self.vocab.get("<unk>").copied().unwrap_or(0);
249
250        let mut ids = Vec::new();
251
252        for word in normalized.split_whitespace() {
253            // Split word into individual characters
254            let mut chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
255
256            // Apply merges in training order
257            for (left, right) in &self.merges {
258                let mut i = 0;
259                while i + 1 < chars.len() {
260                    if chars[i] == *left && chars[i + 1] == *right {
261                        let merged = format!("{}{}", chars[i], chars[i + 1]);
262                        chars.splice(i..=i + 1, std::iter::once(merged));
263                    } else {
264                        i += 1;
265                    }
266                }
267            }
268
269            for tok in chars {
270                if let Some(&id) = self.vocab.get(&tok) {
271                    ids.push(id);
272                } else if self.config.byte_fallback {
273                    // Encode as individual UTF-8 bytes: <0xHH>
274                    for byte in tok.as_bytes() {
275                        let byte_tok = format!("<0x{:02X}>", byte);
276                        let id = self.vocab.get(&byte_tok).copied().unwrap_or(unk_id);
277                        ids.push(id);
278                    }
279                } else {
280                    ids.push(unk_id);
281                }
282            }
283        }
284
285        Ok(ids)
286    }
287
288    // -----------------------------------------------------------------------
289    // Decoding
290    // -----------------------------------------------------------------------
291
292    /// Decode token IDs back to a string.
293    pub fn decode(&self, ids: &[u32]) -> Result<String> {
294        if self.id_to_token.is_empty() {
295            return Err(TextError::ModelNotFitted(
296                "BPE tokenizer has not been trained".into(),
297            ));
298        }
299        let mut parts = Vec::new();
300        for &id in ids {
301            let idx = id as usize;
302            if idx >= self.id_to_token.len() {
303                return Err(TextError::InvalidInput(format!(
304                    "token id {} out of vocabulary range {}",
305                    id,
306                    self.id_to_token.len()
307                )));
308            }
309            parts.push(self.id_to_token[idx].clone());
310        }
311        Ok(parts.join(" "))
312    }
313
314    /// Current vocabulary size (number of entries in the token → id mapping).
315    pub fn vocab_size(&self) -> usize {
316        self.vocab.len()
317    }
318
319    /// Number of merge operations learned during training.
320    pub fn n_merges(&self) -> usize {
321        self.merges.len()
322    }
323
324    /// Access the raw vocabulary map.
325    pub fn vocab(&self) -> &HashMap<String, u32> {
326        &self.vocab
327    }
328}
329
330// ---------------------------------------------------------------------------
331// Tests
332// ---------------------------------------------------------------------------
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    fn small_corpus() -> Vec<&'static str> {
339        vec![
340            "low lower lowest",
341            "new newer newest",
342            "low new lower newest",
343            "the lowest number",
344        ]
345    }
346
347    #[test]
348    fn test_default_config() {
349        let cfg = UnicodeBpeConfig::default();
350        assert_eq!(cfg.vocab_size, 32_000);
351        assert_eq!(cfg.min_frequency, 2);
352        assert!(cfg.normalize);
353        assert!(cfg.byte_fallback);
354    }
355
356    #[test]
357    fn test_train_empty_corpus_error() {
358        let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
359        let result = tok.train(&[]);
360        assert!(result.is_err(), "empty corpus must return error");
361    }
362
363    #[test]
364    fn test_train_succeeds() {
365        let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
366        tok.train(&small_corpus()).expect("train failed");
367        assert!(
368            tok.vocab_size() > 0,
369            "vocab should be non-empty after training"
370        );
371    }
372
373    #[test]
374    fn test_vocab_size_bounded() {
375        let config = UnicodeBpeConfig {
376            vocab_size: 20,
377            min_frequency: 1,
378            ..Default::default()
379        };
380        let mut tok = UnicodeBpeTokenizer::new(config);
381        tok.train(&small_corpus()).expect("train failed");
382        assert!(
383            tok.vocab_size() <= 20,
384            "vocab size {} must be <= 20",
385            tok.vocab_size()
386        );
387    }
388
389    #[test]
390    fn test_encode_returns_ids() {
391        let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig {
392            min_frequency: 1,
393            ..Default::default()
394        });
395        tok.train(&small_corpus()).expect("train failed");
396        let ids = tok.encode("low").expect("encode failed");
397        assert!(
398            !ids.is_empty(),
399            "encoding 'low' should produce at least one id"
400        );
401    }
402
403    #[test]
404    fn test_encode_before_train_error() {
405        let tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
406        let result = tok.encode("hello");
407        assert!(result.is_err(), "encode before train must return error");
408    }
409
410    #[test]
411    fn test_decode_before_train_error() {
412        let tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
413        let result = tok.decode(&[0, 1]);
414        assert!(result.is_err(), "decode before train must return error");
415    }
416
417    #[test]
418    fn test_n_merges_increases_with_training() {
419        let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig {
420            vocab_size: 50,
421            min_frequency: 1,
422            ..Default::default()
423        });
424        tok.train(&small_corpus()).expect("train failed");
425        assert!(tok.n_merges() > 0, "should have at least one merge");
426    }
427
428    #[test]
429    fn test_special_tokens_in_vocab() {
430        let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
431        tok.train(&small_corpus()).expect("train failed");
432        assert!(tok.vocab().contains_key("<unk>"));
433        assert!(tok.vocab().contains_key("<s>"));
434        assert!(tok.vocab().contains_key("</s>"));
435    }
436
437    #[test]
438    fn test_decode_special_token() {
439        let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig {
440            min_frequency: 1,
441            ..Default::default()
442        });
443        tok.train(&small_corpus()).expect("train failed");
444        let unk_id = tok.vocab()["<unk>"];
445        let decoded = tok.decode(&[unk_id]).expect("decode failed");
446        assert_eq!(decoded, "<unk>");
447    }
448}