smoltok_core/regex/
tokenizer_parallel.rs

1use crate::regex::{GPT4_SPLIT_PATTERN, save_regex_tokenizer, split_text};
2use crate::tokenizer::{UnknownTokenId, bpe_encode, build_merge_lookup, build_vocab, save_vocab};
3use crate::{MergeRule, Serializable, TokenId, TokenPair, Tokenizer};
4use fancy_regex::Regex;
5use rayon::prelude::*;
6use std::collections::HashMap;
7use std::io::Error;
8use std::path::Path;
9
10/// Regex-based BPE tokenizer that supports regex patterns and special tokens.
11///
12/// This tokenizer first splits text using a regex pattern,
13/// then applies BPE encoding to each chunk independently.
14/// It also supports special tokens that are handled separately from regular text.
15#[derive(Debug)]
16pub struct ParallelRegexBPETokenizer {
17    /// The learned merge rules, in order of application.
18    merges: Vec<MergeRule>,
19    /// Maps token pairs to (rank, new_id) for fast lookup during encoding.
20    merge_lookup: HashMap<TokenPair, (usize, TokenId)>,
21    /// Maps token IDs to their byte sequences.
22    vocab: HashMap<TokenId, Vec<u8>>,
23    /// The regex pattern used for splitting text.
24    pattern: String,
25    /// The compiled regex pattern.
26    compiled_pattern: Regex,
27    /// Special tokens mapping (token string -> token ID).
28    special_tokens: HashMap<String, TokenId>,
29}
30
31impl ParallelRegexBPETokenizer {
32    // todo: avoid code duplication
33
34    /// Creates a new `ParallelRegexBPETokenizer` from merge rules and a pattern.
35    ///
36    /// # Arguments
37    ///
38    /// * `merges` - A vector of `MergeRule` representing the merge operations.
39    /// * `pattern` - The regex pattern for splitting text.
40    pub fn new(merges: Vec<MergeRule>, pattern: String) -> Self {
41        // todo:
42        let compiled_pattern = Regex::new(&pattern).unwrap_or_else(|_| {
43            Regex::new(GPT4_SPLIT_PATTERN).expect("Both primary and fallback patterns failed to compile")
44        });
45
46        Self {
47            vocab: build_vocab(merges.as_slice()),
48            merge_lookup: build_merge_lookup(merges.as_slice()),
49            merges,
50            pattern,
51            compiled_pattern,
52            special_tokens: HashMap::new(),
53        }
54    }
55
56    /// Creates a new `ParallelRegexBPETokenizer` from merge rules using the default GPT-4 pattern.
57    ///
58    /// # Arguments
59    ///
60    /// * `merges` - A vector of `MergeRule` representing the merge operations.
61    pub fn from_merges(merges: Vec<MergeRule>) -> Self {
62        Self::new(merges, GPT4_SPLIT_PATTERN.to_string())
63    }
64
65    /// Returns the number of merge rules learned during training.
66    pub fn num_merges(&self) -> usize {
67        self.merges.len()
68    }
69
70    /// Returns the vocabulary size (base 256 bytes + merged tokens).
71    pub fn vocab_size(&self) -> usize {
72        self.vocab.len()
73    }
74
75    /// Returns the regex pattern used for splitting text.
76    pub fn pattern(&self) -> &str {
77        &self.pattern
78    }
79
80    /// Registers special tokens with the tokenizer.
81    ///
82    /// Special tokens are handled separately during encoding and are not
83    /// subject to the BPE merge process.
84    ///
85    /// # Arguments
86    ///
87    /// * `special_tokens` - A map of special token strings to their token IDs.
88    pub fn register_special_tokens(&mut self, special_tokens: HashMap<String, TokenId>) {
89        self.special_tokens.extend(special_tokens);
90    }
91
92    /// Adds a single special token to the tokenizer.
93    ///
94    /// # Arguments
95    ///
96    /// * `token` - The special token string.
97    /// * `token_id` - The token ID to assign.
98    pub fn add_special_token(&mut self, token: String, token_id: TokenId) {
99        self.special_tokens.insert(token, token_id);
100    }
101
102    /// Returns the special tokens registered with this tokenizer.
103    pub fn special_tokens(&self) -> &HashMap<String, TokenId> {
104        &self.special_tokens
105    }
106
107    /// Encode a single chunk using BPE.
108    fn encode_chunk(&self, chunk: &str) -> Vec<TokenId> {
109        bpe_encode(chunk, &self.merge_lookup)
110    }
111
112    /// Encode text without handling special tokens (parallel).
113    fn encode_ordinary(&self, text: &str) -> Vec<TokenId> {
114        let chunks = split_text(text, &self.compiled_pattern);
115        chunks
116            .par_iter()
117            .map(|chunk| self.encode_chunk(chunk))
118            .flatten()
119            .collect()
120    }
121}
122
123/// Helper enum for parallel encoding with special tokens.
124enum Segment<'a> {
125    Text(&'a str),
126    SpecialToken(&'a str),
127}
128
129impl Tokenizer for ParallelRegexBPETokenizer {
130    type DecodingError = UnknownTokenId;
131
132    fn merges(&self) -> &[MergeRule] {
133        self.merges.as_slice()
134    }
135
136    fn vocab(&self) -> &HashMap<TokenId, Vec<u8>> {
137        &self.vocab
138    }
139
140    fn encode(&self, text: &str) -> Vec<TokenId> {
141        if self.special_tokens.is_empty() {
142            return self.encode_ordinary(text);
143        }
144
145        // build a regex pattern to split on special tokens
146        let special_tokens_pattern = self
147            .special_tokens
148            .keys()
149            .map(|token| fancy_regex::escape(token))
150            .collect::<Vec<_>>()
151            .join("|");
152
153        let split_pattern = match Regex::new(&format!("({})", special_tokens_pattern)) {
154            Ok(p) => p,
155            Err(_) => return self.encode_ordinary(text),
156        };
157
158        // collect all segments (text between special tokens + special tokens themselves)
159        let mut segments: Vec<Segment> = Vec::new();
160        let mut last_end = 0;
161
162        // find all special tokens and split accordingly
163        for m in split_pattern.find_iter(text).flatten() {
164            // encode the text before this special token
165            if m.start() > last_end {
166                segments.push(Segment::Text(&text[last_end..m.start()]));
167            }
168            segments.push(Segment::SpecialToken(m.as_str()));
169            last_end = m.end();
170        }
171
172        // encode any remaining text after the last special token
173        if last_end < text.len() {
174            segments.push(Segment::Text(&text[last_end..]));
175        }
176
177        // encode segments in parallel
178        segments
179            .par_iter()
180            .map(|segment| match segment {
181                Segment::Text(text) => self.encode_ordinary(text),
182                Segment::SpecialToken(special_token) => self
183                    .special_tokens
184                    .get(*special_token)
185                    .map(|&id| vec![id])
186                    .unwrap_or_default(),
187            })
188            .flatten()
189            .collect()
190    }
191
192    fn decode(&self, tokens: &[TokenId]) -> Result<String, UnknownTokenId> {
193        // build inverted special tokens map for decoding
194        let special_tokens_inverted: HashMap<TokenId, Vec<u8>> = self
195            .special_tokens
196            .iter()
197            .map(|(token, &id)| (id, token.as_bytes().to_vec()))
198            .collect();
199
200        // look up bytes for each token in parallel
201        let byte_chunks: Result<Vec<Vec<u8>>, UnknownTokenId> = tokens
202            .par_iter()
203            .map(|token_id| {
204                if let Some(token_bytes) = self.vocab.get(token_id) {
205                    Ok(token_bytes.clone())
206                } else if let Some(token_bytes) = special_tokens_inverted.get(token_id) {
207                    Ok(token_bytes.clone())
208                } else {
209                    Err(UnknownTokenId(*token_id))
210                }
211            })
212            .collect();
213
214        // concatenate all byte chunks
215        let bytes: Vec<u8> = byte_chunks?.into_iter().flatten().collect();
216
217        Ok(String::from_utf8_lossy(&bytes).into_owned())
218    }
219}
220
221impl Serializable for ParallelRegexBPETokenizer {
222    fn save(&self, path: &Path) -> Result<(), Error> {
223        save_regex_tokenizer(path, self.pattern.as_str(), self.merges())?;
224        save_vocab(&path.with_extension("vocab"), self.merges.as_slice(), self.vocab())?;
225
226        Ok(())
227    }
228}