1use crate::error::{Result, TextError};
12use std::collections::HashMap;
13
14type InitBaseMaps = (
16 HashMap<u8, char>,
17 HashMap<char, u8>,
18 HashMap<String, u32>,
19 Vec<String>,
20);
21
22#[derive(Debug, Clone)]
26pub struct LanguageCorpus {
27 pub language: String,
29 pub texts: Vec<String>,
31 pub weight: f64,
33}
34
35impl LanguageCorpus {
36 pub fn new(language: impl Into<String>, texts: Vec<String>, weight: f64) -> Self {
38 LanguageCorpus {
39 language: language.into(),
40 texts,
41 weight,
42 }
43 }
44
45 pub fn from_texts(language: impl Into<String>, texts: Vec<String>) -> Self {
48 let size: f64 = texts
49 .iter()
50 .map(|t| t.split_whitespace().count() as f64)
51 .sum();
52 LanguageCorpus {
53 language: language.into(),
54 texts,
55 weight: size.max(1.0),
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
64pub struct MultilingualBpeConfig {
65 pub vocab_size: usize,
67 pub alpha: f64,
71 pub min_frequency: usize,
73 pub add_prefix_space: bool,
75}
76
77impl Default for MultilingualBpeConfig {
78 fn default() -> Self {
79 MultilingualBpeConfig {
80 vocab_size: 250_000,
81 alpha: 0.5,
82 min_frequency: 5,
83 add_prefix_space: true,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
96pub struct MultilingualBpeTokenizer {
97 pub vocab: HashMap<String, u32>,
99 pub id_to_token: Vec<String>,
101 pub merges: Vec<(String, String)>,
103 pub byte_encoder: HashMap<u8, char>,
105 pub byte_decoder: HashMap<char, u8>,
107 pub language_probs: HashMap<String, f64>,
109}
110
111impl MultilingualBpeTokenizer {
112 fn init_base() -> InitBaseMaps {
114 use super::byte_level_bpe::bytes_to_unicode;
115 let byte_encoder = bytes_to_unicode();
116 let byte_decoder: HashMap<char, u8> = byte_encoder.iter().map(|(&b, &c)| (c, b)).collect();
117
118 let mut vocab: HashMap<String, u32> = HashMap::new();
119 let mut id_to_token: Vec<String> = Vec::new();
120
121 for b in 0u8..=255u8 {
122 let ch = byte_encoder[&b];
123 let tok = ch.to_string();
124 if !vocab.contains_key(&tok) {
125 let id = id_to_token.len() as u32;
126 vocab.insert(tok.clone(), id);
127 id_to_token.push(tok);
128 }
129 }
130 (byte_encoder, byte_decoder, vocab, id_to_token)
131 }
132
133 pub fn compute_language_probs(
139 corpora: &[LanguageCorpus],
140 alpha: f64,
141 ) -> Option<HashMap<String, f64>> {
142 if corpora.is_empty() {
143 return None;
144 }
145 let powered: Vec<f64> = corpora.iter().map(|c| c.weight.powf(alpha)).collect();
146 let z: f64 = powered.iter().sum();
147 if z == 0.0 {
148 let p = 1.0 / corpora.len() as f64;
150 return Some(corpora.iter().map(|c| (c.language.clone(), p)).collect());
151 }
152 Some(
153 corpora
154 .iter()
155 .zip(powered.iter())
156 .map(|(c, &pw)| (c.language.clone(), pw / z))
157 .collect(),
158 )
159 }
160
161 fn byte_encode(byte_encoder: &HashMap<u8, char>, s: &str) -> Vec<String> {
163 s.bytes()
164 .map(|b| {
165 byte_encoder
166 .get(&b)
167 .copied()
168 .unwrap_or('\u{FFFD}')
169 .to_string()
170 })
171 .collect()
172 }
173
174 fn apply_merges(merges: &[(String, String)], mut word: Vec<String>) -> Vec<String> {
176 let merge_rank: HashMap<(String, String), usize> = merges
177 .iter()
178 .enumerate()
179 .map(|(i, (a, b))| ((a.clone(), b.clone()), i))
180 .collect();
181 loop {
182 if word.len() < 2 {
183 break;
184 }
185 let mut best_rank = usize::MAX;
186 let mut best_idx = usize::MAX;
187 for i in 0..word.len() - 1 {
188 let pair = (word[i].clone(), word[i + 1].clone());
189 if let Some(&rank) = merge_rank.get(&pair) {
190 if rank < best_rank {
191 best_rank = rank;
192 best_idx = i;
193 }
194 }
195 }
196 if best_idx == usize::MAX {
197 break;
198 }
199 let merged = format!("{}{}", word[best_idx], word[best_idx + 1]);
200 word.remove(best_idx + 1);
201 word[best_idx] = merged;
202 }
203 word
204 }
205
206 pub fn train(corpora: &[LanguageCorpus], config: MultilingualBpeConfig) -> Self {
212 let (byte_encoder, byte_decoder, mut vocab, mut id_to_token) = Self::init_base();
213
214 let lang_probs = Self::compute_language_probs(corpora, config.alpha).unwrap_or_default();
215
216 let mut lang_word_freq: Vec<(f64, HashMap<Vec<String>, usize>)> =
218 Vec::with_capacity(corpora.len());
219
220 for corpus in corpora {
221 let prob = lang_probs.get(&corpus.language).copied().unwrap_or(0.0);
222 let mut word_freq: HashMap<Vec<String>, usize> = HashMap::new();
223 for text in &corpus.texts {
224 let mut first = true;
225 for word in text.split_whitespace() {
226 let prefixed = if first || !config.add_prefix_space {
227 word.to_string()
228 } else {
229 format!("\u{0120}{}", word)
230 };
231 first = false;
232 let encoded = Self::byte_encode(&byte_encoder, &prefixed);
233 *word_freq.entry(encoded).or_insert(0) += 1;
234 }
235 }
236 lang_word_freq.push((prob, word_freq));
237 }
238
239 let mut merges: Vec<(String, String)> = Vec::new();
240
241 while vocab.len() < config.vocab_size {
243 let mut pair_freq: HashMap<(String, String), f64> = HashMap::new();
244
245 for (prob, word_freq) in &lang_word_freq {
246 for (word, &count) in word_freq {
247 let weighted = count as f64 * prob;
248 for i in 0..word.len().saturating_sub(1) {
249 let pair = (word[i].clone(), word[i + 1].clone());
250 *pair_freq.entry(pair).or_insert(0.0) += weighted;
251 }
252 }
253 }
254
255 let best = pair_freq
256 .iter()
257 .filter(|(_, &f)| f >= config.min_frequency as f64)
258 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
259
260 let (left, right) = match best {
261 Some(((l, r), _)) => (l.clone(), r.clone()),
262 None => break,
263 };
264
265 merges.push((left.clone(), right.clone()));
266 let merged = format!("{}{}", left, right);
267 let new_id = id_to_token.len() as u32;
268 vocab.insert(merged.clone(), new_id);
269 id_to_token.push(merged.clone());
270
271 for (_, word_freq) in &mut lang_word_freq {
273 let updated: HashMap<Vec<String>, usize> = word_freq
274 .drain()
275 .map(|(word, freq)| (merge_pair(&word, &left, &right), freq))
276 .collect();
277 *word_freq = updated;
278 }
279 }
280
281 MultilingualBpeTokenizer {
282 vocab,
283 id_to_token,
284 merges,
285 byte_encoder,
286 byte_decoder,
287 language_probs: lang_probs,
288 }
289 }
290
291 pub fn encode_with_language(&self, text: &str, _lang: &str) -> Vec<u32> {
296 self.encode(text)
297 }
298
299 pub fn encode(&self, text: &str) -> Vec<u32> {
301 let mut ids = Vec::new();
302 let mut first = true;
303 for word in text.split_whitespace() {
304 let prefixed = if first {
305 word.to_string()
306 } else {
307 format!("\u{0120}{}", word)
308 };
309 first = false;
310 let chars = Self::byte_encode(&self.byte_encoder, &prefixed);
311 let merged = Self::apply_merges(&self.merges, chars);
312 for tok in merged {
313 if let Some(&id) = self.vocab.get(&tok) {
314 ids.push(id);
315 }
316 }
317 }
318 ids
319 }
320
321 pub fn decode(&self, ids: &[u32]) -> String {
323 let mut bytes: Vec<u8> = Vec::new();
324 for &id in ids {
325 if let Some(tok) = self.id_to_token.get(id as usize) {
326 for ch in tok.chars() {
327 if let Some(&b) = self.byte_decoder.get(&ch) {
328 bytes.push(b);
329 }
330 }
331 }
332 }
333 String::from_utf8_lossy(&bytes).into_owned()
334 }
335
336 pub fn vocabulary_coverage(&self, texts: &[&str]) -> f64 {
343 let mut total_words = 0usize;
344 let mut single_token_words = 0usize;
345 for text in texts {
346 for word in text.split_whitespace() {
347 total_words += 1;
348 let chars = Self::byte_encode(&self.byte_encoder, word);
349 let merged = Self::apply_merges(&self.merges, chars);
350 if merged.len() == 1 {
351 single_token_words += 1;
352 }
353 }
354 }
355 if total_words == 0 {
356 return 0.0;
357 }
358 single_token_words as f64 / total_words as f64
359 }
360
361 pub fn vocab_size(&self) -> usize {
363 self.vocab.len()
364 }
365}
366
367fn merge_pair(word: &[String], left: &str, right: &str) -> Vec<String> {
369 let mut result = Vec::with_capacity(word.len());
370 let mut i = 0;
371 while i < word.len() {
372 if i + 1 < word.len() && word[i] == left && word[i + 1] == right {
373 result.push(format!("{}{}", left, right));
374 i += 2;
375 } else {
376 result.push(word[i].clone());
377 i += 1;
378 }
379 }
380 result
381}
382
383#[cfg(test)]
386mod tests {
387 use super::*;
388
389 fn sample_corpora() -> Vec<LanguageCorpus> {
390 vec![
391 LanguageCorpus::from_texts(
392 "en",
393 vec![
394 "hello world the quick brown fox".to_string(),
395 "rust is a great language for systems programming".to_string(),
396 "more english text for training the tokenizer".to_string(),
397 "the tokenizer should learn common english word pieces".to_string(),
398 ],
399 ),
400 LanguageCorpus::from_texts(
401 "de",
402 vec![
403 "hallo welt schnell braun fuchs".to_string(),
404 "rust ist eine großartige sprache".to_string(),
405 ],
406 ),
407 LanguageCorpus::from_texts(
408 "fr",
409 vec![
410 "bonjour monde renard brun rapide".to_string(),
411 "rust est un langage de programmation".to_string(),
412 ],
413 ),
414 ]
415 }
416
417 #[test]
418 fn test_language_probs_sum_to_one() {
419 let corpora = sample_corpora();
420 let probs = MultilingualBpeTokenizer::compute_language_probs(&corpora, 0.5)
421 .expect("should compute probs");
422 let sum: f64 = probs.values().sum();
423 assert!(
424 (sum - 1.0).abs() < 1e-9,
425 "language probs should sum to 1.0, got {}",
426 sum
427 );
428 }
429
430 #[test]
431 fn test_alpha_zero_uniform() {
432 let corpora = sample_corpora();
433 let probs = MultilingualBpeTokenizer::compute_language_probs(&corpora, 0.0)
434 .expect("should compute probs");
435 let expected = 1.0 / corpora.len() as f64;
437 for (lang, &p) in &probs {
438 assert!(
439 (p - expected).abs() < 1e-9,
440 "lang {} prob {} != uniform {}",
441 lang,
442 p,
443 expected
444 );
445 }
446 }
447
448 #[test]
449 fn test_alpha_one_proportional() {
450 let corpora = sample_corpora();
451 let total_weight: f64 = corpora.iter().map(|c| c.weight).sum();
452 let probs = MultilingualBpeTokenizer::compute_language_probs(&corpora, 1.0)
453 .expect("should compute probs");
454 for corpus in &corpora {
455 let expected = corpus.weight / total_weight;
456 let got = probs[&corpus.language];
457 assert!(
458 (got - expected).abs() < 1e-9,
459 "lang {} prob {} != proportional {}",
460 corpus.language,
461 got,
462 expected
463 );
464 }
465 }
466
467 #[test]
468 fn test_train_vocab_size() {
469 let corpora = sample_corpora();
470 let config = MultilingualBpeConfig {
471 vocab_size: 400,
472 alpha: 0.5,
473 min_frequency: 1,
474 add_prefix_space: true,
475 };
476 let tok = MultilingualBpeTokenizer::train(&corpora, config);
477 assert!(tok.vocab_size() <= 400);
478 assert!(tok.vocab_size() >= 256);
479 }
480
481 #[test]
482 fn test_encode_with_language() {
483 let corpora = sample_corpora();
484 let config = MultilingualBpeConfig {
485 vocab_size: 400,
486 alpha: 0.5,
487 min_frequency: 1,
488 add_prefix_space: true,
489 };
490 let tok = MultilingualBpeTokenizer::train(&corpora, config);
491 let ids_en = tok.encode_with_language("hello world", "en");
492 let ids_de = tok.encode_with_language("hello world", "de");
493 assert_eq!(ids_en, ids_de);
495 }
496
497 #[test]
498 fn test_vocabulary_coverage() {
499 let corpora = sample_corpora();
500 let config = MultilingualBpeConfig {
501 vocab_size: 500,
502 alpha: 0.5,
503 min_frequency: 1,
504 add_prefix_space: false,
505 };
506 let tok = MultilingualBpeTokenizer::train(&corpora, config);
507 let coverage = tok.vocabulary_coverage(&["hello", "rust", "world"]);
508 assert!(
509 (0.0..=1.0).contains(&coverage),
510 "coverage should be in [0,1]"
511 );
512 }
513}