tiktoken_rs/
vendor_tiktoken.rs

1#[rustfmt::skip]
2/// This file is a vendored copy of the `tiktoken` crate.
3/// Modifications are limited to commenting out python related code and adjusting visibility of some functions, and suppressing lint warnings.
4/// Limit modifications to this file to make it easy to keep it in sync with upsteam
5// use std::borrow::Borrow;
6// use std::borrow::Cow;
7use std::collections::HashSet;
8use std::num::NonZeroU64;
9use std::thread;
10
11use fancy_regex::Regex;
12// #[cfg(feature = "python")]
13// use pyo3::prelude::*;
14use rustc_hash::FxHashMap as HashMap;
15
16// #[cfg(feature = "python")]
17// mod py;
18
19pub type Rank = u32;
20
21fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
22    // This is a vector of (start, rank).
23    // The rank is of the pair starting at position start.
24    let mut parts = Vec::with_capacity(piece.len() + 1);
25
26    // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
27    // the way we currently do, this is equivalent. An easy way to break this would be to decouple
28    // merge priority from token index or to prevent specific token merges.
29    let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
30    for i in 0..piece.len() - 1 {
31        let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
32        if rank < min_rank.0 {
33            min_rank = (rank, i);
34        }
35        parts.push((i, rank));
36    }
37    parts.push((piece.len() - 1, Rank::MAX));
38    parts.push((piece.len(), Rank::MAX));
39
40    let get_rank = {
41        #[inline(always)]
42        |parts: &Vec<(usize, Rank)>, i: usize| {
43            if (i + 3) < parts.len() {
44                // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
45                // parts[i + 1], see comment in the main loop.
46                *ranks
47                    .get(&piece[parts[i].0..parts[i + 3].0])
48                    .unwrap_or(&Rank::MAX)
49            } else {
50                Rank::MAX
51            }
52        }
53    };
54
55    // If you have n parts and m merges, this does O(mn) work.
56    // We could do something with a heap and do O(m log n) work.
57    // n is often very small so considerations like cache-locality outweigh the algorithmic
58    // complexity downsides of the `parts` vector.
59    while min_rank.0 != Rank::MAX {
60        let i = min_rank.1;
61        // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
62        // `parts.remove(i + 1)` will thrash the cache.
63        if i > 0 {
64            parts[i - 1].1 = get_rank(&parts, i - 1);
65        }
66        parts[i].1 = get_rank(&parts, i);
67        parts.remove(i + 1);
68
69        min_rank = (Rank::MAX, usize::MAX);
70        for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
71            if rank < min_rank.0 {
72                min_rank = (rank, i);
73            }
74        }
75    }
76    parts
77}
78
79pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
80    if piece.len() == 1 {
81        return vec![ranks[piece]];
82    }
83    _byte_pair_merge(ranks, piece)
84        .windows(2)
85        .map(|part| ranks[&piece[part[0].0..part[1].0]])
86        .collect()
87}
88
89pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
90    assert!(piece.len() > 1);
91    _byte_pair_merge(ranks, piece)
92        .windows(2)
93        .map(|part| &piece[part[0].0..part[1].0])
94        .collect()
95}
96
97// Various performance notes:
98//
99// Regex
100// =====
101// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
102// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
103// the usual regex we use.
104//
105// However, given that we're using a regex parse-able by `regex`, there isn't much difference
106// between using the `regex` crate and using the `fancy_regex` crate.
107//
108// There is an important interaction between threading, `regex` and `fancy_regex`.
109// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
110// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
111// old `regex`, we don't hit this, because `find_iter` has a different code path.
112// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
113// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
114// each thread.
115//
116// Threading
117// =========
118// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
119// So goodbye `rayon`! Let thread count etc be in control of our Python users.
120//
121// Caching
122// =======
123// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
124// Originally, we had one too! Without it, we were only vaguely faster than Python.
125// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
126// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
127// multi-threaded performance even when I only had readers (maybed I messed something up?).
128// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
129// These are exactly the set or merges that are likely to be hot. And now we don't have to think
130// about interior mutability, memory use, or cloning.
131//
132// Hashing
133// =======
134// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
135// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
136// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
137
138pub struct FakeThreadId(NonZeroU64);
139
140fn hash_current_thread() -> usize {
141    // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
142    // that works great for our use case of avoiding collisions in our array. Unfortunately,
143    // it's private. However, there are only so many ways you can layout a u64, so just transmute
144    // https://github.com/rust-lang/rust/issues/67939
145    const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
146    const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
147    let x = unsafe {
148        std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
149    };
150    u64::from(x) as usize
151}
152
153#[derive(Debug, Clone)]
154pub struct DecodeKeyError {
155    pub token: Rank,
156}
157
158impl std::fmt::Display for DecodeKeyError {
159    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
160        write!(f, "Invalid token for decoding: {}", self.token)
161    }
162}
163
164impl std::error::Error for DecodeKeyError {}
165
166#[derive(Debug, Clone)]
167#[allow(dead_code)]
168pub struct DecodeError {
169    pub message: String,
170}
171
172impl std::fmt::Display for DecodeError {
173    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
174        write!(f, "Could not decode tokens: {}", self.message)
175    }
176}
177
178impl std::error::Error for DecodeError {}
179
180pub const MAX_NUM_THREADS: usize = 128;
181
182// #[cfg_attr(feature = "python", pyclass)]
183#[derive(Clone)]
184pub struct CoreBPE {
185    pub(crate) encoder: HashMap<Vec<u8>, Rank>,
186    pub(crate) special_tokens_encoder: HashMap<String, Rank>,
187    pub(crate) decoder: HashMap<Rank, Vec<u8>>,
188    pub(crate) special_tokens_decoder: HashMap<Rank, Vec<u8>>,
189    pub(crate) regex_tls: Vec<Regex>,
190    pub(crate) special_regex_tls: Vec<Regex>,
191    #[allow(dead_code)]
192    pub(crate) sorted_token_bytes: Vec<Vec<u8>>,
193}
194
195impl CoreBPE {
196    fn _get_tl_regex(&self) -> &Regex {
197        // See performance notes above for what this is about
198        // It's also a little janky, please make a better version of it!
199        // However, it's nice that this doesn't leak memory to short-lived threads
200        &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
201    }
202
203    fn _get_tl_special_regex(&self) -> &Regex {
204        &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
205    }
206
207    /// Decodes tokens into a list of bytes.
208    ///
209    /// The bytes are not gauranteed to be a valid utf-8 string.
210    pub(crate) fn decode_bytes(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
211        let mut ret = Vec::with_capacity(tokens.len() * 2);
212        for &token in tokens {
213            let token_bytes = match self.decoder.get(&token) {
214                Some(bytes) => bytes,
215                None => self
216                    .special_tokens_decoder
217                    .get(&token)
218                    .ok_or(DecodeKeyError { token })?,
219            };
220            ret.extend(token_bytes);
221        }
222        Ok(ret)
223    }
224
225    pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
226        // This is the core of the encoding logic; the other functions in here
227        // just make things complicated :-)
228        let regex = self._get_tl_regex();
229        let mut ret = vec![];
230        for mat in regex.find_iter(text) {
231            let piece = mat.unwrap().as_str().as_bytes();
232            match self.encoder.get(piece) {
233                Some(token) => ret.push(*token),
234                None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
235            }
236        }
237        ret
238    }
239
240    pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
241        let special_regex = self._get_tl_special_regex();
242        let regex = self._get_tl_regex();
243        let mut ret = vec![];
244
245        let mut start = 0;
246        let mut last_piece_token_len = 0;
247        loop {
248            let mut next_special;
249            let mut start_find = start;
250            loop {
251                // Find the next allowed special token, if any
252                next_special = special_regex.find_from_pos(text, start_find).unwrap();
253                match next_special {
254                    Some(m) => {
255                        if allowed_special.contains(&text[m.start()..m.end()]) {
256                            break;
257                        }
258                        start_find = m.start() + 1;
259                    }
260                    None => break,
261                }
262            }
263            let end = next_special.map_or(text.len(), |m| m.start());
264
265            // Okay, here we go, compare this logic to encode_ordinary
266            for mat in regex.find_iter(&text[start..end]) {
267                let piece = mat.unwrap().as_str().as_bytes();
268                if let Some(token) = self.encoder.get(piece) {
269                    last_piece_token_len = 1;
270                    ret.push(*token);
271                    continue;
272                }
273                let tokens = byte_pair_encode(piece, &self.encoder);
274                last_piece_token_len = tokens.len();
275                ret.extend(&tokens);
276            }
277
278            match next_special {
279                // And here we push the special token
280                Some(m) => {
281                    let piece = m.as_str();
282                    let token = self.special_tokens_encoder[piece];
283                    ret.push(token);
284                    start = m.end();
285                    last_piece_token_len = 0;
286                }
287                None => break,
288            }
289        }
290
291        // last_piece_token_len is how many tokens came from the last regex split. This is used
292        // for determining unstable tokens, since you can't merge across (stable) regex splits
293        (ret, last_piece_token_len)
294    }
295
296    fn _increase_last_piece_token_len(
297        &self,
298        tokens: Vec<Rank>,
299        mut last_piece_token_len: usize,
300    ) -> (Vec<Rank>, usize) {
301        // Unfortunately, the locations where our regex splits can be unstable.
302        // For the purposes of determining unstable tokens, unstable regex splitting
303        // is only a problem if a split that was present disappears, since this can
304        // lead to merging of tokens otherwise thought to be stable.
305        // cl100k_base makes our life hard by including the \s*[\r\n]+
306        // pattern. This can e.g. cause "\n" + " " to become "\n \n".
307        // Here is a quick and dirty fix:
308        {
309            let token_is_all_space = |token| {
310                self.decoder
311                    .get(token)
312                    .map(|token_bytes| {
313                        token_bytes
314                            .iter()
315                            .rev()
316                            .all(|&b| [b' ', b'\n', b'\t'].contains(&b))
317                    })
318                    .unwrap_or(false)
319            };
320            if last_piece_token_len > 0
321                && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
322            {
323                while (last_piece_token_len < tokens.len())
324                    && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
325                {
326                    last_piece_token_len += 1;
327                }
328            }
329        }
330        debug_assert!(last_piece_token_len <= tokens.len());
331
332        (tokens, last_piece_token_len)
333    }
334
335    pub fn _encode_unstable_native(
336        &self,
337        text: &str,
338        allowed_special: &HashSet<&str>,
339    ) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
340        let (tokens, last_piece_token_len) = self.encode(text, allowed_special);
341        if last_piece_token_len == 0 {
342            // If last_piece_token_len is zero, the last token was a special token and we have
343            // no unstable bytes
344            return (tokens, HashSet::new());
345        }
346        let (mut tokens, last_piece_token_len) =
347            self._increase_last_piece_token_len(tokens, last_piece_token_len);
348
349        let unstable_bytes = self
350            .decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
351            .unwrap();
352        tokens.truncate(tokens.len() - last_piece_token_len);
353
354        // TODO: we should try harder to find additional stable tokens
355        // This would reduce the amount of retokenising when determining completions
356        // Refer to the logic in an older version of this file
357
358        let mut completions = HashSet::new();
359        if unstable_bytes.is_empty() {
360            return (tokens, completions);
361        }
362
363        // This is the easy bit. Just find all single tokens that start with unstable_bytes
364        // (including tokens that exactly match unstable_bytes)
365        // Separating this from the loop below helps with performance in a common case.
366        let mut point = self
367            .sorted_token_bytes
368            .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
369        while point < self.sorted_token_bytes.len()
370            && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
371        {
372            completions.insert(vec![
373                self.encoder[self.sorted_token_bytes[point].as_slice()],
374            ]);
375            point += 1;
376        }
377
378        // Now apply even more brute force. At every (other) possible position for the straddling
379        // token, concatenate additional bytes from that token (if any) to unstable_bytes,
380        // and retokenise the whole thing and see what we get.
381        for i in 1..unstable_bytes.len() {
382            let prefix = &unstable_bytes[..i];
383            let suffix = &unstable_bytes[i..];
384            let mut point = self
385                .sorted_token_bytes
386                .partition_point(|x| x.as_slice() < suffix);
387            // TODO: Perf optimisation if suffix starts with " "?
388            while point < self.sorted_token_bytes.len()
389                && self.sorted_token_bytes[point].starts_with(suffix)
390            {
391                let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
392                let encoded = match std::str::from_utf8(&possibility) {
393                    // Morally, this is byte_pair_encode(&possibility, &self.encoder)
394                    // But we might have introduced a regex split which would prevent merges.
395                    // (particularly possible in the presence of unstable regex splits)
396                    // So convert to UTF-8 and do regex splitting.
397                    // E.g. with cl100k_base "  !" gets split to " " + " !",
398                    // but byte_pair_encode("  !") != byte_pair_encode(" ")
399                    Ok(s) => self.encode_ordinary(s),
400
401                    // Technically, whether or not this arm is correct depends on whether there
402                    // would be a regex split before the UTF-8 truncation point.
403                    // Probably niche enough that no one will ever notice (after all, people didn't
404                    // notice all the big holes in the previous unstable token implementation)
405                    Err(_) => byte_pair_encode(&possibility, &self.encoder),
406                    // Something like the following is intriguing but incorrect:
407                    // Err(e) => self.encode_ordinary(unsafe {
408                    //     std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
409                    // }),
410                };
411                let mut seq = Vec::new();
412                let mut seq_len = 0;
413                for token in encoded {
414                    seq.push(token);
415                    seq_len += self.decoder[&token].len();
416                    if seq_len >= unstable_bytes.len() {
417                        break;
418                    }
419                }
420                completions.insert(seq);
421                point += 1;
422            }
423        }
424
425        // This is also not straightforward. While we generally assume that regex splits are stable,
426        // unfortunately, they are not. That is, if adding bytes were to make a split appear in
427        // unstable_bytes, this could make tokens possible which our logic would otherwise think
428        // would be merged.
429        // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
430        // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
431        // Here is a quick and dirty fix:
432        // This isn't right if we ever remove \s+(?!\S)
433        if unstable_bytes.len() > 1 {
434            let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
435            if unstable_bytes.len() - last_decoded.1 > 0
436                && last_decoded.0.map_or(false, |c| c.is_whitespace())
437            {
438                let mut reencoded = byte_pair_encode(
439                    &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
440                    &self.encoder,
441                );
442                reencoded.extend(byte_pair_encode(
443                    &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
444                    &self.encoder,
445                ));
446                completions.insert(reencoded);
447            }
448        }
449
450        (tokens, completions)
451    }
452
453    // pub fn new<E, SE, NSE>(
454    //     encoder: E,
455    //     special_tokens_encoder: SE,
456    //     pattern: &str,
457    // ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
458    // where
459    //     E: IntoIterator<Item = (Vec<u8>, Rank)>,
460    //     SE: IntoIterator<Item = (String, Rank)>,
461    //     NSE: IntoIterator<Item = (String, (Rank, Rank))>,
462    // {
463    //     Self::new_internal(
464    //         HashMap::from_iter(encoder),
465    //         HashMap::from_iter(special_tokens_encoder),
466    //         pattern,
467    //     )
468    // }
469
470    // fn new_internal(
471    //     encoder: HashMap<Vec<u8>, Rank>,
472    //     special_tokens_encoder: HashMap<String, Rank>,
473    //     pattern: &str,
474    // ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
475    //     let regex = Regex::new(pattern)?;
476
477    //     let special_regex = {
478    //         let parts = special_tokens_encoder
479    //             .keys()
480    //             .map(|s| fancy_regex::escape(s))
481    //             .collect::<Vec<_>>();
482    //         Regex::new(&parts.join("|"))?
483    //     };
484
485    //     let decoder: HashMap<Rank, Vec<u8>> =
486    //         encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
487
488    //     assert!(
489    //         encoder.len() == decoder.len(),
490    //         "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
491    //     );
492
493    //     let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
494    //         .iter()
495    //         .map(|(k, v)| (*v, k.as_bytes().to_vec()))
496    //         .collect();
497
498    //     // Clone because I don't know how to tell Rust I'm not going to change the map
499    //     let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
500    //     sorted_token_bytes.sort();
501
502    //     Ok(Self {
503    //         encoder,
504    //         special_tokens_encoder,
505    //         decoder,
506    //         special_tokens_decoder,
507    //         regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
508    //         special_regex_tls: (0..MAX_NUM_THREADS)
509    //             .map(|_| special_regex.clone())
510    //             .collect(),
511    //         sorted_token_bytes,
512    //     })
513    // }
514
515    pub fn special_tokens(&self) -> HashSet<&str> {
516        self.special_tokens_encoder
517            .keys()
518            .map(|s| s.as_str())
519            .collect()
520    }
521
522    pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
523        let allowed_special = self.special_tokens();
524        self.encode(text, &allowed_special).0
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    // use fancy_regex::Regex;
531    use rustc_hash::FxHashMap as HashMap;
532
533    use crate::{byte_pair_split, Rank};
534
535    fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
536        HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)])
537    }
538
539    #[test]
540    fn test_simple_characters() {
541        let ranks = setup_ranks();
542        let res = byte_pair_split(b"abcd", &ranks);
543        assert_eq!(res, vec![b"ab", b"cd"]);
544    }
545
546    #[test]
547    fn test_repeated_characters() {
548        let ranks = setup_ranks();
549        let res = byte_pair_split(b"abab", &ranks);
550        assert_eq!(res, vec![b"ab", b"ab"]);
551    }
552}