Skip to main content

tiktoken_rs/
vendor_tiktoken.rs

1#[rustfmt::skip]
2/// This file is a vendored copy of the `tiktoken` crate.
3/// Modifications are limited to commenting out python related code and adjusting visibility of some functions, and suppressing lint warnings.
4/// Limit modifications to this file to make it easy to keep it in sync with upsteam
5// use std::borrow::Borrow;
6// use std::borrow::Cow;
7use std::collections::HashSet;
8use std::num::NonZeroU64;
9use std::thread;
10
11use fancy_regex::Regex;
12// #[cfg(feature = "python")]
13// use pyo3::prelude::*;
14use rustc_hash::FxHashMap as HashMap;
15
16// #[cfg(feature = "python")]
17// mod py;
18
19pub type Rank = u32;
20
21use std::collections::BinaryHeap;
22
23#[derive(Eq, PartialEq, Clone, Copy)]
24struct Merge {
25    start: usize,
26    rank: Rank,
27}
28
29impl Ord for Merge {
30    #[inline]
31    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
32        other
33            .rank
34            .cmp(&self.rank)
35            .then_with(|| other.start.cmp(&self.start))
36    }
37}
38
39impl PartialOrd for Merge {
40    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
41        Some(self.cmp(other))
42    }
43}
44
45struct State {
46    prev: usize,
47    end: usize,
48    next_end: usize,
49    next_rank: Rank,
50    cur_rank: Rank,
51}
52
53fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
54    let mut state = Vec::with_capacity(piece.len());
55    state.push(State {
56        prev: usize::MAX,
57        end: 1,
58        next_end: 2,
59        next_rank: Rank::MAX,
60        cur_rank: Rank::MAX,
61    });
62
63    let mut heap = BinaryHeap::with_capacity(piece.len());
64    for i in 0..piece.len() - 1 {
65        if let Some(&rank) = ranks.get(&piece[i..i + 2]) {
66            heap.push(Merge { start: i, rank });
67            state[i].next_rank = rank;
68        }
69        // note this is happening offset by 1
70        state.push(State {
71            prev: i,
72            end: i + 2,
73            next_end: i + 3,
74            next_rank: Rank::MAX,
75            cur_rank: Rank::MAX,
76        });
77    }
78
79    // Repeatedly find the valid merge with smallest rank. We merge the (left) token that
80    // starts at `start` and ends at `state[start].end` with the (right) token that starts at
81    // `state[start].end` and ends at `state[start].next_end`.  We invalidate the old merges
82    // (the ones that started at `state[start].end` and ended at `state[start]`) and add the two
83    // new potential merges to the heap.
84
85    let potential_merge = {
86        #[inline(always)]
87        |state: &mut Vec<State>,
88         heap: &mut BinaryHeap<Merge>,
89         start: usize,
90         next_end_item: usize| {
91            state[start].next_end = next_end_item;
92            state[start].next_rank = Rank::MAX; // Always invalidate the old merge
93            if next_end_item <= piece.len()
94                && let Some(&rank) = ranks.get(&piece[start..next_end_item])
95            {
96                // We have a valid potential merge!
97                heap.push(Merge { start, rank });
98                state[start].next_rank = rank;
99            }
100        }
101    };
102
103    while let Some(left) = heap.pop() {
104        if left.rank == Rank::MAX {
105            break;
106        }
107        if left.rank != state[left.start].next_rank {
108            continue; // This merge was invalidated, ignore it
109        }
110
111        let left_start = left.start;
112        let right_start = state[left_start].end;
113        let right_end = state[left_start].next_end;
114        debug_assert!(right_end == state[right_start].end);
115        let right_next_end = state[right_start].next_end;
116
117        // Merge left and right into a single token
118        state[left_start].cur_rank = state[left_start].next_rank;
119        state[left_start].end = right_end;
120        potential_merge(&mut state, &mut heap, left_start, right_next_end);
121        if right_end < state.len() {
122            state[right_end].prev = left_start;
123        }
124        // Update the merge that ends at left_start
125        if left_start > 0 {
126            let prev_start = state[left_start].prev;
127            potential_merge(&mut state, &mut heap, prev_start, right_end);
128        }
129        // Invalidate the merge starting at right_start, so we ignore it when it comes off the heap
130        state[right_start].next_rank = Rank::MAX;
131    }
132
133    let mut result = Vec::new();
134    let mut i = 0;
135    while i < state.len() {
136        if state[i].cur_rank != Rank::MAX {
137            result.push(state[i].cur_rank);
138        } else {
139            result.push(ranks[&piece[i..state[i].end]]);
140        }
141        i = state[i].end;
142    }
143    result
144}
145
146fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
147    // This is a vector of (start, rank).
148    // The rank is of the pair starting at position start.
149    let mut parts = Vec::with_capacity(piece.len() + 1);
150
151    // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
152    // the way we currently do, this is equivalent. An easy way to break this would be to decouple
153    // merge priority from token index or to prevent specific token merges.
154    let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
155    for i in 0..piece.len() - 1 {
156        let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
157        if rank < min_rank.0 {
158            min_rank = (rank, i);
159        }
160        parts.push((i, rank));
161    }
162    parts.push((piece.len() - 1, Rank::MAX));
163    parts.push((piece.len(), Rank::MAX));
164
165    let get_rank = {
166        #[inline(always)]
167        |parts: &Vec<(usize, Rank)>, i: usize| {
168            if (i + 3) < parts.len() {
169                // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
170                // parts[i + 1], see comment in the main loop.
171                *ranks
172                    .get(&piece[parts[i].0..parts[i + 3].0])
173                    .unwrap_or(&Rank::MAX)
174            } else {
175                Rank::MAX
176            }
177        }
178    };
179
180    // If you have n parts and m merges, this does O(mn) work.
181    // We could do something with a heap and do O(m log n) work.
182    // n is often very small so considerations like cache-locality outweigh the algorithmic
183    // complexity downsides of the `parts` vector.
184    while min_rank.0 != Rank::MAX {
185        let i = min_rank.1;
186        // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
187        // `parts.remove(i + 1)` will thrash the cache.
188        if i > 0 {
189            parts[i - 1].1 = get_rank(&parts, i - 1);
190        }
191        parts[i].1 = get_rank(&parts, i);
192        parts.remove(i + 1);
193
194        min_rank = (Rank::MAX, usize::MAX);
195        for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
196            if rank < min_rank.0 {
197                min_rank = (rank, i);
198            }
199        }
200    }
201    parts
202}
203
204pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
205    let piece_len = piece.len();
206
207    if piece_len == 1 {
208        return vec![ranks[piece]];
209    }
210    if piece_len < 100 {
211        return _byte_pair_merge(ranks, piece)
212            .windows(2)
213            .map(|part| ranks[&piece[part[0].0..part[1].0]])
214            .collect();
215    }
216    _byte_pair_merge_large(ranks, piece)
217}
218
219pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
220    assert!(piece.len() > 1);
221    _byte_pair_merge(ranks, piece)
222        .windows(2)
223        .map(|part| &piece[part[0].0..part[1].0])
224        .collect()
225}
226
227// Various performance notes:
228//
229// Regex
230// =====
231// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
232// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
233// the usual regex we use.
234//
235// However, given that we're using a regex parse-able by `regex`, there isn't much difference
236// between using the `regex` crate and using the `fancy_regex` crate.
237//
238// There is an important interaction between threading, `regex` and `fancy_regex`.
239// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
240// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
241// old `regex`, we don't hit this, because `find_iter` has a different code path.
242// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
243// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
244// each thread.
245//
246// Threading
247// =========
248// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
249// So goodbye `rayon`! Let thread count etc be in control of our Python users.
250//
251// Caching
252// =======
253// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
254// Originally, we had one too! Without it, we were only vaguely faster than Python.
255// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
256// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
257// multi-threaded performance even when I only had readers (maybed I messed something up?).
258// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
259// These are exactly the set or merges that are likely to be hot. And now we don't have to think
260// about interior mutability, memory use, or cloning.
261//
262// Hashing
263// =======
264// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
265// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
266// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
267
268pub struct FakeThreadId(NonZeroU64);
269
270fn hash_current_thread() -> usize {
271    // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
272    // that works great for our use case of avoiding collisions in our array. Unfortunately,
273    // it's private. However, there are only so many ways you can layout a u64, so just transmute
274    // https://github.com/rust-lang/rust/issues/67939
275    const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
276    const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
277    let x = unsafe {
278        std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
279    };
280    u64::from(x) as usize
281}
282
283#[derive(Debug, Clone)]
284pub struct DecodeKeyError {
285    pub token: Rank,
286}
287
288impl std::fmt::Display for DecodeKeyError {
289    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
290        write!(f, "Invalid token for decoding: {}", self.token)
291    }
292}
293
294impl std::error::Error for DecodeKeyError {}
295
296#[derive(Debug, Clone)]
297#[allow(dead_code)]
298pub struct DecodeError {
299    pub message: String,
300}
301
302impl std::fmt::Display for DecodeError {
303    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
304        write!(f, "Could not decode tokens: {}", self.message)
305    }
306}
307
308impl std::error::Error for DecodeError {}
309
310#[derive(Debug, Clone)]
311pub struct EncodeError {
312    pub message: String,
313}
314
315impl std::fmt::Display for EncodeError {
316    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
317        write!(f, "Could not encode string: {}", self.message)
318    }
319}
320
321impl std::error::Error for EncodeError {}
322
323pub const MAX_NUM_THREADS: usize = 128;
324
325// #[cfg_attr(feature = "python", pyclass)]
326#[derive(Clone)]
327pub struct CoreBPE {
328    pub(crate) encoder: HashMap<Vec<u8>, Rank>,
329    pub(crate) special_tokens_encoder: HashMap<String, Rank>,
330    pub(crate) decoder: HashMap<Rank, Vec<u8>>,
331    pub(crate) special_tokens_decoder: HashMap<Rank, Vec<u8>>,
332    pub(crate) regex_tls: Vec<Regex>,
333    pub(crate) special_regex_tls: Vec<Regex>,
334    #[allow(dead_code)]
335    pub(crate) sorted_token_bytes: Vec<Vec<u8>>,
336}
337
338impl CoreBPE {
339    fn _get_tl_regex(&self) -> &Regex {
340        // See performance notes above for what this is about
341        // It's also a little janky, please make a better version of it!
342        // However, it's nice that this doesn't leak memory to short-lived threads
343        &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
344    }
345
346    fn _get_tl_special_regex(&self) -> &Regex {
347        &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
348    }
349
350    /// Decodes tokens into a list of bytes.
351    ///
352    /// The bytes are not gauranteed to be a valid utf-8 string.
353    pub fn decode_bytes(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
354        let mut ret = Vec::with_capacity(tokens.len() * 2);
355        for &token in tokens {
356            let token_bytes = match self.decoder.get(&token) {
357                Some(bytes) => bytes,
358                None => self
359                    .special_tokens_decoder
360                    .get(&token)
361                    .ok_or(DecodeKeyError { token })?,
362            };
363            ret.extend(token_bytes);
364        }
365        Ok(ret)
366    }
367
368    pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
369        // This is the core of the encoding logic; the other functions in here
370        // just make things complicated :-)
371        let regex = self._get_tl_regex();
372        let mut ret = vec![];
373        for mat in regex.find_iter(text) {
374            let piece = mat.unwrap().as_str().as_bytes();
375            match self.encoder.get(piece) {
376                Some(token) => ret.push(*token),
377                None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
378            }
379        }
380        ret
381    }
382
383    pub fn encode(
384        &self,
385        text: &str,
386        allowed_special: &HashSet<&str>,
387    ) -> Result<(Vec<Rank>, usize), EncodeError> {
388        let special_regex = self._get_tl_special_regex();
389        let regex = self._get_tl_regex();
390        let mut ret = vec![];
391
392        let mut start = 0;
393        let mut last_piece_token_len = 0;
394        loop {
395            let mut next_special;
396            let mut start_find = start;
397            loop {
398                // Find the next allowed special token, if any
399                next_special = special_regex.find_from_pos(text, start_find).unwrap();
400                match next_special {
401                    Some(m) => {
402                        if allowed_special.contains(&text[m.start()..m.end()]) {
403                            break;
404                        }
405                        start_find = m.start() + 1;
406                    }
407                    None => break,
408                }
409            }
410            let end = next_special.map_or(text.len(), |m| m.start());
411
412            // Okay, here we go, compare this logic to encode_ordinary
413            for mat_res in regex.find_iter(&text[start..end]) {
414                let mat = match mat_res {
415                    Ok(m) => m,
416                    Err(e) => {
417                        return Err(EncodeError {
418                            message: format!("Regex error while tokenizing: {e}"),
419                        });
420                    }
421                };
422
423                let piece = mat.as_str().as_bytes();
424                if let Some(token) = self.encoder.get(piece) {
425                    last_piece_token_len = 1;
426                    ret.push(*token);
427                    continue;
428                }
429                let tokens = byte_pair_encode(piece, &self.encoder);
430                last_piece_token_len = tokens.len();
431                ret.extend(&tokens);
432            }
433
434            match next_special {
435                // And here we push the special token
436                Some(m) => {
437                    let piece = m.as_str();
438                    let token = self.special_tokens_encoder[piece];
439                    ret.push(token);
440                    start = m.end();
441                    last_piece_token_len = 0;
442                }
443                None => break,
444            }
445        }
446
447        // last_piece_token_len is how many tokens came from the last regex split. This is used
448        // for determining unstable tokens, since you can't merge across (stable) regex splits
449        Ok((ret, last_piece_token_len))
450    }
451
452    fn _increase_last_piece_token_len(
453        &self,
454        tokens: Vec<Rank>,
455        mut last_piece_token_len: usize,
456    ) -> (Vec<Rank>, usize) {
457        // Unfortunately, the locations where our regex splits can be unstable.
458        // For the purposes of determining unstable tokens, unstable regex splitting
459        // is only a problem if a split that was present disappears, since this can
460        // lead to merging of tokens otherwise thought to be stable.
461        // cl100k_base makes our life hard by including the \s*[\r\n]+
462        // pattern. This can e.g. cause "\n" + " " to become "\n \n".
463        // Here is a quick and dirty fix:
464        {
465            let token_is_all_space = |token| {
466                self.decoder
467                    .get(token)
468                    .map(|token_bytes| token_bytes.iter().rev().all(|&b| b" \n\t".contains(&b)))
469                    .unwrap_or(false)
470            };
471            if last_piece_token_len > 0
472                && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
473            {
474                while (last_piece_token_len < tokens.len())
475                    && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
476                {
477                    last_piece_token_len += 1;
478                }
479            }
480        }
481        debug_assert!(last_piece_token_len <= tokens.len());
482
483        (tokens, last_piece_token_len)
484    }
485
486    pub fn _encode_unstable_native(
487        &self,
488        text: &str,
489        allowed_special: &HashSet<&str>,
490    ) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
491        let (tokens, last_piece_token_len) = self.encode(text, allowed_special).unwrap();
492        if last_piece_token_len == 0 {
493            // If last_piece_token_len is zero, the last token was a special token and we have
494            // no unstable bytes
495            return (tokens, HashSet::new());
496        }
497        let (mut tokens, last_piece_token_len) =
498            self._increase_last_piece_token_len(tokens, last_piece_token_len);
499
500        let unstable_bytes = self
501            .decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
502            .unwrap();
503        tokens.truncate(tokens.len() - last_piece_token_len);
504
505        // TODO: we should try harder to find additional stable tokens
506        // This would reduce the amount of retokenising when determining completions
507        // Refer to the logic in an older version of this file
508
509        let mut completions = HashSet::new();
510        if unstable_bytes.is_empty() {
511            return (tokens, completions);
512        }
513
514        // This is the easy bit. Just find all single tokens that start with unstable_bytes
515        // (including tokens that exactly match unstable_bytes)
516        // Separating this from the loop below helps with performance in a common case.
517        let mut point = self
518            .sorted_token_bytes
519            .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
520        while point < self.sorted_token_bytes.len()
521            && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
522        {
523            completions.insert(vec![
524                self.encoder[self.sorted_token_bytes[point].as_slice()],
525            ]);
526            point += 1;
527        }
528
529        // Now apply even more brute force. At every (other) possible position for the straddling
530        // token, concatenate additional bytes from that token (if any) to unstable_bytes,
531        // and retokenise the whole thing and see what we get.
532        for i in 1..unstable_bytes.len() {
533            let prefix = &unstable_bytes[..i];
534            let suffix = &unstable_bytes[i..];
535            let mut point = self
536                .sorted_token_bytes
537                .partition_point(|x| x.as_slice() < suffix);
538            // TODO: Perf optimisation if suffix starts with " "?
539            while point < self.sorted_token_bytes.len()
540                && self.sorted_token_bytes[point].starts_with(suffix)
541            {
542                let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
543                let encoded = match std::str::from_utf8(&possibility) {
544                    // Morally, this is byte_pair_encode(&possibility, &self.encoder)
545                    // But we might have introduced a regex split which would prevent merges.
546                    // (particularly possible in the presence of unstable regex splits)
547                    // So convert to UTF-8 and do regex splitting.
548                    // E.g. with cl100k_base "  !" gets split to " " + " !",
549                    // but byte_pair_encode("  !") != byte_pair_encode(" ")
550                    Ok(s) => self.encode_ordinary(s),
551
552                    // Technically, whether or not this arm is correct depends on whether there
553                    // would be a regex split before the UTF-8 truncation point.
554                    // Probably niche enough that no one will ever notice (after all, people didn't
555                    // notice all the big holes in the previous unstable token implementation)
556                    Err(_) => byte_pair_encode(&possibility, &self.encoder),
557                    // Something like the following is intriguing but incorrect:
558                    // Err(e) => self.encode_ordinary(unsafe {
559                    //     std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
560                    // }),
561                };
562                let mut seq = Vec::new();
563                let mut seq_len = 0;
564                for token in encoded {
565                    seq.push(token);
566                    seq_len += self.decoder[&token].len();
567                    if seq_len >= unstable_bytes.len() {
568                        break;
569                    }
570                }
571                completions.insert(seq);
572                point += 1;
573            }
574        }
575
576        // This is also not straightforward. While we generally assume that regex splits are stable,
577        // unfortunately, they are not. That is, if adding bytes were to make a split appear in
578        // unstable_bytes, this could make tokens possible which our logic would otherwise think
579        // would be merged.
580        // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
581        // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
582        // Here is a quick and dirty fix:
583        // This isn't right if we ever remove \s+(?!\S)
584        if unstable_bytes.len() > 1 {
585            let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
586            if unstable_bytes.len() - last_decoded.1 > 0
587                && last_decoded.0.is_some_and(|c| c.is_whitespace())
588            {
589                let mut reencoded = byte_pair_encode(
590                    &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
591                    &self.encoder,
592                );
593                reencoded.extend(byte_pair_encode(
594                    &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
595                    &self.encoder,
596                ));
597                completions.insert(reencoded);
598            }
599        }
600
601        (tokens, completions)
602    }
603
604    // pub fn new<E, SE, NSE>(
605    //     encoder: E,
606    //     special_tokens_encoder: SE,
607    //     pattern: &str,
608    // ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
609    // where
610    //     E: IntoIterator<Item = (Vec<u8>, Rank)>,
611    //     SE: IntoIterator<Item = (String, Rank)>,
612    //     NSE: IntoIterator<Item = (String, (Rank, Rank))>,
613    // {
614    //     Self::new_internal(
615    //         HashMap::from_iter(encoder),
616    //         HashMap::from_iter(special_tokens_encoder),
617    //         pattern,
618    //     )
619    // }
620
621    // fn new_internal(
622    //     encoder: HashMap<Vec<u8>, Rank>,
623    //     special_tokens_encoder: HashMap<String, Rank>,
624    //     pattern: &str,
625    // ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
626    //     let regex = Regex::new(pattern)?;
627
628    //     let special_regex = {
629    //         let parts = special_tokens_encoder
630    //             .keys()
631    //             .map(|s| fancy_regex::escape(s))
632    //             .collect::<Vec<_>>();
633    //         Regex::new(&parts.join("|"))?
634    //     };
635
636    //     let decoder: HashMap<Rank, Vec<u8>> =
637    //         encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
638
639    //     assert!(
640    //         encoder.len() == decoder.len(),
641    //         "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
642    //     );
643
644    //     let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
645    //         .iter()
646    //         .map(|(k, v)| (*v, k.as_bytes().to_vec()))
647    //         .collect();
648
649    //     // Clone because I don't know how to tell Rust I'm not going to change the map
650    //     let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
651    //     sorted_token_bytes.sort();
652
653    //     Ok(Self {
654    //         encoder,
655    //         special_tokens_encoder,
656    //         decoder,
657    //         special_tokens_decoder,
658    //         regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
659    //         special_regex_tls: (0..MAX_NUM_THREADS)
660    //             .map(|_| special_regex.clone())
661    //             .collect(),
662    //         sorted_token_bytes,
663    //     })
664    // }
665
666    pub fn special_tokens(&self) -> HashSet<&str> {
667        self.special_tokens_encoder
668            .keys()
669            .map(|s| s.as_str())
670            .collect()
671    }
672
673    pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
674        let allowed_special = self.special_tokens();
675        self.encode(text, &allowed_special).unwrap().0
676    }
677}
678
679#[cfg(test)]
680mod tests {
681    // use fancy_regex::Regex;
682    use rustc_hash::FxHashMap as HashMap;
683
684    use crate::{Rank, byte_pair_split};
685
686    fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
687        HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)])
688    }
689
690    #[test]
691    fn test_simple_characters() {
692        let ranks = setup_ranks();
693        let res = byte_pair_split(b"abcd", &ranks);
694        assert_eq!(res, vec![b"ab", b"cd"]);
695    }
696
697    #[test]
698    fn test_repeated_characters() {
699        let ranks = setup_ranks();
700        let res = byte_pair_split(b"abab", &ranks);
701        assert_eq!(res, vec![b"ab", b"ab"]);
702    }
703}