smoltok_core/regex/
tokenizer.rs

1use crate::regex::{GPT4_SPLIT_PATTERN, save_regex_tokenizer, split_text};
2use crate::tokenizer::{UnknownTokenId, bpe_encode, build_merge_lookup, build_vocab, save_vocab};
3use crate::{MergeRule, Serializable, TokenId, TokenPair, Tokenizer};
4use fancy_regex::Regex;
5use std::collections::HashMap;
6use std::io::Error;
7use std::path::Path;
8
9/// Regex-based BPE tokenizer that supports regex patterns and special tokens.
10///
11/// This tokenizer first splits text using a regex pattern,
12/// then applies BPE encoding to each chunk independently.
13/// It also supports special tokens that are handled separately from regular text.
14#[derive(Debug)]
15pub struct RegexBPETokenizer {
16    /// The learned merge rules, in order of application.
17    merges: Vec<MergeRule>,
18    /// Maps token pairs to (rank, new_id) for fast lookup during encoding.
19    merge_lookup: HashMap<TokenPair, (usize, TokenId)>,
20    /// Maps token IDs to their byte sequences.
21    vocab: HashMap<TokenId, Vec<u8>>,
22    /// The regex pattern used for splitting text.
23    pattern: String,
24    /// The compiled regex pattern.
25    compiled_pattern: Regex,
26    /// Special tokens mapping (token string -> token ID).
27    special_tokens: HashMap<String, TokenId>,
28}
29
30impl RegexBPETokenizer {
31    /// Creates a new `RegexBPETokenizer` from merge rules and a pattern.
32    ///
33    /// # Arguments
34    ///
35    /// * `merges` - A vector of `MergeRule` representing the merge operations.
36    /// * `pattern` - The regex pattern for splitting text.
37    pub fn new(merges: Vec<MergeRule>, pattern: String) -> Self {
38        // todo:
39        let compiled_pattern = Regex::new(&pattern).unwrap_or_else(|_| {
40            Regex::new(GPT4_SPLIT_PATTERN).expect("Both primary and fallback patterns failed to compile")
41        });
42
43        Self {
44            vocab: build_vocab(merges.as_slice()),
45            merge_lookup: build_merge_lookup(merges.as_slice()),
46            merges,
47            pattern,
48            compiled_pattern,
49            special_tokens: HashMap::new(),
50        }
51    }
52
53    /// Creates a new `RegexBPETokenizer` from merge rules using the default GPT-4 pattern.
54    ///
55    /// # Arguments
56    ///
57    /// * `merges` - A vector of `MergeRule` representing the merge operations.
58    pub fn from_merges(merges: Vec<MergeRule>) -> Self {
59        Self::new(merges, GPT4_SPLIT_PATTERN.to_string())
60    }
61
62    /// Returns the number of merge rules learned during training.
63    pub fn num_merges(&self) -> usize {
64        self.merges.len()
65    }
66
67    /// Returns the vocabulary size (base 256 bytes + merged tokens).
68    pub fn vocab_size(&self) -> usize {
69        self.vocab.len()
70    }
71
72    /// Returns the regex pattern used for splitting text.
73    pub fn pattern(&self) -> &str {
74        &self.pattern
75    }
76
77    /// Registers special tokens with the tokenizer.
78    ///
79    /// Special tokens are handled separately during encoding and are not
80    /// subject to the BPE merge process.
81    ///
82    /// # Arguments
83    ///
84    /// * `special_tokens` - A map of special token strings to their token IDs.
85    pub fn register_special_tokens(&mut self, special_tokens: HashMap<String, TokenId>) {
86        self.special_tokens.extend(special_tokens);
87    }
88
89    /// Adds a single special token to the tokenizer.
90    ///
91    /// # Arguments
92    ///
93    /// * `token` - The special token string.
94    /// * `token_id` - The token ID to assign.
95    pub fn add_special_token(&mut self, token: String, token_id: TokenId) {
96        self.special_tokens.insert(token, token_id);
97    }
98
99    /// Returns the special tokens registered with this tokenizer.
100    pub fn special_tokens(&self) -> &HashMap<String, TokenId> {
101        &self.special_tokens
102    }
103
104    /// Encode a single chunk using BPE.
105    fn encode_chunk(&self, chunk: &str) -> Vec<TokenId> {
106        bpe_encode(chunk, &self.merge_lookup)
107    }
108
109    /// Encode text without handling special tokens (sequential).
110    fn encode_ordinary(&self, text: &str) -> Vec<TokenId> {
111        let chunks = split_text(text, &self.compiled_pattern);
112        chunks.iter().flat_map(|chunk| self.encode_chunk(chunk)).collect()
113    }
114}
115
116impl Tokenizer for RegexBPETokenizer {
117    type DecodingError = UnknownTokenId;
118
119    fn merges(&self) -> &[MergeRule] {
120        self.merges.as_slice()
121    }
122
123    fn vocab(&self) -> &HashMap<TokenId, Vec<u8>> {
124        &self.vocab
125    }
126
127    fn encode(&self, text: &str) -> Vec<TokenId> {
128        if self.special_tokens.is_empty() {
129            return self.encode_ordinary(text);
130        }
131
132        // build a regex pattern to split on special tokens
133        let special_tokens_pattern = self
134            .special_tokens
135            .keys()
136            .map(|token| fancy_regex::escape(token))
137            .collect::<Vec<_>>()
138            .join("|");
139
140        let split_pattern = match Regex::new(&format!("({})", special_tokens_pattern)) {
141            Ok(p) => p,
142            Err(_) => return self.encode_ordinary(text),
143        };
144
145        let mut tokens = Vec::new();
146        let mut last_end = 0;
147
148        // find all special tokens and split accordingly
149        for m in split_pattern.find_iter(text).flatten() {
150            // encode the text before this special token
151            if m.start() > last_end {
152                tokens.extend(self.encode_ordinary(&text[last_end..m.start()]));
153            }
154
155            // add the special token
156            let special_token = m.as_str();
157            if let Some(&token_id) = self.special_tokens.get(special_token) {
158                tokens.push(token_id);
159            }
160
161            last_end = m.end();
162        }
163
164        // encode any remaining text after the last special token
165        if last_end < text.len() {
166            tokens.extend(self.encode_ordinary(&text[last_end..]));
167        }
168
169        tokens
170    }
171
172    fn decode(&self, tokens: &[TokenId]) -> Result<String, UnknownTokenId> {
173        // build inverted special tokens map for decoding
174        let special_tokens_inverted: HashMap<TokenId, Vec<u8>> = self
175            .special_tokens
176            .iter()
177            .map(|(token, &id)| (id, token.as_bytes().to_vec()))
178            .collect();
179
180        let mut bytes = Vec::new();
181
182        for token_id in tokens {
183            if let Some(token_bytes) = self.vocab.get(token_id) {
184                bytes.extend_from_slice(token_bytes);
185            } else if let Some(token_bytes) = special_tokens_inverted.get(token_id) {
186                bytes.extend_from_slice(token_bytes);
187            } else {
188                return Err(UnknownTokenId(*token_id));
189            }
190        }
191
192        Ok(String::from_utf8_lossy(&bytes).into_owned())
193    }
194}
195
196impl Serializable for RegexBPETokenizer {
197    fn save(&self, path: &Path) -> Result<(), Error> {
198        save_regex_tokenizer(path, self.pattern.as_str(), self.merges())?;
199        save_vocab(&path.with_extension("vocab"), self.merges.as_slice(), self.vocab())?;
200
201        Ok(())
202    }
203}