1use crate::error::{Result, TextError};
7use crate::tokenize::Tokenizer;
8use std::collections::HashMap;
9use std::fmt;
10use std::fs::File;
11use std::io::{BufReader, BufWriter, Read, Write};
12use std::path::Path;
13
14type TokenPair = (String, String);
16
17#[derive(Clone)]
19pub struct BpeVocabulary {
20 pub token_to_id: HashMap<String, usize>,
22 pub id_to_token: HashMap<usize, String>,
24 pub merges: HashMap<TokenPair, String>,
26}
27
28impl fmt::Debug for BpeVocabulary {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 f.debug_struct("BpeVocabulary")
31 .field("vocab_size", &self.token_to_id.len())
32 .field("num_merges", &self.merges.len())
33 .finish()
34 }
35}
36
37impl BpeVocabulary {
38 pub fn new() -> Self {
40 Self {
41 token_to_id: HashMap::new(),
42 id_to_token: HashMap::new(),
43 merges: HashMap::new(),
44 }
45 }
46
47 pub fn add_token(&mut self, token: &str) -> usize {
49 if let Some(&id) = self.token_to_id.get(token) {
50 return id;
51 }
52
53 let id = self.token_to_id.len();
54 self.token_to_id.insert(token.to_string(), id);
55 self.id_to_token.insert(id, token.to_string());
56 id
57 }
58
59 pub fn add_merge(&mut self, pair: TokenPair, merged: String) {
61 self.merges.insert(pair, merged);
62 }
63
64 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
66 let file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
67 let mut writer = BufWriter::new(file);
68
69 writeln!(writer, "{}", self.token_to_id.len())
71 .map_err(|e| TextError::IoError(e.to_string()))?;
72
73 for (token, id) in &self.token_to_id {
75 writeln!(writer, "{token}\t{id}").map_err(|e| TextError::IoError(e.to_string()))?;
76 }
77
78 writeln!(writer, "{}", self.merges.len()).map_err(|e| TextError::IoError(e.to_string()))?;
80
81 for ((first, second), merged) in &self.merges {
83 writeln!(writer, "{first}\t{second}\t{merged}")
84 .map_err(|e| TextError::IoError(e.to_string()))?;
85 }
86
87 Ok(())
88 }
89
90 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
92 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
93 let mut reader = BufReader::new(file);
94 let mut content = String::new();
95 reader
96 .read_to_string(&mut content)
97 .map_err(|e| TextError::IoError(e.to_string()))?;
98
99 let mut lines = content.lines();
100
101 let vocab_size: usize = lines
103 .next()
104 .ok_or_else(|| TextError::IoError("Unexpected end of file".to_string()))?
105 .parse()
106 .map_err(|e| TextError::IoError(format!("Invalid vocabulary size: {e}")))?;
107
108 let mut vocabulary = Self::new();
109
110 for _ in 0..vocab_size {
112 let line = lines
113 .next()
114 .ok_or_else(|| TextError::IoError("Unexpected end of file".to_string()))?;
115 let parts: Vec<&str> = line.split('\t').collect();
116
117 if parts.len() != 2 {
118 return Err(TextError::IoError(format!(
119 "Invalid vocabulary entry: {line}"
120 )));
121 }
122
123 let token = parts[0].to_string();
124 let id: usize = parts[1]
125 .parse()
126 .map_err(|e| TextError::IoError(format!("Invalid token ID: {e}")))?;
127
128 vocabulary.token_to_id.insert(token.clone(), id);
129 vocabulary.id_to_token.insert(id, token);
130 }
131
132 let num_merges: usize = lines
134 .next()
135 .ok_or_else(|| TextError::IoError("Unexpected end of file".to_string()))?
136 .parse()
137 .map_err(|e| TextError::IoError(format!("Invalid number of merges: {e}")))?;
138
139 for _ in 0..num_merges {
141 let line = lines
142 .next()
143 .ok_or_else(|| TextError::IoError("Unexpected end of file".to_string()))?;
144 let parts: Vec<&str> = line.split('\t').collect();
145
146 if parts.len() != 3 {
147 return Err(TextError::IoError(format!("Invalid merge rule: {line}")));
148 }
149
150 let first = parts[0].to_string();
151 let second = parts[1].to_string();
152 let merged = parts[2].to_string();
153
154 vocabulary.merges.insert((first, second), merged);
155 }
156
157 Ok(vocabulary)
158 }
159}
160
161impl Default for BpeVocabulary {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct BpeConfig {
170 pub vocab_size: usize,
172 pub min_frequency: usize,
174 pub special_tokens: Vec<String>,
176 pub character_level: bool,
178 pub lowercase: bool,
180}
181
182impl Default for BpeConfig {
183 fn default() -> Self {
184 Self {
185 vocab_size: 30000,
186 min_frequency: 2,
187 special_tokens: vec![],
188 character_level: true,
189 lowercase: true,
190 }
191 }
192}
193
194#[derive(Debug, Clone)]
199pub struct BpeTokenizer {
200 config: BpeConfig,
202 vocabulary: Option<BpeVocabulary>,
204}
205
206impl BpeTokenizer {
207 pub fn new(config: BpeConfig) -> Self {
209 Self {
210 config,
211 vocabulary: Some(BpeVocabulary::new()),
212 }
213 }
214
215 pub fn with_defaults() -> Self {
217 Self::new(BpeConfig::default())
218 }
219
220 pub fn vocab_size(&self) -> usize {
222 match &self.vocabulary {
223 Some(vocab) => vocab.token_to_id.len(),
224 None => 0,
225 }
226 }
227
228 pub fn has_vocabulary(&self) -> bool {
230 self.vocabulary.is_some()
231 }
232
233 pub fn vocabulary(&self) -> Option<&BpeVocabulary> {
235 self.vocabulary.as_ref()
236 }
237
238 pub fn set_vocabulary(&mut self, vocabulary: BpeVocabulary) {
240 self.vocabulary = Some(vocabulary);
241 }
242
243 pub fn save_vocabulary(&self, path: impl AsRef<Path>) -> Result<()> {
245 match &self.vocabulary {
246 Some(vocab) => vocab.save(path),
247 None => Err(TextError::TokenizationError(
248 "No vocabulary available to save".to_string(),
249 )),
250 }
251 }
252
253 pub fn load_vocabulary(&mut self, path: impl AsRef<Path>) -> Result<()> {
255 self.vocabulary = Some(BpeVocabulary::load(path)?);
256 Ok(())
257 }
258
259 pub fn train(&mut self, corpus: &[&str]) -> Result<()> {
261 if corpus.is_empty() {
262 return Err(TextError::TokenizationError(
263 "Cannot train on empty corpus".to_string(),
264 ));
265 }
266
267 let mut vocabulary = BpeVocabulary::new();
268
269 for token in &self.config.special_tokens {
271 vocabulary.add_token(token);
272 }
273
274 let mut token_counts = HashMap::new();
276 let mut all_tokens = Vec::new();
277
278 for text in corpus {
279 let processedtext = if self.config.lowercase {
280 text.to_lowercase()
281 } else {
282 text.to_string()
283 };
284
285 if self.config.character_level {
288 let initial_tokens: Vec<String> =
290 processedtext.chars().map(|c| c.to_string()).collect();
291 for token in &initial_tokens {
293 *token_counts.entry(token.clone()).or_insert(0) += 1;
294 }
295 all_tokens.push(initial_tokens);
296 } else {
297 for word in processedtext.split_whitespace() {
299 let chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
300 for token in &chars {
302 *token_counts.entry(token.clone()).or_insert(0) += 1;
303 }
304 all_tokens.push(chars);
305 }
306 };
307
308 }
310
311 for (token, &count) in &token_counts {
313 if count >= self.config.min_frequency {
314 vocabulary.add_token(token);
315 }
316 }
317
318 let mut merges = Vec::new();
320 let max_merges = self.config.vocab_size - vocabulary.token_to_id.len();
321
322 for _ in 0..max_merges {
323 let mut pair_counts = HashMap::new();
325 let mut pair_to_merged = HashMap::new();
326
327 for tokens in &all_tokens {
328 for window in tokens.windows(2) {
329 if window.len() < 2 {
330 continue;
331 }
332
333 let pair = (window[0].clone(), window[1].clone());
334 let pair_0 = &pair.0;
335 let pair_1 = &pair.1;
336 let merged = format!("{pair_0}{pair_1}");
337 *pair_counts.entry(pair.clone()).or_insert(0) += 1;
338 pair_to_merged.insert(pair, merged);
339 }
340 }
341
342 let best_pair = pair_counts
344 .iter()
345 .max_by_key(|&(_, count)| count)
346 .map(|(pair_, _)| pair_.clone());
347
348 if let Some(pair) = best_pair {
349 let merged = pair_to_merged[&pair].clone();
350
351 vocabulary.add_token(&merged);
353
354 vocabulary.add_merge(pair.clone(), merged.clone());
356 merges.push((pair.clone(), merged.clone()));
357
358 for tokens in &mut all_tokens {
360 let mut i = 0;
361 while i < tokens.len() - 1 {
362 if i < tokens.len() - 1 && tokens[i] == pair.0 && tokens[i + 1] == pair.1 {
363 tokens[i] = merged.clone();
364 tokens.remove(i + 1);
365 } else {
366 i += 1;
367 }
368 }
369 }
370 } else {
371 break;
373 }
374 }
375
376 self.vocabulary = Some(vocabulary);
377 Ok(())
378 }
379
380 fn tokenize_word(&self, word: &str) -> Result<Vec<String>> {
382 let vocab = match &self.vocabulary {
383 Some(v) => v,
384 None => {
385 return Err(TextError::TokenizationError(
386 "Tokenizer vocabulary not initialized. Call train() first".to_string(),
387 ))
388 }
389 };
390
391 let mut tokens: Vec<String> = word.chars().map(|c| c.to_string()).collect();
393
394 let mut has_changes = true;
396 while has_changes {
397 has_changes = false;
398
399 let mut i = 0;
400 while i < tokens.len() - 1 {
401 let pair = (tokens[i].clone(), tokens[i + 1].clone());
402 if let Some(merged) = vocab.merges.get(&pair) {
403 tokens[i] = merged.clone();
404 tokens.remove(i + 1);
405 has_changes = true;
406 } else {
407 i += 1;
408 }
409 }
410 }
411
412 Ok(tokens)
413 }
414}
415
416impl Tokenizer for BpeTokenizer {
417 fn tokenize(&self, text: &str) -> Result<Vec<String>> {
418 if text.trim().is_empty() {
419 return Ok(Vec::new());
420 }
421
422 if !self.has_vocabulary() {
423 return Err(TextError::TokenizationError(
424 "Tokenizer vocabulary not initialized. Call train() first".to_string(),
425 ));
426 }
427
428 let processedtext = if self.config.lowercase {
429 text.to_lowercase()
430 } else {
431 text.to_string()
432 };
433
434 let mut tokens = Vec::new();
435
436 if self.config.character_level {
437 tokens = self.tokenize_word(&processedtext)?;
439 } else {
440 for word in processedtext.split_whitespace() {
442 let word_tokens = self.tokenize_word(word)?;
443 tokens.extend(word_tokens);
444 }
445 }
446
447 Ok(tokens)
448 }
449
450 fn clone_box(&self) -> Box<dyn Tokenizer + Send + Sync> {
451 Box::new(self.clone())
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use tempfile::tempdir;
459
460 #[test]
461 fn test_bpe_tokenizer_train() {
462 let corpus = [
463 "this is a test",
464 "another test",
465 "more tests for testing",
466 "test the tokenizer",
467 ];
468
469 let mut tokenizer = BpeTokenizer::with_defaults();
470 tokenizer.train(&corpus).unwrap();
471
472 assert!(tokenizer.has_vocabulary());
473 assert!(tokenizer.vocab_size() > 0);
474 }
475
476 #[test]
477 fn test_bpe_tokenizer_tokenize() {
478 let corpus = [
479 "this is a test",
480 "another test",
481 "more tests for testing",
482 "test the tokenizer",
483 ];
484
485 let mut tokenizer = BpeTokenizer::with_defaults();
486 tokenizer.train(&corpus).unwrap();
487
488 let tokens = tokenizer.tokenize("this is a tokenizer test").unwrap();
489 assert!(!tokens.is_empty());
490 }
491
492 #[test]
493 fn test_bpe_vocabulary_save_load() {
494 let corpus = [
495 "this is a test",
496 "another test",
497 "more tests for testing",
498 "test the tokenizer",
499 ];
500
501 let mut tokenizer = BpeTokenizer::with_defaults();
502 tokenizer.train(&corpus).unwrap();
503
504 let temp_dir = tempdir().unwrap();
506 let vocab_path = temp_dir.path().join("vocab.bpe");
507
508 tokenizer.save_vocabulary(&vocab_path).unwrap();
510
511 let mut new_tokenizer = BpeTokenizer::with_defaults();
513 new_tokenizer.load_vocabulary(&vocab_path).unwrap();
514
515 let text = "this is a tokenizer test";
517 let tokens1 = tokenizer.tokenize(text).unwrap();
518 let tokens2 = new_tokenizer.tokenize(text).unwrap();
519
520 assert_eq!(tokens1, tokens2);
521 }
522
523 #[test]
524 fn test_bpe_tokenizer_with_special_tokens() {
525 let config = BpeConfig {
526 special_tokens: vec!["<pad>".to_string(), "<unk>".to_string()],
527 ..Default::default()
528 };
529
530 let corpus = [
531 "this is a test",
532 "another test",
533 "more tests for testing",
534 "test the tokenizer",
535 ];
536
537 let mut tokenizer = BpeTokenizer::new(config);
538 tokenizer.train(&corpus).unwrap();
539
540 let vocab = tokenizer.vocabulary().unwrap();
541 assert!(vocab.token_to_id.contains_key("<pad>"));
542 assert!(vocab.token_to_id.contains_key("<unk>"));
543 }
544
545 #[test]
546 fn test_bpe_tokenizer_emptytext() {
547 let corpus = ["this is a test"];
548 let mut tokenizer = BpeTokenizer::with_defaults();
549 tokenizer.train(&corpus).unwrap();
550
551 let tokens = tokenizer.tokenize("").unwrap();
552 assert_eq!(tokens.len(), 0);
553 }
554
555 #[test]
556 fn test_bpe_tokenizer_case_sensitivity() {
557 let corpus = ["This IS a TEST"];
558
559 let mut tokenizer1 = BpeTokenizer::with_defaults();
561 tokenizer1.train(&corpus).unwrap();
562 let tokens1 = tokenizer1.tokenize("THIS is A test").unwrap();
563
564 let config = BpeConfig {
566 lowercase: false,
567 ..Default::default()
568 };
569 let mut tokenizer2 = BpeTokenizer::new(config);
570 tokenizer2.train(&corpus).unwrap();
571 let tokens2 = tokenizer2.tokenize("THIS is A test").unwrap();
572
573 assert!(tokens1.len() <= tokens2.len());
575 }
576
577 #[test]
578 fn test_bpe_tokenizer_no_vocabulary() {
579 let mut tokenizer = BpeTokenizer::with_defaults();
581 tokenizer.vocabulary = None;
582
583 let result = tokenizer.tokenize("test");
585 assert!(result.is_err()); }
587}