scirs2_text/tokenize/
bpe.rs

1//! Byte Pair Encoding (BPE) tokenizer implementation
2//!
3//! This module provides a BPE tokenizer which can learn and apply
4//! subword tokenization based on the most frequent byte pairs.
5
6use crate::error::{Result, TextError};
7use crate::tokenize::Tokenizer;
8use std::collections::HashMap;
9use std::fmt;
10use std::fs::File;
11use std::io::{BufReader, BufWriter, Read, Write};
12use std::path::Path;
13
14/// A pair of tokens
15type TokenPair = (String, String);
16
17/// A vocabulary for BPE tokenization
18#[derive(Clone)]
19pub struct BpeVocabulary {
20    /// Token to ID mapping
21    pub token_to_id: HashMap<String, usize>,
22    /// ID to token mapping
23    pub id_to_token: HashMap<usize, String>,
24    /// Merge rules (token pair -> merged token)
25    pub merges: HashMap<TokenPair, String>,
26}
27
28impl fmt::Debug for BpeVocabulary {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        f.debug_struct("BpeVocabulary")
31            .field("vocab_size", &self.token_to_id.len())
32            .field("num_merges", &self.merges.len())
33            .finish()
34    }
35}
36
37impl BpeVocabulary {
38    /// Create a new empty BPE vocabulary
39    pub fn new() -> Self {
40        Self {
41            token_to_id: HashMap::new(),
42            id_to_token: HashMap::new(),
43            merges: HashMap::new(),
44        }
45    }
46
47    /// Add a token to the vocabulary
48    pub fn add_token(&mut self, token: &str) -> usize {
49        if let Some(&id) = self.token_to_id.get(token) {
50            return id;
51        }
52
53        let id = self.token_to_id.len();
54        self.token_to_id.insert(token.to_string(), id);
55        self.id_to_token.insert(id, token.to_string());
56        id
57    }
58
59    /// Add a merge rule to the vocabulary
60    pub fn add_merge(&mut self, pair: TokenPair, merged: String) {
61        self.merges.insert(pair, merged);
62    }
63
64    /// Save the vocabulary to a file
65    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
66        let file = File::create(path).map_err(|e| TextError::IoError(e.to_string()))?;
67        let mut writer = BufWriter::new(file);
68
69        // Write the vocabulary size
70        writeln!(writer, "{}", self.token_to_id.len())
71            .map_err(|e| TextError::IoError(e.to_string()))?;
72
73        // Write the tokens and their IDs
74        for (token, id) in &self.token_to_id {
75            writeln!(writer, "{token}\t{id}").map_err(|e| TextError::IoError(e.to_string()))?;
76        }
77
78        // Write the number of merges
79        writeln!(writer, "{}", self.merges.len()).map_err(|e| TextError::IoError(e.to_string()))?;
80
81        // Write the merge rules
82        for ((first, second), merged) in &self.merges {
83            writeln!(writer, "{first}\t{second}\t{merged}")
84                .map_err(|e| TextError::IoError(e.to_string()))?;
85        }
86
87        Ok(())
88    }
89
90    /// Load a vocabulary from a file
91    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
92        let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
93        let mut reader = BufReader::new(file);
94        let mut content = String::new();
95        reader
96            .read_to_string(&mut content)
97            .map_err(|e| TextError::IoError(e.to_string()))?;
98
99        let mut lines = content.lines();
100
101        // Read the vocabulary size
102        let vocab_size: usize = lines
103            .next()
104            .ok_or_else(|| TextError::IoError("Unexpected end of file".to_string()))?
105            .parse()
106            .map_err(|e| TextError::IoError(format!("Invalid vocabulary size: {e}")))?;
107
108        let mut vocabulary = Self::new();
109
110        // Read the tokens and their IDs
111        for _ in 0..vocab_size {
112            let line = lines
113                .next()
114                .ok_or_else(|| TextError::IoError("Unexpected end of file".to_string()))?;
115            let parts: Vec<&str> = line.split('\t').collect();
116
117            if parts.len() != 2 {
118                return Err(TextError::IoError(format!(
119                    "Invalid vocabulary entry: {line}"
120                )));
121            }
122
123            let token = parts[0].to_string();
124            let id: usize = parts[1]
125                .parse()
126                .map_err(|e| TextError::IoError(format!("Invalid token ID: {e}")))?;
127
128            vocabulary.token_to_id.insert(token.clone(), id);
129            vocabulary.id_to_token.insert(id, token);
130        }
131
132        // Read the number of merges
133        let num_merges: usize = lines
134            .next()
135            .ok_or_else(|| TextError::IoError("Unexpected end of file".to_string()))?
136            .parse()
137            .map_err(|e| TextError::IoError(format!("Invalid number of merges: {e}")))?;
138
139        // Read the merge rules
140        for _ in 0..num_merges {
141            let line = lines
142                .next()
143                .ok_or_else(|| TextError::IoError("Unexpected end of file".to_string()))?;
144            let parts: Vec<&str> = line.split('\t').collect();
145
146            if parts.len() != 3 {
147                return Err(TextError::IoError(format!("Invalid merge rule: {line}")));
148            }
149
150            let first = parts[0].to_string();
151            let second = parts[1].to_string();
152            let merged = parts[2].to_string();
153
154            vocabulary.merges.insert((first, second), merged);
155        }
156
157        Ok(vocabulary)
158    }
159}
160
161impl Default for BpeVocabulary {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167/// Configuration for BPE tokenizer
168#[derive(Debug, Clone)]
169pub struct BpeConfig {
170    /// The maximum vocabulary size
171    pub vocab_size: usize,
172    /// The minimum frequency for a token to be included in the vocabulary
173    pub min_frequency: usize,
174    /// Special tokens to add to the vocabulary
175    pub special_tokens: Vec<String>,
176    /// Whether to treat characters as the base tokens
177    pub character_level: bool,
178    /// Whether to lowercase the input text
179    pub lowercase: bool,
180}
181
182impl Default for BpeConfig {
183    fn default() -> Self {
184        Self {
185            vocab_size: 30000,
186            min_frequency: 2,
187            special_tokens: vec![],
188            character_level: true,
189            lowercase: true,
190        }
191    }
192}
193
194/// A Byte Pair Encoding (BPE) tokenizer
195///
196/// BPE is a subword tokenization algorithm that iteratively merges the most
197/// frequent pairs of tokens (bytes or characters) to form new tokens.
198#[derive(Debug, Clone)]
199pub struct BpeTokenizer {
200    /// Tokenizer configuration
201    config: BpeConfig,
202    /// The vocabulary for the tokenizer
203    vocabulary: Option<BpeVocabulary>,
204}
205
206impl BpeTokenizer {
207    /// Create a new BPE tokenizer with the given configuration
208    pub fn new(config: BpeConfig) -> Self {
209        Self {
210            config,
211            vocabulary: Some(BpeVocabulary::new()),
212        }
213    }
214
215    /// Create a new BPE tokenizer with default configuration
216    pub fn with_defaults() -> Self {
217        Self::new(BpeConfig::default())
218    }
219
220    /// Get the vocabulary size
221    pub fn vocab_size(&self) -> usize {
222        match &self.vocabulary {
223            Some(vocab) => vocab.token_to_id.len(),
224            None => 0,
225        }
226    }
227
228    /// Check if the tokenizer has a vocabulary
229    pub fn has_vocabulary(&self) -> bool {
230        self.vocabulary.is_some()
231    }
232
233    /// Get a reference to the tokenizer's vocabulary
234    pub fn vocabulary(&self) -> Option<&BpeVocabulary> {
235        self.vocabulary.as_ref()
236    }
237
238    /// Set the tokenizer's vocabulary
239    pub fn set_vocabulary(&mut self, vocabulary: BpeVocabulary) {
240        self.vocabulary = Some(vocabulary);
241    }
242
243    /// Save the tokenizer's vocabulary to a file
244    pub fn save_vocabulary(&self, path: impl AsRef<Path>) -> Result<()> {
245        match &self.vocabulary {
246            Some(vocab) => vocab.save(path),
247            None => Err(TextError::TokenizationError(
248                "No vocabulary available to save".to_string(),
249            )),
250        }
251    }
252
253    /// Load the tokenizer's vocabulary from a file
254    pub fn load_vocabulary(&mut self, path: impl AsRef<Path>) -> Result<()> {
255        self.vocabulary = Some(BpeVocabulary::load(path)?);
256        Ok(())
257    }
258
259    /// Train the BPE tokenizer on a corpus
260    pub fn train(&mut self, corpus: &[&str]) -> Result<()> {
261        if corpus.is_empty() {
262            return Err(TextError::TokenizationError(
263                "Cannot train on empty corpus".to_string(),
264            ));
265        }
266
267        let mut vocabulary = BpeVocabulary::new();
268
269        // Initialize vocabulary with special tokens
270        for token in &self.config.special_tokens {
271            vocabulary.add_token(token);
272        }
273
274        // Count initial tokens (characters or words)
275        let mut token_counts = HashMap::new();
276        let mut all_tokens = Vec::new();
277
278        for text in corpus {
279            let processedtext = if self.config.lowercase {
280                text.to_lowercase()
281            } else {
282                text.to_string()
283            };
284
285            // For character-level tokenization, we operate directly on characters
286            // For word-level tokenization, we need to process each word separately
287            if self.config.character_level {
288                // Character-level tokenization
289                let initial_tokens: Vec<String> =
290                    processedtext.chars().map(|c| c.to_string()).collect();
291                // Add character sequence directly
292                for token in &initial_tokens {
293                    *token_counts.entry(token.clone()).or_insert(0) += 1;
294                }
295                all_tokens.push(initial_tokens);
296            } else {
297                // Word-level tokenization with characters as base tokens
298                for word in processedtext.split_whitespace() {
299                    let chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
300                    // Count individual characters
301                    for token in &chars {
302                        *token_counts.entry(token.clone()).or_insert(0) += 1;
303                    }
304                    all_tokens.push(chars);
305                }
306            };
307
308            // The token counting is now handled in the previous block
309        }
310
311        // Add initial tokens to vocabulary
312        for (token, &count) in &token_counts {
313            if count >= self.config.min_frequency {
314                vocabulary.add_token(token);
315            }
316        }
317
318        // Train BPE on the corpus
319        let mut merges = Vec::new();
320        let max_merges = self.config.vocab_size - vocabulary.token_to_id.len();
321
322        for _ in 0..max_merges {
323            // Count token pairs
324            let mut pair_counts = HashMap::new();
325            let mut pair_to_merged = HashMap::new();
326
327            for tokens in &all_tokens {
328                for window in tokens.windows(2) {
329                    if window.len() < 2 {
330                        continue;
331                    }
332
333                    let pair = (window[0].clone(), window[1].clone());
334                    let pair_0 = &pair.0;
335                    let pair_1 = &pair.1;
336                    let merged = format!("{pair_0}{pair_1}");
337                    *pair_counts.entry(pair.clone()).or_insert(0) += 1;
338                    pair_to_merged.insert(pair, merged);
339                }
340            }
341
342            // Find the most frequent pair
343            let best_pair = pair_counts
344                .iter()
345                .max_by_key(|&(_, count)| count)
346                .map(|(pair_, _)| pair_.clone());
347
348            if let Some(pair) = best_pair {
349                let merged = pair_to_merged[&pair].clone();
350
351                // Add the merged token to the vocabulary
352                vocabulary.add_token(&merged);
353
354                // Add the merge rule
355                vocabulary.add_merge(pair.clone(), merged.clone());
356                merges.push((pair.clone(), merged.clone()));
357
358                // Update tokens by applying the merge
359                for tokens in &mut all_tokens {
360                    let mut i = 0;
361                    while i < tokens.len() - 1 {
362                        if i < tokens.len() - 1 && tokens[i] == pair.0 && tokens[i + 1] == pair.1 {
363                            tokens[i] = merged.clone();
364                            tokens.remove(i + 1);
365                        } else {
366                            i += 1;
367                        }
368                    }
369                }
370            } else {
371                // No more pairs to merge
372                break;
373            }
374        }
375
376        self.vocabulary = Some(vocabulary);
377        Ok(())
378    }
379
380    /// Apply BPE to tokenize a word
381    fn tokenize_word(&self, word: &str) -> Result<Vec<String>> {
382        let vocab = match &self.vocabulary {
383            Some(v) => v,
384            None => {
385                return Err(TextError::TokenizationError(
386                    "Tokenizer vocabulary not initialized. Call train() first".to_string(),
387                ))
388            }
389        };
390
391        // Split word into characters
392        let mut tokens: Vec<String> = word.chars().map(|c| c.to_string()).collect();
393
394        // Apply merges
395        let mut has_changes = true;
396        while has_changes {
397            has_changes = false;
398
399            let mut i = 0;
400            while i < tokens.len() - 1 {
401                let pair = (tokens[i].clone(), tokens[i + 1].clone());
402                if let Some(merged) = vocab.merges.get(&pair) {
403                    tokens[i] = merged.clone();
404                    tokens.remove(i + 1);
405                    has_changes = true;
406                } else {
407                    i += 1;
408                }
409            }
410        }
411
412        Ok(tokens)
413    }
414}
415
416impl Tokenizer for BpeTokenizer {
417    fn tokenize(&self, text: &str) -> Result<Vec<String>> {
418        if text.trim().is_empty() {
419            return Ok(Vec::new());
420        }
421
422        if !self.has_vocabulary() {
423            return Err(TextError::TokenizationError(
424                "Tokenizer vocabulary not initialized. Call train() first".to_string(),
425            ));
426        }
427
428        let processedtext = if self.config.lowercase {
429            text.to_lowercase()
430        } else {
431            text.to_string()
432        };
433
434        let mut tokens = Vec::new();
435
436        if self.config.character_level {
437            // Tokenize as a single sequence
438            tokens = self.tokenize_word(&processedtext)?;
439        } else {
440            // Tokenize each word separately
441            for word in processedtext.split_whitespace() {
442                let word_tokens = self.tokenize_word(word)?;
443                tokens.extend(word_tokens);
444            }
445        }
446
447        Ok(tokens)
448    }
449
450    fn clone_box(&self) -> Box<dyn Tokenizer + Send + Sync> {
451        Box::new(self.clone())
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458    use tempfile::tempdir;
459
460    #[test]
461    fn test_bpe_tokenizer_train() {
462        let corpus = [
463            "this is a test",
464            "another test",
465            "more tests for testing",
466            "test the tokenizer",
467        ];
468
469        let mut tokenizer = BpeTokenizer::with_defaults();
470        tokenizer.train(&corpus).unwrap();
471
472        assert!(tokenizer.has_vocabulary());
473        assert!(tokenizer.vocab_size() > 0);
474    }
475
476    #[test]
477    fn test_bpe_tokenizer_tokenize() {
478        let corpus = [
479            "this is a test",
480            "another test",
481            "more tests for testing",
482            "test the tokenizer",
483        ];
484
485        let mut tokenizer = BpeTokenizer::with_defaults();
486        tokenizer.train(&corpus).unwrap();
487
488        let tokens = tokenizer.tokenize("this is a tokenizer test").unwrap();
489        assert!(!tokens.is_empty());
490    }
491
492    #[test]
493    fn test_bpe_vocabulary_save_load() {
494        let corpus = [
495            "this is a test",
496            "another test",
497            "more tests for testing",
498            "test the tokenizer",
499        ];
500
501        let mut tokenizer = BpeTokenizer::with_defaults();
502        tokenizer.train(&corpus).unwrap();
503
504        // Create a temporary directory for the test
505        let temp_dir = tempdir().unwrap();
506        let vocab_path = temp_dir.path().join("vocab.bpe");
507
508        // Save the vocabulary
509        tokenizer.save_vocabulary(&vocab_path).unwrap();
510
511        // Create a new tokenizer and load the vocabulary
512        let mut new_tokenizer = BpeTokenizer::with_defaults();
513        new_tokenizer.load_vocabulary(&vocab_path).unwrap();
514
515        // Both tokenizers should produce the same tokens
516        let text = "this is a tokenizer test";
517        let tokens1 = tokenizer.tokenize(text).unwrap();
518        let tokens2 = new_tokenizer.tokenize(text).unwrap();
519
520        assert_eq!(tokens1, tokens2);
521    }
522
523    #[test]
524    fn test_bpe_tokenizer_with_special_tokens() {
525        let config = BpeConfig {
526            special_tokens: vec!["<pad>".to_string(), "<unk>".to_string()],
527            ..Default::default()
528        };
529
530        let corpus = [
531            "this is a test",
532            "another test",
533            "more tests for testing",
534            "test the tokenizer",
535        ];
536
537        let mut tokenizer = BpeTokenizer::new(config);
538        tokenizer.train(&corpus).unwrap();
539
540        let vocab = tokenizer.vocabulary().unwrap();
541        assert!(vocab.token_to_id.contains_key("<pad>"));
542        assert!(vocab.token_to_id.contains_key("<unk>"));
543    }
544
545    #[test]
546    fn test_bpe_tokenizer_emptytext() {
547        let corpus = ["this is a test"];
548        let mut tokenizer = BpeTokenizer::with_defaults();
549        tokenizer.train(&corpus).unwrap();
550
551        let tokens = tokenizer.tokenize("").unwrap();
552        assert_eq!(tokens.len(), 0);
553    }
554
555    #[test]
556    fn test_bpe_tokenizer_case_sensitivity() {
557        let corpus = ["This IS a TEST"];
558
559        // Test with lowercase=true (default)
560        let mut tokenizer1 = BpeTokenizer::with_defaults();
561        tokenizer1.train(&corpus).unwrap();
562        let tokens1 = tokenizer1.tokenize("THIS is A test").unwrap();
563
564        // Test with lowercase=false
565        let config = BpeConfig {
566            lowercase: false,
567            ..Default::default()
568        };
569        let mut tokenizer2 = BpeTokenizer::new(config);
570        tokenizer2.train(&corpus).unwrap();
571        let tokens2 = tokenizer2.tokenize("THIS is A test").unwrap();
572
573        // The lowercase tokenizer should produce fewer tokens as it's case-insensitive
574        assert!(tokens1.len() <= tokens2.len());
575    }
576
577    #[test]
578    fn test_bpe_tokenizer_no_vocabulary() {
579        // Create tokenizer with no vocabulary (vocabulary set to None)
580        let mut tokenizer = BpeTokenizer::with_defaults();
581        tokenizer.vocabulary = None;
582
583        // No vocabulary is initialized, so this should fail
584        let result = tokenizer.tokenize("test");
585        assert!(result.is_err()); // This should be an error
586    }
587}