smoltok_core/regex/
tokenizer_parallel.rs1use crate::regex::{GPT4_SPLIT_PATTERN, save_regex_tokenizer, split_text};
2use crate::tokenizer::{UnknownTokenId, bpe_encode, build_merge_lookup, build_vocab, save_vocab};
3use crate::{MergeRule, Serializable, TokenId, TokenPair, Tokenizer};
4use fancy_regex::Regex;
5use rayon::prelude::*;
6use std::collections::HashMap;
7use std::io::Error;
8use std::path::Path;
9
10#[derive(Debug)]
16pub struct ParallelRegexBPETokenizer {
17 merges: Vec<MergeRule>,
19 merge_lookup: HashMap<TokenPair, (usize, TokenId)>,
21 vocab: HashMap<TokenId, Vec<u8>>,
23 pattern: String,
25 compiled_pattern: Regex,
27 special_tokens: HashMap<String, TokenId>,
29}
30
31impl ParallelRegexBPETokenizer {
32 pub fn new(merges: Vec<MergeRule>, pattern: String) -> Self {
41 let compiled_pattern = Regex::new(&pattern).unwrap_or_else(|_| {
43 Regex::new(GPT4_SPLIT_PATTERN).expect("Both primary and fallback patterns failed to compile")
44 });
45
46 Self {
47 vocab: build_vocab(merges.as_slice()),
48 merge_lookup: build_merge_lookup(merges.as_slice()),
49 merges,
50 pattern,
51 compiled_pattern,
52 special_tokens: HashMap::new(),
53 }
54 }
55
56 pub fn from_merges(merges: Vec<MergeRule>) -> Self {
62 Self::new(merges, GPT4_SPLIT_PATTERN.to_string())
63 }
64
65 pub fn num_merges(&self) -> usize {
67 self.merges.len()
68 }
69
70 pub fn vocab_size(&self) -> usize {
72 self.vocab.len()
73 }
74
75 pub fn pattern(&self) -> &str {
77 &self.pattern
78 }
79
80 pub fn register_special_tokens(&mut self, special_tokens: HashMap<String, TokenId>) {
89 self.special_tokens.extend(special_tokens);
90 }
91
92 pub fn add_special_token(&mut self, token: String, token_id: TokenId) {
99 self.special_tokens.insert(token, token_id);
100 }
101
102 pub fn special_tokens(&self) -> &HashMap<String, TokenId> {
104 &self.special_tokens
105 }
106
107 fn encode_chunk(&self, chunk: &str) -> Vec<TokenId> {
109 bpe_encode(chunk, &self.merge_lookup)
110 }
111
112 fn encode_ordinary(&self, text: &str) -> Vec<TokenId> {
114 let chunks = split_text(text, &self.compiled_pattern);
115 chunks
116 .par_iter()
117 .map(|chunk| self.encode_chunk(chunk))
118 .flatten()
119 .collect()
120 }
121}
122
123enum Segment<'a> {
125 Text(&'a str),
126 SpecialToken(&'a str),
127}
128
129impl Tokenizer for ParallelRegexBPETokenizer {
130 type DecodingError = UnknownTokenId;
131
132 fn merges(&self) -> &[MergeRule] {
133 self.merges.as_slice()
134 }
135
136 fn vocab(&self) -> &HashMap<TokenId, Vec<u8>> {
137 &self.vocab
138 }
139
140 fn encode(&self, text: &str) -> Vec<TokenId> {
141 if self.special_tokens.is_empty() {
142 return self.encode_ordinary(text);
143 }
144
145 let special_tokens_pattern = self
147 .special_tokens
148 .keys()
149 .map(|token| fancy_regex::escape(token))
150 .collect::<Vec<_>>()
151 .join("|");
152
153 let split_pattern = match Regex::new(&format!("({})", special_tokens_pattern)) {
154 Ok(p) => p,
155 Err(_) => return self.encode_ordinary(text),
156 };
157
158 let mut segments: Vec<Segment> = Vec::new();
160 let mut last_end = 0;
161
162 for m in split_pattern.find_iter(text).flatten() {
164 if m.start() > last_end {
166 segments.push(Segment::Text(&text[last_end..m.start()]));
167 }
168 segments.push(Segment::SpecialToken(m.as_str()));
169 last_end = m.end();
170 }
171
172 if last_end < text.len() {
174 segments.push(Segment::Text(&text[last_end..]));
175 }
176
177 segments
179 .par_iter()
180 .map(|segment| match segment {
181 Segment::Text(text) => self.encode_ordinary(text),
182 Segment::SpecialToken(special_token) => self
183 .special_tokens
184 .get(*special_token)
185 .map(|&id| vec![id])
186 .unwrap_or_default(),
187 })
188 .flatten()
189 .collect()
190 }
191
192 fn decode(&self, tokens: &[TokenId]) -> Result<String, UnknownTokenId> {
193 let special_tokens_inverted: HashMap<TokenId, Vec<u8>> = self
195 .special_tokens
196 .iter()
197 .map(|(token, &id)| (id, token.as_bytes().to_vec()))
198 .collect();
199
200 let byte_chunks: Result<Vec<Vec<u8>>, UnknownTokenId> = tokens
202 .par_iter()
203 .map(|token_id| {
204 if let Some(token_bytes) = self.vocab.get(token_id) {
205 Ok(token_bytes.clone())
206 } else if let Some(token_bytes) = special_tokens_inverted.get(token_id) {
207 Ok(token_bytes.clone())
208 } else {
209 Err(UnknownTokenId(*token_id))
210 }
211 })
212 .collect();
213
214 let bytes: Vec<u8> = byte_chunks?.into_iter().flatten().collect();
216
217 Ok(String::from_utf8_lossy(&bytes).into_owned())
218 }
219}
220
221impl Serializable for ParallelRegexBPETokenizer {
222 fn save(&self, path: &Path) -> Result<(), Error> {
223 save_regex_tokenizer(path, self.pattern.as_str(), self.merges())?;
224 save_vocab(&path.with_extension("vocab"), self.merges.as_slice(), self.vocab())?;
225
226 Ok(())
227 }
228}