scirs2_text/tokenization/
unicode_bpe.rs1use crate::error::{Result, TextError};
8use std::collections::HashMap;
9
10#[non_exhaustive]
16#[derive(Debug, Clone)]
17pub struct UnicodeBpeConfig {
18 pub vocab_size: usize,
20 pub min_frequency: usize,
22 pub normalize: bool,
24 pub byte_fallback: bool,
26}
27
28impl Default for UnicodeBpeConfig {
29 fn default() -> Self {
30 Self {
31 vocab_size: 32_000,
32 min_frequency: 2,
33 normalize: true,
34 byte_fallback: true,
35 }
36 }
37}
38
39fn nfc_normalize(s: &str) -> String {
48 s.chars()
49 .filter(|c| !c.is_control() || c.is_whitespace())
50 .collect()
51}
52
53pub struct UnicodeBpeTokenizer {
59 config: UnicodeBpeConfig,
60 vocab: HashMap<String, u32>,
62 id_to_token: Vec<String>,
64 merges: Vec<(String, String)>,
66 special_tokens: Vec<String>,
68}
69
70struct MergeResult {
72 pair: (String, String),
73 freq: usize,
74 new_token: String,
75}
76
77impl UnicodeBpeTokenizer {
78 pub fn new(config: UnicodeBpeConfig) -> Self {
80 Self {
81 config,
82 vocab: HashMap::new(),
83 id_to_token: Vec::new(),
84 merges: Vec::new(),
85 special_tokens: vec!["<unk>".into(), "<s>".into(), "</s>".into(), "<pad>".into()],
86 }
87 }
88
89 pub fn train(&mut self, corpus: &[&str]) -> Result<()> {
95 if corpus.is_empty() {
96 return Err(TextError::InvalidInput(
97 "BPE training corpus must not be empty".into(),
98 ));
99 }
100
101 let words: Vec<String> = corpus
103 .iter()
104 .flat_map(|doc| {
105 let normalized = if self.config.normalize {
106 nfc_normalize(doc)
107 } else {
108 doc.to_string()
109 };
110 normalized
111 .split_whitespace()
112 .map(|w| w.to_owned())
113 .collect::<Vec<_>>()
114 })
115 .filter(|w| !w.is_empty())
116 .collect();
117
118 if words.is_empty() {
119 return Err(TextError::InvalidInput(
120 "corpus has no words after split".into(),
121 ));
122 }
123
124 let mut word_freq: HashMap<String, usize> = HashMap::new();
126 for word in &words {
127 *word_freq.entry(word.clone()).or_insert(0) += 1;
128 }
129
130 let mut word_splits: HashMap<String, Vec<String>> = word_freq
133 .keys()
134 .map(|w| {
135 let chars: Vec<String> = w.chars().map(|c| c.to_string()).collect();
136 (w.clone(), chars)
137 })
138 .collect();
139
140 let mut base_chars: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
142 for chars in word_splits.values() {
143 for c in chars {
144 base_chars.insert(c.clone());
145 }
146 }
147
148 self.vocab.clear();
150 self.id_to_token.clear();
151 self.merges.clear();
152
153 for sp in &self.special_tokens {
154 let id = self.id_to_token.len() as u32;
155 self.vocab.insert(sp.clone(), id);
156 self.id_to_token.push(sp.clone());
157 }
158 for c in &base_chars {
159 if !self.vocab.contains_key(c) {
160 let id = self.id_to_token.len() as u32;
161 self.vocab.insert(c.clone(), id);
162 self.id_to_token.push(c.clone());
163 }
164 }
165
166 let max_merges = self.config.vocab_size.saturating_sub(self.vocab.len());
168
169 for _ in 0..max_merges {
170 let mut pair_freq: HashMap<(String, String), usize> = HashMap::new();
172 for (word, freq) in &word_freq {
173 let chars = match word_splits.get(word) {
174 Some(c) => c,
175 None => continue,
176 };
177 for window in chars.windows(2) {
178 *pair_freq
179 .entry((window[0].clone(), window[1].clone()))
180 .or_insert(0) += freq;
181 }
182 }
183
184 let best = pair_freq
186 .iter()
187 .filter(|(_, &freq)| freq >= self.config.min_frequency)
188 .max_by_key(|((a, b), &freq)| (freq, std::cmp::Reverse((a.clone(), b.clone()))));
189
190 let merge = match best {
191 Some(((a, b), &freq)) => MergeResult {
192 pair: (a.clone(), b.clone()),
193 freq,
194 new_token: format!("{}{}", a, b),
195 },
196 None => break, };
198
199 if merge.freq < self.config.min_frequency {
200 break;
201 }
202
203 if !self.vocab.contains_key(&merge.new_token) {
205 let id = self.id_to_token.len() as u32;
206 self.vocab.insert(merge.new_token.clone(), id);
207 self.id_to_token.push(merge.new_token.clone());
208 }
209 self.merges.push(merge.pair.clone());
210
211 let (ref left, ref right) = merge.pair;
213 for chars in word_splits.values_mut() {
214 let mut i = 0;
215 while i + 1 < chars.len() {
216 if chars[i] == *left && chars[i + 1] == *right {
217 let merged = format!("{}{}", chars[i], chars[i + 1]);
218 chars.splice(i..=i + 1, std::iter::once(merged));
219 } else {
221 i += 1;
222 }
223 }
224 }
225 }
226
227 Ok(())
228 }
229
230 pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
236 if self.vocab.is_empty() {
237 return Err(TextError::ModelNotFitted(
238 "BPE tokenizer has not been trained".into(),
239 ));
240 }
241
242 let normalized = if self.config.normalize {
243 nfc_normalize(text)
244 } else {
245 text.to_string()
246 };
247
248 let unk_id = self.vocab.get("<unk>").copied().unwrap_or(0);
249
250 let mut ids = Vec::new();
251
252 for word in normalized.split_whitespace() {
253 let mut chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
255
256 for (left, right) in &self.merges {
258 let mut i = 0;
259 while i + 1 < chars.len() {
260 if chars[i] == *left && chars[i + 1] == *right {
261 let merged = format!("{}{}", chars[i], chars[i + 1]);
262 chars.splice(i..=i + 1, std::iter::once(merged));
263 } else {
264 i += 1;
265 }
266 }
267 }
268
269 for tok in chars {
270 if let Some(&id) = self.vocab.get(&tok) {
271 ids.push(id);
272 } else if self.config.byte_fallback {
273 for byte in tok.as_bytes() {
275 let byte_tok = format!("<0x{:02X}>", byte);
276 let id = self.vocab.get(&byte_tok).copied().unwrap_or(unk_id);
277 ids.push(id);
278 }
279 } else {
280 ids.push(unk_id);
281 }
282 }
283 }
284
285 Ok(ids)
286 }
287
288 pub fn decode(&self, ids: &[u32]) -> Result<String> {
294 if self.id_to_token.is_empty() {
295 return Err(TextError::ModelNotFitted(
296 "BPE tokenizer has not been trained".into(),
297 ));
298 }
299 let mut parts = Vec::new();
300 for &id in ids {
301 let idx = id as usize;
302 if idx >= self.id_to_token.len() {
303 return Err(TextError::InvalidInput(format!(
304 "token id {} out of vocabulary range {}",
305 id,
306 self.id_to_token.len()
307 )));
308 }
309 parts.push(self.id_to_token[idx].clone());
310 }
311 Ok(parts.join(" "))
312 }
313
314 pub fn vocab_size(&self) -> usize {
316 self.vocab.len()
317 }
318
319 pub fn n_merges(&self) -> usize {
321 self.merges.len()
322 }
323
324 pub fn vocab(&self) -> &HashMap<String, u32> {
326 &self.vocab
327 }
328}
329
330#[cfg(test)]
335mod tests {
336 use super::*;
337
338 fn small_corpus() -> Vec<&'static str> {
339 vec![
340 "low lower lowest",
341 "new newer newest",
342 "low new lower newest",
343 "the lowest number",
344 ]
345 }
346
347 #[test]
348 fn test_default_config() {
349 let cfg = UnicodeBpeConfig::default();
350 assert_eq!(cfg.vocab_size, 32_000);
351 assert_eq!(cfg.min_frequency, 2);
352 assert!(cfg.normalize);
353 assert!(cfg.byte_fallback);
354 }
355
356 #[test]
357 fn test_train_empty_corpus_error() {
358 let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
359 let result = tok.train(&[]);
360 assert!(result.is_err(), "empty corpus must return error");
361 }
362
363 #[test]
364 fn test_train_succeeds() {
365 let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
366 tok.train(&small_corpus()).expect("train failed");
367 assert!(
368 tok.vocab_size() > 0,
369 "vocab should be non-empty after training"
370 );
371 }
372
373 #[test]
374 fn test_vocab_size_bounded() {
375 let config = UnicodeBpeConfig {
376 vocab_size: 20,
377 min_frequency: 1,
378 ..Default::default()
379 };
380 let mut tok = UnicodeBpeTokenizer::new(config);
381 tok.train(&small_corpus()).expect("train failed");
382 assert!(
383 tok.vocab_size() <= 20,
384 "vocab size {} must be <= 20",
385 tok.vocab_size()
386 );
387 }
388
389 #[test]
390 fn test_encode_returns_ids() {
391 let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig {
392 min_frequency: 1,
393 ..Default::default()
394 });
395 tok.train(&small_corpus()).expect("train failed");
396 let ids = tok.encode("low").expect("encode failed");
397 assert!(
398 !ids.is_empty(),
399 "encoding 'low' should produce at least one id"
400 );
401 }
402
403 #[test]
404 fn test_encode_before_train_error() {
405 let tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
406 let result = tok.encode("hello");
407 assert!(result.is_err(), "encode before train must return error");
408 }
409
410 #[test]
411 fn test_decode_before_train_error() {
412 let tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
413 let result = tok.decode(&[0, 1]);
414 assert!(result.is_err(), "decode before train must return error");
415 }
416
417 #[test]
418 fn test_n_merges_increases_with_training() {
419 let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig {
420 vocab_size: 50,
421 min_frequency: 1,
422 ..Default::default()
423 });
424 tok.train(&small_corpus()).expect("train failed");
425 assert!(tok.n_merges() > 0, "should have at least one merge");
426 }
427
428 #[test]
429 fn test_special_tokens_in_vocab() {
430 let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig::default());
431 tok.train(&small_corpus()).expect("train failed");
432 assert!(tok.vocab().contains_key("<unk>"));
433 assert!(tok.vocab().contains_key("<s>"));
434 assert!(tok.vocab().contains_key("</s>"));
435 }
436
437 #[test]
438 fn test_decode_special_token() {
439 let mut tok = UnicodeBpeTokenizer::new(UnicodeBpeConfig {
440 min_frequency: 1,
441 ..Default::default()
442 });
443 tok.train(&small_corpus()).expect("train failed");
444 let unk_id = tok.vocab()["<unk>"];
445 let decoded = tok.decode(&[unk_id]).expect("decode failed");
446 assert_eq!(decoded, "<unk>");
447 }
448}