smoltok_core/simple/
tokenizer.rs

1//! Simple BPE tokenizer implementation.
2
3use crate::tokenizer::{
4    MergeRule, Serializable, TokenId, TokenPair, Tokenizer, UnknownTokenId, bpe_encode, build_merge_lookup,
5    build_vocab, save_merges, save_vocab, verify_stok_extension,
6};
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::BufWriter;
10use std::io::Error;
11use std::path::Path;
12
13/// Simple BPE tokenizer. Does not support regex patterns and special tokens.
14#[derive(Debug)]
15pub struct SimpleBPETokenizer {
16    /// The learned merge rules, in order of application.
17    merges: Vec<MergeRule>,
18    /// Maps token pairs to (rank, new_id) for fast lookup during encoding.
19    merge_lookup: HashMap<TokenPair, (usize, TokenId)>,
20    /// Maps token IDs to their byte sequences.
21    vocab: HashMap<TokenId, Vec<u8>>,
22}
23
24impl SimpleBPETokenizer {
25    /// Creates a new `SimpleBPETokenizer` from a list of merge operations.
26    ///
27    /// The vocabulary is automatically reconstructed from the merges.
28    ///
29    /// # Arguments
30    ///
31    /// * `merges` - A vector of `MergeRule` representing the merge operations.
32    pub fn from_merges(merges: Vec<MergeRule>) -> Self {
33        Self {
34            vocab: build_vocab(merges.as_slice()),
35            merge_lookup: build_merge_lookup(merges.as_slice()),
36            merges,
37        }
38    }
39}
40
41impl Tokenizer for SimpleBPETokenizer {
42    type DecodingError = UnknownTokenId;
43
44    fn merges(&self) -> &[MergeRule] {
45        self.merges.as_slice()
46    }
47
48    fn vocab(&self) -> &HashMap<TokenId, Vec<u8>> {
49        &self.vocab
50    }
51
52    fn encode(&self, text: &str) -> Vec<TokenId> {
53        bpe_encode(text, &self.merge_lookup)
54    }
55
56    fn decode(&self, tokens: &[TokenId]) -> Result<String, UnknownTokenId> {
57        let mut bytes = Vec::new();
58
59        for token_id in tokens {
60            if let Some(token_bytes) = self.vocab.get(token_id) {
61                bytes.extend_from_slice(token_bytes);
62            } else {
63                return Err(UnknownTokenId(*token_id));
64            }
65        }
66
67        Ok(String::from_utf8_lossy(&bytes).into_owned())
68    }
69}
70
71impl Serializable for SimpleBPETokenizer {
72    fn save(&self, path: &Path) -> Result<(), Error> {
73        verify_stok_extension(path)?;
74
75        let file = File::create(path)?;
76        let mut writer = BufWriter::new(file);
77
78        save_merges(&mut writer, self.merges.as_slice())?;
79        save_vocab(&path.with_extension("vocab"), self.merges.as_slice(), self.vocab())?;
80
81        Ok(())
82    }
83}