tiktoken_rust/
lib.rs

1//! # tiktoken_rust
2//!
3//! This crate is a tokeniser for use with OpenAI's models.
4
5mod core;
6pub use crate::core::{encoding_for_model, get_encoding, Encoding, Result};
7
8mod model;
9pub use model::{AllowedSpecial, DecodeMode, DisallowedSpecial, EncodeError};
10
11mod load;
12mod openai_public;
13pub use openai_public::list_encoding_names;
14
15use std::collections::HashSet;
16use std::thread;
17
18use fancy_regex::Regex;
19use rustc_hash::FxHashMap as HashMap;
20
21fn _byte_pair_merge<T>(
22    piece: &[u8],
23    ranks: &HashMap<Vec<u8>, usize>,
24    f: impl Fn(std::ops::Range<usize>) -> T,
25) -> Vec<T> {
26    // This is a vector of (start, rank).
27    // The rank is of the byte pair starting at position start.
28    // The rank of the last item in the vector is not a valid value.
29    let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();
30
31    let get_rank = {
32        #[inline(always)]
33        |parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| {
34            if (start_idx + skip + 2) < parts.len() {
35                ranks
36                    .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
37                    .copied()
38            } else {
39                None
40            }
41        }
42    };
43
44    // We look up the ranks once in the beginning and iteratively update
45    // them during each merge, which reduces the number of rank lookups.
46    for i in 0..parts.len() - 2 {
47        match get_rank(&parts, i, 0) {
48            Some(rank) => {
49                // usize::MAX is a sentinel value and cannot be a valid rank
50                debug_assert!(rank != usize::MAX);
51                parts[i].1 = rank;
52            }
53            None => {
54                continue;
55            }
56        };
57    }
58
59    // If you have n parts and m merges, this does O(mn) work.
60    // We could do something with a heap and do O(m log n) work.
61    // It is important to consider that n is often small (<100), and as such
62    // the cache-locality benefits outweigh the algorithmic complexity downsides
63    // of the `parts` vector data structure above.
64
65    // Note that we hash bytes, not token pairs. As long as we train BPE the way we
66    // currently do, this is equivalent. An easy way to break this would be to decouple
67    // merge priority from token index or to prevent specific token merges.
68    loop {
69        if parts.len() == 1 {
70            break;
71        }
72
73        // usize::MAX is a sentinel rank value allowing us to
74        // take the min more quickly
75        let mut min_rank: (usize, usize) = (usize::MAX, 0);
76        for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
77            if rank < min_rank.0 {
78                min_rank = (rank, i);
79            }
80        }
81
82        if min_rank.0 != usize::MAX {
83            let i = min_rank.1;
84
85            // NOTE: We are about to remove parts[i + 1]. We do not do it
86            // yet because there are cache-locality benefits to updating
87            // parts[i] and parts[i-1] before removing, which could thrash
88            // the cache. Thus, we update the rank calculation by skipping over
89            // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
90            parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX);
91            if i > 0 {
92                parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX);
93            }
94
95            parts.remove(i + 1);
96        } else {
97            break;
98        }
99    }
100    let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
101    for i in 0..parts.len() - 1 {
102        out.push(f(parts[i].0..parts[i + 1].0));
103    }
104    out
105}
106
107fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
108    if piece.len() == 1 {
109        return vec![ranks[piece]];
110    }
111    _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
112}
113
114#[allow(dead_code)]
115fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
116    if piece.len() == 1 {
117        return vec![piece];
118    }
119    _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
120}
121
122// Various performance notes:
123//
124// Regex
125// =====
126// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
127// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
128// the usual regex we use.
129//
130// However, given that we're using a regex parse-able by `regex`, there isn't much difference
131// between using the `regex` crate and using the `fancy_regex` crate.
132//
133// There is an important interaction between threading, `regex` and `fancy_regex`.
134// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
135// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
136// old `regex`, we don't hit this, because `find_iter` has a different code path.
137// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
138// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
139// each thread.
140//
141// Threading
142// =========
143// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
144// So goodbye `rayon`! Let thread count etc be in control of our Python users.
145//
146// Caching
147// =======
148// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
149// Originally, we had one too! Without it, we were only vaguely faster than Python.
150// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
151// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
152// multi-threaded performance even when I only had readers (maybed I messed something up?).
153// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
154// These are exactly the set or merges that are likely to be hot. And now we don't have to think
155// about interior mutability, memory use, or cloning.
156//
157// Hashing
158// =======
159// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
160// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
161// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
162
163use std::num::NonZeroU64;
164struct FakeThreadId(NonZeroU64);
165
166fn hash_current_thread() -> usize {
167    // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
168    // that works great for our use case of avoiding collisions in our array. Unfortunately,
169    // it's private. However, there are only so many ways you can layout a u64, so just transmute
170    // https://github.com/rust-lang/rust/issues/67939
171    const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
172    const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
173    let x = unsafe {
174        std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
175    };
176    u64::from(x) as usize
177}
178
179const MAX_NUM_THREADS: usize = 128;
180
181struct CoreBPE {
182    encoder: HashMap<Vec<u8>, usize>,
183    special_tokens_encoder: HashMap<String, usize>,
184    decoder: HashMap<usize, Vec<u8>>,
185    special_tokens_decoder: HashMap<usize, Vec<u8>>,
186    regex_tls: Vec<Regex>,
187    special_regex_tls: Vec<Regex>,
188    sorted_token_bytes: Vec<Vec<u8>>,
189}
190
191impl CoreBPE {
192    fn _get_tl_regex(&self) -> &Regex {
193        // See performance notes above for what this is about
194        // It's also a little janky, please make a better version of it!
195        // However, it's nice that this doesn't leak memory to short-lived threads
196        &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
197    }
198
199    fn _get_tl_special_regex(&self) -> &Regex {
200        &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
201    }
202
203    fn _decode_native(&self, tokens: &[usize]) -> Vec<u8> {
204        let mut ret = Vec::with_capacity(tokens.len() * 2);
205        for token in tokens {
206            let token_bytes = self
207                .decoder
208                .get(token)
209                .unwrap_or_else(|| &self.special_tokens_decoder[token]);
210            ret.extend(token_bytes);
211        }
212        ret
213    }
214
215    fn _encode_ordinary_native(&self, text: &str) -> Vec<usize> {
216        // This is the core of the encoding logic; the other functions in here
217        // just make things complicated :-)
218        let regex = self._get_tl_regex();
219        let mut ret = vec![];
220        for mat in regex.find_iter(text) {
221            let piece = mat.unwrap().as_str().as_bytes();
222            if let Some(token) = self.encoder.get(piece) {
223                ret.push(*token);
224                continue;
225            }
226            ret.extend(&byte_pair_encode(piece, &self.encoder));
227        }
228        ret
229    }
230
231    fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<usize>, usize) {
232        let special_regex = self._get_tl_special_regex();
233        let regex = self._get_tl_regex();
234        let mut ret = vec![];
235
236        let mut start = 0;
237        let mut last_piece_token_len = 0;
238        loop {
239            let mut next_special;
240            let mut start_find = start;
241            loop {
242                // Find the next allowed special token, if any
243                next_special = special_regex.find_from_pos(text, start_find).unwrap();
244                match next_special {
245                    Some(m) => {
246                        if allowed_special.contains(&text[m.start()..m.end()]) {
247                            break;
248                        }
249                        start_find = m.start() + 1;
250                    }
251                    None => break,
252                }
253            }
254            let end = next_special.map_or(text.len(), |m| m.start());
255
256            // Okay, here we go, compare this logic to _encode_ordinary_native
257            for mat in regex.find_iter(&text[start..end]) {
258                let piece = mat.unwrap().as_str().as_bytes();
259                if let Some(token) = self.encoder.get(piece) {
260                    last_piece_token_len = 1;
261                    ret.push(*token);
262                    continue;
263                }
264                let tokens = byte_pair_encode(piece, &self.encoder);
265                last_piece_token_len = tokens.len();
266                ret.extend(&tokens);
267            }
268
269            match next_special {
270                // And here we push the special token
271                Some(m) => {
272                    let piece = m.as_str();
273                    let token = self.special_tokens_encoder[piece];
274                    ret.push(token);
275                    start = m.end();
276                    last_piece_token_len = 0;
277                }
278                None => break,
279            }
280        }
281
282        // last_piece_token_len is how many tokens came from the last regex split. This is used
283        // for determining unstable tokens, since you can't merge across (stable) regex splits
284        (ret, last_piece_token_len)
285    }
286
287    fn _increase_last_piece_token_len(
288        &self,
289        tokens: Vec<usize>,
290        mut last_piece_token_len: usize,
291    ) -> (Vec<usize>, usize) {
292        // Unfortunately, the locations where our regex splits can be unstable.
293        // For the purposes of determining unstable tokens, unstable regex splitting
294        // is only a problem if a split that was present disappears, since this can
295        // lead to merging of tokens otherwise thought to be stable.
296        // cl100k_base makes our life hard by including the \s*[\r\n]+
297        // pattern. This can e.g. cause "\n" + " " to become "\n \n".
298        // Here is a quick and dirty fix:
299        {
300            let token_is_all_space = |token| {
301                self.decoder
302                    .get(token)
303                    .map(|token_bytes| {
304                        token_bytes
305                            .iter()
306                            .rev()
307                            .all(|&b| [b' ', b'\n', b'\t'].contains(&b))
308                    })
309                    .unwrap_or(false)
310            };
311            if last_piece_token_len > 0
312                && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
313            {
314                while (last_piece_token_len < tokens.len())
315                    && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
316                {
317                    last_piece_token_len += 1;
318                }
319            }
320        }
321        debug_assert!(last_piece_token_len <= tokens.len());
322
323        (tokens, last_piece_token_len)
324    }
325
326    fn _encode_unstable_native(
327        &self,
328        text: &str,
329        allowed_special: &HashSet<&str>,
330    ) -> (Vec<usize>, HashSet<Vec<usize>>) {
331        let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special);
332        if last_piece_token_len == 0 {
333            // If last_piece_token_len is zero, the last token was a special token and we have
334            // no unstable bytes
335            return (tokens, HashSet::new());
336        }
337        let (mut tokens, last_piece_token_len) =
338            self._increase_last_piece_token_len(tokens, last_piece_token_len);
339
340        let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]);
341        tokens.truncate(tokens.len() - last_piece_token_len);
342
343        // TODO: we should try harder to find additional stable tokens
344        // This would reduce the amount of retokenising when determining completions
345        // Refer to the logic in an older version of this file
346
347        let mut completions = HashSet::new();
348        if unstable_bytes.is_empty() {
349            return (tokens, completions);
350        }
351
352        // This is the easy bit. Just find all single tokens that start with unstable_bytes
353        // (including tokens that exactly match unstable_bytes)
354        // Separating this from the loop below helps with performance in a common case.
355        let mut point = self
356            .sorted_token_bytes
357            .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
358        while point < self.sorted_token_bytes.len()
359            && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
360        {
361            completions.insert(vec![
362                self.encoder[self.sorted_token_bytes[point].as_slice()],
363            ]);
364            point += 1;
365        }
366
367        // Now apply even more brute force. At every (other) possible position for the straddling
368        // token, concatenate additional bytes from that token (if any) to unstable_bytes,
369        // and retokenise the whole thing and see what we get.
370        for i in 1..unstable_bytes.len() {
371            let prefix = &unstable_bytes[..i];
372            let suffix = &unstable_bytes[i..];
373            let mut point = self
374                .sorted_token_bytes
375                .partition_point(|x| x.as_slice() < suffix);
376            // TODO: Perf optimisation if suffix starts with " "?
377            while point < self.sorted_token_bytes.len()
378                && self.sorted_token_bytes[point].starts_with(suffix)
379            {
380                let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
381                let encoded = match std::str::from_utf8(&possibility) {
382                    // Morally, this is byte_pair_encode(&possibility, &self.encoder)
383                    // But we might have introduced a regex split which would prevent merges.
384                    // (particularly possible in the presence of unstable regex splits)
385                    // So convert to UTF-8 and do regex splitting.
386                    // E.g. with cl100k_base "  !" gets split to " " + " !",
387                    // but byte_pair_encode("  !") != byte_pair_encode(" ")
388                    Ok(s) => self._encode_ordinary_native(s),
389
390                    // Technically, whether or not this arm is correct depends on whether there
391                    // would be a regex split before the UTF-8 truncation point.
392                    // Probably niche enough that no one will ever notice (after all, people didn't
393                    // notice all the big holes in the previous unstable token implementation)
394                    Err(_) => byte_pair_encode(&possibility, &self.encoder),
395                    // Something like the following is intriguing but incorrect:
396                    // Err(e) => self._encode_ordinary_native(unsafe {
397                    //     std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
398                    // }),
399                };
400                let mut seq = Vec::new();
401                let mut seq_len = 0;
402                for token in encoded {
403                    seq.push(token);
404                    seq_len += self.decoder[&token].len();
405                    if seq_len >= unstable_bytes.len() {
406                        break;
407                    }
408                }
409                completions.insert(seq);
410                point += 1;
411            }
412        }
413
414        // This is also not straightforward. While we generally assume that regex splits are stable,
415        // unfortunately, they are not. That is, if adding bytes were to make a split appear in
416        // unstable_bytes, this could make tokens possible which our logic would otherwise think
417        // would be merged.
418        // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
419        // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
420        // Here is a quick and dirty fix:
421        // This isn't right if we ever remove \s+(?!\S)
422        if unstable_bytes.len() > 1 {
423            let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
424            if unstable_bytes.len() - last_decoded.1 > 0
425                && last_decoded.0.map_or(false, |c| c.is_whitespace())
426            {
427                let mut reencoded = byte_pair_encode(
428                    &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
429                    &self.encoder,
430                );
431                reencoded.extend(byte_pair_encode(
432                    &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
433                    &self.encoder,
434                ));
435                completions.insert(reencoded);
436            }
437        }
438
439        (tokens, completions)
440    }
441}
442
443impl CoreBPE {
444    fn new(
445        encoder: HashMap<Vec<u8>, usize>,
446        special_tokens_encoder: HashMap<String, usize>,
447        pattern: &str,
448    ) -> Result<Self> {
449        let regex = Regex::new(pattern)?;
450
451        let special_regex = {
452            let _parts = special_tokens_encoder
453                .keys()
454                .map(|s| fancy_regex::escape(s))
455                .collect::<Vec<_>>();
456            Regex::new(&_parts.join("|"))?
457        };
458
459        let decoder: HashMap<usize, Vec<u8>> =
460            encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
461
462        assert_eq!(encoder.len(),
463                   decoder.len(),
464                   "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?");
465
466        let special_tokens_decoder: HashMap<usize, Vec<u8>> = special_tokens_encoder
467            .iter()
468            .map(|(k, v)| (*v, k.as_bytes().to_vec()))
469            .collect();
470
471        // Clone because I don't know how to tell Rust I'm not going to change the map
472        let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
473        sorted_token_bytes.sort();
474
475        Ok(CoreBPE {
476            encoder,
477            special_tokens_encoder,
478            decoder,
479            special_tokens_decoder,
480            regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
481            special_regex_tls: (0..MAX_NUM_THREADS)
482                .map(|_| special_regex.clone())
483                .collect(),
484            sorted_token_bytes,
485        })
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use rustc_hash::FxHashMap as HashMap;
492
493    use crate::byte_pair_split;
494
495    #[test]
496    fn very_simple_test() {
497        let mut ranks = HashMap::default();
498        ranks.insert(b"ab".to_vec(), 1);
499        ranks.insert(b"cd".to_vec(), 2);
500
501        let res = byte_pair_split(b"abcd", &ranks);
502        assert_eq!(res, vec![b"ab", b"cd"]);
503    }
504}