smoltok_core/regex/
tokenizer.rs1use crate::regex::{GPT4_SPLIT_PATTERN, save_regex_tokenizer, split_text};
2use crate::tokenizer::{UnknownTokenId, bpe_encode, build_merge_lookup, build_vocab, save_vocab};
3use crate::{MergeRule, Serializable, TokenId, TokenPair, Tokenizer};
4use fancy_regex::Regex;
5use std::collections::HashMap;
6use std::io::Error;
7use std::path::Path;
8
9#[derive(Debug)]
15pub struct RegexBPETokenizer {
16 merges: Vec<MergeRule>,
18 merge_lookup: HashMap<TokenPair, (usize, TokenId)>,
20 vocab: HashMap<TokenId, Vec<u8>>,
22 pattern: String,
24 compiled_pattern: Regex,
26 special_tokens: HashMap<String, TokenId>,
28}
29
30impl RegexBPETokenizer {
31 pub fn new(merges: Vec<MergeRule>, pattern: String) -> Self {
38 let compiled_pattern = Regex::new(&pattern).unwrap_or_else(|_| {
40 Regex::new(GPT4_SPLIT_PATTERN).expect("Both primary and fallback patterns failed to compile")
41 });
42
43 Self {
44 vocab: build_vocab(merges.as_slice()),
45 merge_lookup: build_merge_lookup(merges.as_slice()),
46 merges,
47 pattern,
48 compiled_pattern,
49 special_tokens: HashMap::new(),
50 }
51 }
52
53 pub fn from_merges(merges: Vec<MergeRule>) -> Self {
59 Self::new(merges, GPT4_SPLIT_PATTERN.to_string())
60 }
61
62 pub fn num_merges(&self) -> usize {
64 self.merges.len()
65 }
66
67 pub fn vocab_size(&self) -> usize {
69 self.vocab.len()
70 }
71
72 pub fn pattern(&self) -> &str {
74 &self.pattern
75 }
76
77 pub fn register_special_tokens(&mut self, special_tokens: HashMap<String, TokenId>) {
86 self.special_tokens.extend(special_tokens);
87 }
88
89 pub fn add_special_token(&mut self, token: String, token_id: TokenId) {
96 self.special_tokens.insert(token, token_id);
97 }
98
99 pub fn special_tokens(&self) -> &HashMap<String, TokenId> {
101 &self.special_tokens
102 }
103
104 fn encode_chunk(&self, chunk: &str) -> Vec<TokenId> {
106 bpe_encode(chunk, &self.merge_lookup)
107 }
108
109 fn encode_ordinary(&self, text: &str) -> Vec<TokenId> {
111 let chunks = split_text(text, &self.compiled_pattern);
112 chunks.iter().flat_map(|chunk| self.encode_chunk(chunk)).collect()
113 }
114}
115
116impl Tokenizer for RegexBPETokenizer {
117 type DecodingError = UnknownTokenId;
118
119 fn merges(&self) -> &[MergeRule] {
120 self.merges.as_slice()
121 }
122
123 fn vocab(&self) -> &HashMap<TokenId, Vec<u8>> {
124 &self.vocab
125 }
126
127 fn encode(&self, text: &str) -> Vec<TokenId> {
128 if self.special_tokens.is_empty() {
129 return self.encode_ordinary(text);
130 }
131
132 let special_tokens_pattern = self
134 .special_tokens
135 .keys()
136 .map(|token| fancy_regex::escape(token))
137 .collect::<Vec<_>>()
138 .join("|");
139
140 let split_pattern = match Regex::new(&format!("({})", special_tokens_pattern)) {
141 Ok(p) => p,
142 Err(_) => return self.encode_ordinary(text),
143 };
144
145 let mut tokens = Vec::new();
146 let mut last_end = 0;
147
148 for m in split_pattern.find_iter(text).flatten() {
150 if m.start() > last_end {
152 tokens.extend(self.encode_ordinary(&text[last_end..m.start()]));
153 }
154
155 let special_token = m.as_str();
157 if let Some(&token_id) = self.special_tokens.get(special_token) {
158 tokens.push(token_id);
159 }
160
161 last_end = m.end();
162 }
163
164 if last_end < text.len() {
166 tokens.extend(self.encode_ordinary(&text[last_end..]));
167 }
168
169 tokens
170 }
171
172 fn decode(&self, tokens: &[TokenId]) -> Result<String, UnknownTokenId> {
173 let special_tokens_inverted: HashMap<TokenId, Vec<u8>> = self
175 .special_tokens
176 .iter()
177 .map(|(token, &id)| (id, token.as_bytes().to_vec()))
178 .collect();
179
180 let mut bytes = Vec::new();
181
182 for token_id in tokens {
183 if let Some(token_bytes) = self.vocab.get(token_id) {
184 bytes.extend_from_slice(token_bytes);
185 } else if let Some(token_bytes) = special_tokens_inverted.get(token_id) {
186 bytes.extend_from_slice(token_bytes);
187 } else {
188 return Err(UnknownTokenId(*token_id));
189 }
190 }
191
192 Ok(String::from_utf8_lossy(&bytes).into_owned())
193 }
194}
195
196impl Serializable for RegexBPETokenizer {
197 fn save(&self, path: &Path) -> Result<(), Error> {
198 save_regex_tokenizer(path, self.pattern.as_str(), self.merges())?;
199 save_vocab(&path.with_extension("vocab"), self.merges.as_slice(), self.vocab())?;
200
201 Ok(())
202 }
203}