Skip to main content

scirs2_text/tokenization/
multilingual_bpe.rs

1//! Multilingual BPE tokenizer — shared vocabulary across 50+ languages.
2//!
3//! Implements temperature-based language-balanced corpus sampling (mBERT/XLM-R
4//! style) where the probability of sampling from language `l` is
5//! `p_l = (size_l / total_size)^alpha / Z`.
6//!
7//! - `alpha = 1.0` gives proportional (natural) sampling.
8//! - `alpha = 0.0` gives uniform sampling across languages.
9//! - Intermediate values up-sample low-resource languages.
10
11use crate::error::{Result, TextError};
12use std::collections::HashMap;
13
14/// Type alias for the four-tuple returned by `init_base`.
15type InitBaseMaps = (
16    HashMap<u8, char>,
17    HashMap<char, u8>,
18    HashMap<String, u32>,
19    Vec<String>,
20);
21
22// ─── LanguageCorpus ───────────────────────────────────────────────────────────
23
24/// A single-language corpus with an optional manual weight override.
25#[derive(Debug, Clone)]
26pub struct LanguageCorpus {
27    /// BCP-47 language tag or arbitrary identifier.
28    pub language: String,
29    /// Raw text documents.
30    pub texts: Vec<String>,
31    /// Optional explicit weight; if `None` the weight is derived from corpus size.
32    pub weight: f64,
33}
34
35impl LanguageCorpus {
36    /// Construct a corpus with a manual weight.
37    pub fn new(language: impl Into<String>, texts: Vec<String>, weight: f64) -> Self {
38        LanguageCorpus {
39            language: language.into(),
40            texts,
41            weight,
42        }
43    }
44
45    /// Construct a corpus whose weight is derived from the number of tokens
46    /// (whitespace-split words) in all texts.
47    pub fn from_texts(language: impl Into<String>, texts: Vec<String>) -> Self {
48        let size: f64 = texts
49            .iter()
50            .map(|t| t.split_whitespace().count() as f64)
51            .sum();
52        LanguageCorpus {
53            language: language.into(),
54            texts,
55            weight: size.max(1.0),
56        }
57    }
58}
59
60// ─── MultilingualBpeConfig ────────────────────────────────────────────────────
61
62/// Configuration for [`MultilingualBpeTokenizer`] training.
63#[derive(Debug, Clone)]
64pub struct MultilingualBpeConfig {
65    /// Target vocabulary size.
66    pub vocab_size: usize,
67    /// Temperature exponent for language-balanced sampling.
68    ///
69    /// `alpha = 1.0` → proportional; `alpha = 0.0` → uniform.
70    pub alpha: f64,
71    /// Minimum pair frequency for a merge to be accepted.
72    pub min_frequency: usize,
73    /// Whether to prepend `Ġ` to non-initial words (GPT-2 style).
74    pub add_prefix_space: bool,
75}
76
77impl Default for MultilingualBpeConfig {
78    fn default() -> Self {
79        MultilingualBpeConfig {
80            vocab_size: 250_000,
81            alpha: 0.5,
82            min_frequency: 5,
83            add_prefix_space: true,
84        }
85    }
86}
87
88// ─── MultilingualBpeTokenizer ─────────────────────────────────────────────────
89
90/// Multilingual byte-level BPE tokenizer.
91///
92/// Shares a single vocabulary and merge table across all languages, trained
93/// with language-balanced sampling to prevent high-resource languages from
94/// dominating the vocabulary.
95#[derive(Debug, Clone)]
96pub struct MultilingualBpeTokenizer {
97    /// token string → integer id
98    pub vocab: HashMap<String, u32>,
99    /// integer id → token string
100    pub id_to_token: Vec<String>,
101    /// ordered merge rules
102    pub merges: Vec<(String, String)>,
103    /// byte → unicode char  (GPT-2 table)
104    pub byte_encoder: HashMap<u8, char>,
105    /// unicode char → byte
106    pub byte_decoder: HashMap<char, u8>,
107    /// language sampling probabilities used during training
108    pub language_probs: HashMap<String, f64>,
109}
110
111impl MultilingualBpeTokenizer {
112    /// Build base vocabulary (256 byte-level characters).
113    fn init_base() -> InitBaseMaps {
114        use super::byte_level_bpe::bytes_to_unicode;
115        let byte_encoder = bytes_to_unicode();
116        let byte_decoder: HashMap<char, u8> = byte_encoder.iter().map(|(&b, &c)| (c, b)).collect();
117
118        let mut vocab: HashMap<String, u32> = HashMap::new();
119        let mut id_to_token: Vec<String> = Vec::new();
120
121        for b in 0u8..=255u8 {
122            let ch = byte_encoder[&b];
123            let tok = ch.to_string();
124            if !vocab.contains_key(&tok) {
125                let id = id_to_token.len() as u32;
126                vocab.insert(tok.clone(), id);
127                id_to_token.push(tok);
128            }
129        }
130        (byte_encoder, byte_decoder, vocab, id_to_token)
131    }
132
133    /// Compute temperature-smoothed language sampling probabilities.
134    ///
135    /// `p_l = weight_l^alpha / sum_k(weight_k^alpha)`
136    ///
137    /// Returns `None` only when `corpora` is empty.
138    pub fn compute_language_probs(
139        corpora: &[LanguageCorpus],
140        alpha: f64,
141    ) -> Option<HashMap<String, f64>> {
142        if corpora.is_empty() {
143            return None;
144        }
145        let powered: Vec<f64> = corpora.iter().map(|c| c.weight.powf(alpha)).collect();
146        let z: f64 = powered.iter().sum();
147        if z == 0.0 {
148            // Uniform fallback
149            let p = 1.0 / corpora.len() as f64;
150            return Some(corpora.iter().map(|c| (c.language.clone(), p)).collect());
151        }
152        Some(
153            corpora
154                .iter()
155                .zip(powered.iter())
156                .map(|(c, &pw)| (c.language.clone(), pw / z))
157                .collect(),
158        )
159    }
160
161    /// Byte-encode a single string into a sequence of unicode-char tokens.
162    fn byte_encode(byte_encoder: &HashMap<u8, char>, s: &str) -> Vec<String> {
163        s.bytes()
164            .map(|b| {
165                byte_encoder
166                    .get(&b)
167                    .copied()
168                    .unwrap_or('\u{FFFD}')
169                    .to_string()
170            })
171            .collect()
172    }
173
174    /// Apply all known merges (priority-ordered) to a word token sequence.
175    fn apply_merges(merges: &[(String, String)], mut word: Vec<String>) -> Vec<String> {
176        let merge_rank: HashMap<(String, String), usize> = merges
177            .iter()
178            .enumerate()
179            .map(|(i, (a, b))| ((a.clone(), b.clone()), i))
180            .collect();
181        loop {
182            if word.len() < 2 {
183                break;
184            }
185            let mut best_rank = usize::MAX;
186            let mut best_idx = usize::MAX;
187            for i in 0..word.len() - 1 {
188                let pair = (word[i].clone(), word[i + 1].clone());
189                if let Some(&rank) = merge_rank.get(&pair) {
190                    if rank < best_rank {
191                        best_rank = rank;
192                        best_idx = i;
193                    }
194                }
195            }
196            if best_idx == usize::MAX {
197                break;
198            }
199            let merged = format!("{}{}", word[best_idx], word[best_idx + 1]);
200            word.remove(best_idx + 1);
201            word[best_idx] = merged;
202        }
203        word
204    }
205
206    /// Train a new [`MultilingualBpeTokenizer`] from a set of language corpora.
207    ///
208    /// Language probabilities are computed via temperature smoothing
209    /// (`alpha` parameter in config).  Pair frequencies are accumulated as a
210    /// weighted sum: `count += freq * p_l` for each language `l`.
211    pub fn train(corpora: &[LanguageCorpus], config: MultilingualBpeConfig) -> Self {
212        let (byte_encoder, byte_decoder, mut vocab, mut id_to_token) = Self::init_base();
213
214        let lang_probs = Self::compute_language_probs(corpora, config.alpha).unwrap_or_default();
215
216        // Build per-language word-frequency maps
217        let mut lang_word_freq: Vec<(f64, HashMap<Vec<String>, usize>)> =
218            Vec::with_capacity(corpora.len());
219
220        for corpus in corpora {
221            let prob = lang_probs.get(&corpus.language).copied().unwrap_or(0.0);
222            let mut word_freq: HashMap<Vec<String>, usize> = HashMap::new();
223            for text in &corpus.texts {
224                let mut first = true;
225                for word in text.split_whitespace() {
226                    let prefixed = if first || !config.add_prefix_space {
227                        word.to_string()
228                    } else {
229                        format!("\u{0120}{}", word)
230                    };
231                    first = false;
232                    let encoded = Self::byte_encode(&byte_encoder, &prefixed);
233                    *word_freq.entry(encoded).or_insert(0) += 1;
234                }
235            }
236            lang_word_freq.push((prob, word_freq));
237        }
238
239        let mut merges: Vec<(String, String)> = Vec::new();
240
241        // BPE merge loop with language-weighted pair counting
242        while vocab.len() < config.vocab_size {
243            let mut pair_freq: HashMap<(String, String), f64> = HashMap::new();
244
245            for (prob, word_freq) in &lang_word_freq {
246                for (word, &count) in word_freq {
247                    let weighted = count as f64 * prob;
248                    for i in 0..word.len().saturating_sub(1) {
249                        let pair = (word[i].clone(), word[i + 1].clone());
250                        *pair_freq.entry(pair).or_insert(0.0) += weighted;
251                    }
252                }
253            }
254
255            let best = pair_freq
256                .iter()
257                .filter(|(_, &f)| f >= config.min_frequency as f64)
258                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
259
260            let (left, right) = match best {
261                Some(((l, r), _)) => (l.clone(), r.clone()),
262                None => break,
263            };
264
265            merges.push((left.clone(), right.clone()));
266            let merged = format!("{}{}", left, right);
267            let new_id = id_to_token.len() as u32;
268            vocab.insert(merged.clone(), new_id);
269            id_to_token.push(merged.clone());
270
271            // Apply merge to all language word maps
272            for (_, word_freq) in &mut lang_word_freq {
273                let updated: HashMap<Vec<String>, usize> = word_freq
274                    .drain()
275                    .map(|(word, freq)| (merge_pair(&word, &left, &right), freq))
276                    .collect();
277                *word_freq = updated;
278            }
279        }
280
281        MultilingualBpeTokenizer {
282            vocab,
283            id_to_token,
284            merges,
285            byte_encoder,
286            byte_decoder,
287            language_probs: lang_probs,
288        }
289    }
290
291    /// Encode text using this tokenizer.
292    ///
293    /// Language is accepted as an argument for API symmetry but the encoding
294    /// is purely language-agnostic (same BPE merges for all languages).
295    pub fn encode_with_language(&self, text: &str, _lang: &str) -> Vec<u32> {
296        self.encode(text)
297    }
298
299    /// Encode text to token IDs.
300    pub fn encode(&self, text: &str) -> Vec<u32> {
301        let mut ids = Vec::new();
302        let mut first = true;
303        for word in text.split_whitespace() {
304            let prefixed = if first {
305                word.to_string()
306            } else {
307                format!("\u{0120}{}", word)
308            };
309            first = false;
310            let chars = Self::byte_encode(&self.byte_encoder, &prefixed);
311            let merged = Self::apply_merges(&self.merges, chars);
312            for tok in merged {
313                if let Some(&id) = self.vocab.get(&tok) {
314                    ids.push(id);
315                }
316            }
317        }
318        ids
319    }
320
321    /// Decode token IDs back to a UTF-8 string.
322    pub fn decode(&self, ids: &[u32]) -> String {
323        let mut bytes: Vec<u8> = Vec::new();
324        for &id in ids {
325            if let Some(tok) = self.id_to_token.get(id as usize) {
326                for ch in tok.chars() {
327                    if let Some(&b) = self.byte_decoder.get(&ch) {
328                        bytes.push(b);
329                    }
330                }
331            }
332        }
333        String::from_utf8_lossy(&bytes).into_owned()
334    }
335
336    /// Compute vocabulary coverage: fraction of encoded tokens that are not
337    /// out-of-vocabulary.
338    ///
339    /// With byte-level encoding there is no true OOV, but this metric measures
340    /// the fraction of whitespace-delimited words that are represented as a
341    /// single token (perfectly compressed) versus multiple sub-tokens.
342    pub fn vocabulary_coverage(&self, texts: &[&str]) -> f64 {
343        let mut total_words = 0usize;
344        let mut single_token_words = 0usize;
345        for text in texts {
346            for word in text.split_whitespace() {
347                total_words += 1;
348                let chars = Self::byte_encode(&self.byte_encoder, word);
349                let merged = Self::apply_merges(&self.merges, chars);
350                if merged.len() == 1 {
351                    single_token_words += 1;
352                }
353            }
354        }
355        if total_words == 0 {
356            return 0.0;
357        }
358        single_token_words as f64 / total_words as f64
359    }
360
361    /// Return the vocabulary size.
362    pub fn vocab_size(&self) -> usize {
363        self.vocab.len()
364    }
365}
366
367/// Merge all occurrences of (left, right) adjacent pair in `word`.
368fn merge_pair(word: &[String], left: &str, right: &str) -> Vec<String> {
369    let mut result = Vec::with_capacity(word.len());
370    let mut i = 0;
371    while i < word.len() {
372        if i + 1 < word.len() && word[i] == left && word[i + 1] == right {
373            result.push(format!("{}{}", left, right));
374            i += 2;
375        } else {
376            result.push(word[i].clone());
377            i += 1;
378        }
379    }
380    result
381}
382
383// ─── Tests ───────────────────────────────────────────────────────────────────
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    fn sample_corpora() -> Vec<LanguageCorpus> {
390        vec![
391            LanguageCorpus::from_texts(
392                "en",
393                vec![
394                    "hello world the quick brown fox".to_string(),
395                    "rust is a great language for systems programming".to_string(),
396                    "more english text for training the tokenizer".to_string(),
397                    "the tokenizer should learn common english word pieces".to_string(),
398                ],
399            ),
400            LanguageCorpus::from_texts(
401                "de",
402                vec![
403                    "hallo welt schnell braun fuchs".to_string(),
404                    "rust ist eine großartige sprache".to_string(),
405                ],
406            ),
407            LanguageCorpus::from_texts(
408                "fr",
409                vec![
410                    "bonjour monde renard brun rapide".to_string(),
411                    "rust est un langage de programmation".to_string(),
412                ],
413            ),
414        ]
415    }
416
417    #[test]
418    fn test_language_probs_sum_to_one() {
419        let corpora = sample_corpora();
420        let probs = MultilingualBpeTokenizer::compute_language_probs(&corpora, 0.5)
421            .expect("should compute probs");
422        let sum: f64 = probs.values().sum();
423        assert!(
424            (sum - 1.0).abs() < 1e-9,
425            "language probs should sum to 1.0, got {}",
426            sum
427        );
428    }
429
430    #[test]
431    fn test_alpha_zero_uniform() {
432        let corpora = sample_corpora();
433        let probs = MultilingualBpeTokenizer::compute_language_probs(&corpora, 0.0)
434            .expect("should compute probs");
435        // alpha=0 means weight^0 = 1.0 for all, so uniform
436        let expected = 1.0 / corpora.len() as f64;
437        for (lang, &p) in &probs {
438            assert!(
439                (p - expected).abs() < 1e-9,
440                "lang {} prob {} != uniform {}",
441                lang,
442                p,
443                expected
444            );
445        }
446    }
447
448    #[test]
449    fn test_alpha_one_proportional() {
450        let corpora = sample_corpora();
451        let total_weight: f64 = corpora.iter().map(|c| c.weight).sum();
452        let probs = MultilingualBpeTokenizer::compute_language_probs(&corpora, 1.0)
453            .expect("should compute probs");
454        for corpus in &corpora {
455            let expected = corpus.weight / total_weight;
456            let got = probs[&corpus.language];
457            assert!(
458                (got - expected).abs() < 1e-9,
459                "lang {} prob {} != proportional {}",
460                corpus.language,
461                got,
462                expected
463            );
464        }
465    }
466
467    #[test]
468    fn test_train_vocab_size() {
469        let corpora = sample_corpora();
470        let config = MultilingualBpeConfig {
471            vocab_size: 400,
472            alpha: 0.5,
473            min_frequency: 1,
474            add_prefix_space: true,
475        };
476        let tok = MultilingualBpeTokenizer::train(&corpora, config);
477        assert!(tok.vocab_size() <= 400);
478        assert!(tok.vocab_size() >= 256);
479    }
480
481    #[test]
482    fn test_encode_with_language() {
483        let corpora = sample_corpora();
484        let config = MultilingualBpeConfig {
485            vocab_size: 400,
486            alpha: 0.5,
487            min_frequency: 1,
488            add_prefix_space: true,
489        };
490        let tok = MultilingualBpeTokenizer::train(&corpora, config);
491        let ids_en = tok.encode_with_language("hello world", "en");
492        let ids_de = tok.encode_with_language("hello world", "de");
493        // Language-agnostic: same IDs regardless of lang tag
494        assert_eq!(ids_en, ids_de);
495    }
496
497    #[test]
498    fn test_vocabulary_coverage() {
499        let corpora = sample_corpora();
500        let config = MultilingualBpeConfig {
501            vocab_size: 500,
502            alpha: 0.5,
503            min_frequency: 1,
504            add_prefix_space: false,
505        };
506        let tok = MultilingualBpeTokenizer::train(&corpora, config);
507        let coverage = tok.vocabulary_coverage(&["hello", "rust", "world"]);
508        assert!(
509            (0.0..=1.0).contains(&coverage),
510            "coverage should be in [0,1]"
511        );
512    }
513}