tiktoken_rs/vendor_tiktoken.rs
1#[rustfmt::skip]
2/// This file is a vendored copy of the `tiktoken` crate.
3/// Modifications are limited to commenting out python related code and adjusting visibility of some functions, and suppressing lint warnings.
4/// Limit modifications to this file to make it easy to keep it in sync with upsteam
5// use std::borrow::Borrow;
6// use std::borrow::Cow;
7use std::collections::HashSet;
8use std::num::NonZeroU64;
9use std::thread;
10
11use fancy_regex::Regex;
12// #[cfg(feature = "python")]
13// use pyo3::prelude::*;
14use rustc_hash::FxHashMap as HashMap;
15
16// #[cfg(feature = "python")]
17// mod py;
18
19pub type Rank = u32;
20
21use std::collections::BinaryHeap;
22
23#[derive(Eq, PartialEq, Clone, Copy)]
24struct Merge {
25 start: usize,
26 rank: Rank,
27}
28
29impl Ord for Merge {
30 #[inline]
31 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
32 other
33 .rank
34 .cmp(&self.rank)
35 .then_with(|| other.start.cmp(&self.start))
36 }
37}
38
39impl PartialOrd for Merge {
40 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
41 Some(self.cmp(other))
42 }
43}
44
45struct State {
46 prev: usize,
47 end: usize,
48 next_end: usize,
49 next_rank: Rank,
50 cur_rank: Rank,
51}
52
53fn _byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
54 let mut state = Vec::with_capacity(piece.len());
55 state.push(State {
56 prev: usize::MAX,
57 end: 1,
58 next_end: 2,
59 next_rank: Rank::MAX,
60 cur_rank: Rank::MAX,
61 });
62
63 let mut heap = BinaryHeap::with_capacity(piece.len());
64 for i in 0..piece.len() - 1 {
65 if let Some(&rank) = ranks.get(&piece[i..i + 2]) {
66 heap.push(Merge { start: i, rank });
67 state[i].next_rank = rank;
68 }
69 // note this is happening offset by 1
70 state.push(State {
71 prev: i,
72 end: i + 2,
73 next_end: i + 3,
74 next_rank: Rank::MAX,
75 cur_rank: Rank::MAX,
76 });
77 }
78
79 // Repeatedly find the valid merge with smallest rank. We merge the (left) token that
80 // starts at `start` and ends at `state[start].end` with the (right) token that starts at
81 // `state[start].end` and ends at `state[start].next_end`. We invalidate the old merges
82 // (the ones that started at `state[start].end` and ended at `state[start]`) and add the two
83 // new potential merges to the heap.
84
85 let potential_merge = {
86 #[inline(always)]
87 |state: &mut Vec<State>,
88 heap: &mut BinaryHeap<Merge>,
89 start: usize,
90 next_end_item: usize| {
91 state[start].next_end = next_end_item;
92 state[start].next_rank = Rank::MAX; // Always invalidate the old merge
93 if next_end_item <= piece.len()
94 && let Some(&rank) = ranks.get(&piece[start..next_end_item])
95 {
96 // We have a valid potential merge!
97 heap.push(Merge { start, rank });
98 state[start].next_rank = rank;
99 }
100 }
101 };
102
103 while let Some(left) = heap.pop() {
104 if left.rank == Rank::MAX {
105 break;
106 }
107 if left.rank != state[left.start].next_rank {
108 continue; // This merge was invalidated, ignore it
109 }
110
111 let left_start = left.start;
112 let right_start = state[left_start].end;
113 let right_end = state[left_start].next_end;
114 debug_assert!(right_end == state[right_start].end);
115 let right_next_end = state[right_start].next_end;
116
117 // Merge left and right into a single token
118 state[left_start].cur_rank = state[left_start].next_rank;
119 state[left_start].end = right_end;
120 potential_merge(&mut state, &mut heap, left_start, right_next_end);
121 if right_end < state.len() {
122 state[right_end].prev = left_start;
123 }
124 // Update the merge that ends at left_start
125 if left_start > 0 {
126 let prev_start = state[left_start].prev;
127 potential_merge(&mut state, &mut heap, prev_start, right_end);
128 }
129 // Invalidate the merge starting at right_start, so we ignore it when it comes off the heap
130 state[right_start].next_rank = Rank::MAX;
131 }
132
133 let mut result = Vec::new();
134 let mut i = 0;
135 while i < state.len() {
136 if state[i].cur_rank != Rank::MAX {
137 result.push(state[i].cur_rank);
138 } else {
139 result.push(ranks[&piece[i..state[i].end]]);
140 }
141 i = state[i].end;
142 }
143 result
144}
145
146fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
147 // This is a vector of (start, rank).
148 // The rank is of the pair starting at position start.
149 let mut parts = Vec::with_capacity(piece.len() + 1);
150
151 // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
152 // the way we currently do, this is equivalent. An easy way to break this would be to decouple
153 // merge priority from token index or to prevent specific token merges.
154 let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
155 for i in 0..piece.len() - 1 {
156 let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
157 if rank < min_rank.0 {
158 min_rank = (rank, i);
159 }
160 parts.push((i, rank));
161 }
162 parts.push((piece.len() - 1, Rank::MAX));
163 parts.push((piece.len(), Rank::MAX));
164
165 let get_rank = {
166 #[inline(always)]
167 |parts: &Vec<(usize, Rank)>, i: usize| {
168 if (i + 3) < parts.len() {
169 // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
170 // parts[i + 1], see comment in the main loop.
171 *ranks
172 .get(&piece[parts[i].0..parts[i + 3].0])
173 .unwrap_or(&Rank::MAX)
174 } else {
175 Rank::MAX
176 }
177 }
178 };
179
180 // If you have n parts and m merges, this does O(mn) work.
181 // We could do something with a heap and do O(m log n) work.
182 // n is often very small so considerations like cache-locality outweigh the algorithmic
183 // complexity downsides of the `parts` vector.
184 while min_rank.0 != Rank::MAX {
185 let i = min_rank.1;
186 // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
187 // `parts.remove(i + 1)` will thrash the cache.
188 if i > 0 {
189 parts[i - 1].1 = get_rank(&parts, i - 1);
190 }
191 parts[i].1 = get_rank(&parts, i);
192 parts.remove(i + 1);
193
194 min_rank = (Rank::MAX, usize::MAX);
195 for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
196 if rank < min_rank.0 {
197 min_rank = (rank, i);
198 }
199 }
200 }
201 parts
202}
203
204pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
205 let piece_len = piece.len();
206
207 if piece_len == 1 {
208 return vec![ranks[piece]];
209 }
210 if piece_len < 100 {
211 return _byte_pair_merge(ranks, piece)
212 .windows(2)
213 .map(|part| ranks[&piece[part[0].0..part[1].0]])
214 .collect();
215 }
216 _byte_pair_merge_large(ranks, piece)
217}
218
219pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
220 assert!(piece.len() > 1);
221 _byte_pair_merge(ranks, piece)
222 .windows(2)
223 .map(|part| &piece[part[0].0..part[1].0])
224 .collect()
225}
226
227// Various performance notes:
228//
229// Regex
230// =====
231// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
232// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
233// the usual regex we use.
234//
235// However, given that we're using a regex parse-able by `regex`, there isn't much difference
236// between using the `regex` crate and using the `fancy_regex` crate.
237//
238// There is an important interaction between threading, `regex` and `fancy_regex`.
239// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
240// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
241// old `regex`, we don't hit this, because `find_iter` has a different code path.
242// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
243// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
244// each thread.
245//
246// Threading
247// =========
248// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
249// So goodbye `rayon`! Let thread count etc be in control of our Python users.
250//
251// Caching
252// =======
253// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
254// Originally, we had one too! Without it, we were only vaguely faster than Python.
255// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
256// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
257// multi-threaded performance even when I only had readers (maybed I messed something up?).
258// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
259// These are exactly the set or merges that are likely to be hot. And now we don't have to think
260// about interior mutability, memory use, or cloning.
261//
262// Hashing
263// =======
264// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
265// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
266// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
267
268pub struct FakeThreadId(NonZeroU64);
269
270fn hash_current_thread() -> usize {
271 // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
272 // that works great for our use case of avoiding collisions in our array. Unfortunately,
273 // it's private. However, there are only so many ways you can layout a u64, so just transmute
274 // https://github.com/rust-lang/rust/issues/67939
275 const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
276 const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
277 let x = unsafe {
278 std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
279 };
280 u64::from(x) as usize
281}
282
283#[derive(Debug, Clone)]
284pub struct DecodeKeyError {
285 pub token: Rank,
286}
287
288impl std::fmt::Display for DecodeKeyError {
289 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
290 write!(f, "Invalid token for decoding: {}", self.token)
291 }
292}
293
294impl std::error::Error for DecodeKeyError {}
295
296#[derive(Debug, Clone)]
297#[allow(dead_code)]
298pub struct DecodeError {
299 pub message: String,
300}
301
302impl std::fmt::Display for DecodeError {
303 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
304 write!(f, "Could not decode tokens: {}", self.message)
305 }
306}
307
308impl std::error::Error for DecodeError {}
309
310#[derive(Debug, Clone)]
311pub struct EncodeError {
312 pub message: String,
313}
314
315impl std::fmt::Display for EncodeError {
316 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
317 write!(f, "Could not encode string: {}", self.message)
318 }
319}
320
321impl std::error::Error for EncodeError {}
322
323pub const MAX_NUM_THREADS: usize = 128;
324
325// #[cfg_attr(feature = "python", pyclass)]
326#[derive(Clone)]
327pub struct CoreBPE {
328 pub(crate) encoder: HashMap<Vec<u8>, Rank>,
329 pub(crate) special_tokens_encoder: HashMap<String, Rank>,
330 pub(crate) decoder: HashMap<Rank, Vec<u8>>,
331 pub(crate) special_tokens_decoder: HashMap<Rank, Vec<u8>>,
332 pub(crate) regex_tls: Vec<Regex>,
333 pub(crate) special_regex_tls: Vec<Regex>,
334 #[allow(dead_code)]
335 pub(crate) sorted_token_bytes: Vec<Vec<u8>>,
336}
337
338impl CoreBPE {
339 fn _get_tl_regex(&self) -> &Regex {
340 // See performance notes above for what this is about
341 // It's also a little janky, please make a better version of it!
342 // However, it's nice that this doesn't leak memory to short-lived threads
343 &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
344 }
345
346 fn _get_tl_special_regex(&self) -> &Regex {
347 &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
348 }
349
350 /// Decodes tokens into a list of bytes.
351 ///
352 /// The bytes are not gauranteed to be a valid utf-8 string.
353 pub fn decode_bytes(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
354 let mut ret = Vec::with_capacity(tokens.len() * 2);
355 for &token in tokens {
356 let token_bytes = match self.decoder.get(&token) {
357 Some(bytes) => bytes,
358 None => self
359 .special_tokens_decoder
360 .get(&token)
361 .ok_or(DecodeKeyError { token })?,
362 };
363 ret.extend(token_bytes);
364 }
365 Ok(ret)
366 }
367
368 pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
369 // This is the core of the encoding logic; the other functions in here
370 // just make things complicated :-)
371 let regex = self._get_tl_regex();
372 let mut ret = vec![];
373 for mat in regex.find_iter(text) {
374 let piece = mat.unwrap().as_str().as_bytes();
375 match self.encoder.get(piece) {
376 Some(token) => ret.push(*token),
377 None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
378 }
379 }
380 ret
381 }
382
383 pub fn encode(
384 &self,
385 text: &str,
386 allowed_special: &HashSet<&str>,
387 ) -> Result<(Vec<Rank>, usize), EncodeError> {
388 let special_regex = self._get_tl_special_regex();
389 let regex = self._get_tl_regex();
390 let mut ret = vec![];
391
392 let mut start = 0;
393 let mut last_piece_token_len = 0;
394 loop {
395 let mut next_special;
396 let mut start_find = start;
397 loop {
398 // Find the next allowed special token, if any
399 next_special = special_regex.find_from_pos(text, start_find).unwrap();
400 match next_special {
401 Some(m) => {
402 if allowed_special.contains(&text[m.start()..m.end()]) {
403 break;
404 }
405 start_find = m.start() + 1;
406 }
407 None => break,
408 }
409 }
410 let end = next_special.map_or(text.len(), |m| m.start());
411
412 // Okay, here we go, compare this logic to encode_ordinary
413 for mat_res in regex.find_iter(&text[start..end]) {
414 let mat = match mat_res {
415 Ok(m) => m,
416 Err(e) => {
417 return Err(EncodeError {
418 message: format!("Regex error while tokenizing: {e}"),
419 });
420 }
421 };
422
423 let piece = mat.as_str().as_bytes();
424 if let Some(token) = self.encoder.get(piece) {
425 last_piece_token_len = 1;
426 ret.push(*token);
427 continue;
428 }
429 let tokens = byte_pair_encode(piece, &self.encoder);
430 last_piece_token_len = tokens.len();
431 ret.extend(&tokens);
432 }
433
434 match next_special {
435 // And here we push the special token
436 Some(m) => {
437 let piece = m.as_str();
438 let token = self.special_tokens_encoder[piece];
439 ret.push(token);
440 start = m.end();
441 last_piece_token_len = 0;
442 }
443 None => break,
444 }
445 }
446
447 // last_piece_token_len is how many tokens came from the last regex split. This is used
448 // for determining unstable tokens, since you can't merge across (stable) regex splits
449 Ok((ret, last_piece_token_len))
450 }
451
452 fn _increase_last_piece_token_len(
453 &self,
454 tokens: Vec<Rank>,
455 mut last_piece_token_len: usize,
456 ) -> (Vec<Rank>, usize) {
457 // Unfortunately, the locations where our regex splits can be unstable.
458 // For the purposes of determining unstable tokens, unstable regex splitting
459 // is only a problem if a split that was present disappears, since this can
460 // lead to merging of tokens otherwise thought to be stable.
461 // cl100k_base makes our life hard by including the \s*[\r\n]+
462 // pattern. This can e.g. cause "\n" + " " to become "\n \n".
463 // Here is a quick and dirty fix:
464 {
465 let token_is_all_space = |token| {
466 self.decoder
467 .get(token)
468 .map(|token_bytes| token_bytes.iter().rev().all(|&b| b" \n\t".contains(&b)))
469 .unwrap_or(false)
470 };
471 if last_piece_token_len > 0
472 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
473 {
474 while (last_piece_token_len < tokens.len())
475 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
476 {
477 last_piece_token_len += 1;
478 }
479 }
480 }
481 debug_assert!(last_piece_token_len <= tokens.len());
482
483 (tokens, last_piece_token_len)
484 }
485
486 pub fn _encode_unstable_native(
487 &self,
488 text: &str,
489 allowed_special: &HashSet<&str>,
490 ) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
491 let (tokens, last_piece_token_len) = self.encode(text, allowed_special).unwrap();
492 if last_piece_token_len == 0 {
493 // If last_piece_token_len is zero, the last token was a special token and we have
494 // no unstable bytes
495 return (tokens, HashSet::new());
496 }
497 let (mut tokens, last_piece_token_len) =
498 self._increase_last_piece_token_len(tokens, last_piece_token_len);
499
500 let unstable_bytes = self
501 .decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
502 .unwrap();
503 tokens.truncate(tokens.len() - last_piece_token_len);
504
505 // TODO: we should try harder to find additional stable tokens
506 // This would reduce the amount of retokenising when determining completions
507 // Refer to the logic in an older version of this file
508
509 let mut completions = HashSet::new();
510 if unstable_bytes.is_empty() {
511 return (tokens, completions);
512 }
513
514 // This is the easy bit. Just find all single tokens that start with unstable_bytes
515 // (including tokens that exactly match unstable_bytes)
516 // Separating this from the loop below helps with performance in a common case.
517 let mut point = self
518 .sorted_token_bytes
519 .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
520 while point < self.sorted_token_bytes.len()
521 && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
522 {
523 completions.insert(vec![
524 self.encoder[self.sorted_token_bytes[point].as_slice()],
525 ]);
526 point += 1;
527 }
528
529 // Now apply even more brute force. At every (other) possible position for the straddling
530 // token, concatenate additional bytes from that token (if any) to unstable_bytes,
531 // and retokenise the whole thing and see what we get.
532 for i in 1..unstable_bytes.len() {
533 let prefix = &unstable_bytes[..i];
534 let suffix = &unstable_bytes[i..];
535 let mut point = self
536 .sorted_token_bytes
537 .partition_point(|x| x.as_slice() < suffix);
538 // TODO: Perf optimisation if suffix starts with " "?
539 while point < self.sorted_token_bytes.len()
540 && self.sorted_token_bytes[point].starts_with(suffix)
541 {
542 let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
543 let encoded = match std::str::from_utf8(&possibility) {
544 // Morally, this is byte_pair_encode(&possibility, &self.encoder)
545 // But we might have introduced a regex split which would prevent merges.
546 // (particularly possible in the presence of unstable regex splits)
547 // So convert to UTF-8 and do regex splitting.
548 // E.g. with cl100k_base " !" gets split to " " + " !",
549 // but byte_pair_encode(" !") != byte_pair_encode(" ")
550 Ok(s) => self.encode_ordinary(s),
551
552 // Technically, whether or not this arm is correct depends on whether there
553 // would be a regex split before the UTF-8 truncation point.
554 // Probably niche enough that no one will ever notice (after all, people didn't
555 // notice all the big holes in the previous unstable token implementation)
556 Err(_) => byte_pair_encode(&possibility, &self.encoder),
557 // Something like the following is intriguing but incorrect:
558 // Err(e) => self.encode_ordinary(unsafe {
559 // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
560 // }),
561 };
562 let mut seq = Vec::new();
563 let mut seq_len = 0;
564 for token in encoded {
565 seq.push(token);
566 seq_len += self.decoder[&token].len();
567 if seq_len >= unstable_bytes.len() {
568 break;
569 }
570 }
571 completions.insert(seq);
572 point += 1;
573 }
574 }
575
576 // This is also not straightforward. While we generally assume that regex splits are stable,
577 // unfortunately, they are not. That is, if adding bytes were to make a split appear in
578 // unstable_bytes, this could make tokens possible which our logic would otherwise think
579 // would be merged.
580 // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
581 // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
582 // Here is a quick and dirty fix:
583 // This isn't right if we ever remove \s+(?!\S)
584 if unstable_bytes.len() > 1 {
585 let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
586 if unstable_bytes.len() - last_decoded.1 > 0
587 && last_decoded.0.is_some_and(|c| c.is_whitespace())
588 {
589 let mut reencoded = byte_pair_encode(
590 &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
591 &self.encoder,
592 );
593 reencoded.extend(byte_pair_encode(
594 &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
595 &self.encoder,
596 ));
597 completions.insert(reencoded);
598 }
599 }
600
601 (tokens, completions)
602 }
603
604 // pub fn new<E, SE, NSE>(
605 // encoder: E,
606 // special_tokens_encoder: SE,
607 // pattern: &str,
608 // ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
609 // where
610 // E: IntoIterator<Item = (Vec<u8>, Rank)>,
611 // SE: IntoIterator<Item = (String, Rank)>,
612 // NSE: IntoIterator<Item = (String, (Rank, Rank))>,
613 // {
614 // Self::new_internal(
615 // HashMap::from_iter(encoder),
616 // HashMap::from_iter(special_tokens_encoder),
617 // pattern,
618 // )
619 // }
620
621 // fn new_internal(
622 // encoder: HashMap<Vec<u8>, Rank>,
623 // special_tokens_encoder: HashMap<String, Rank>,
624 // pattern: &str,
625 // ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
626 // let regex = Regex::new(pattern)?;
627
628 // let special_regex = {
629 // let parts = special_tokens_encoder
630 // .keys()
631 // .map(|s| fancy_regex::escape(s))
632 // .collect::<Vec<_>>();
633 // Regex::new(&parts.join("|"))?
634 // };
635
636 // let decoder: HashMap<Rank, Vec<u8>> =
637 // encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
638
639 // assert!(
640 // encoder.len() == decoder.len(),
641 // "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
642 // );
643
644 // let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
645 // .iter()
646 // .map(|(k, v)| (*v, k.as_bytes().to_vec()))
647 // .collect();
648
649 // // Clone because I don't know how to tell Rust I'm not going to change the map
650 // let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
651 // sorted_token_bytes.sort();
652
653 // Ok(Self {
654 // encoder,
655 // special_tokens_encoder,
656 // decoder,
657 // special_tokens_decoder,
658 // regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
659 // special_regex_tls: (0..MAX_NUM_THREADS)
660 // .map(|_| special_regex.clone())
661 // .collect(),
662 // sorted_token_bytes,
663 // })
664 // }
665
666 pub fn special_tokens(&self) -> HashSet<&str> {
667 self.special_tokens_encoder
668 .keys()
669 .map(|s| s.as_str())
670 .collect()
671 }
672
673 pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
674 let allowed_special = self.special_tokens();
675 self.encode(text, &allowed_special).unwrap().0
676 }
677}
678
679#[cfg(test)]
680mod tests {
681 // use fancy_regex::Regex;
682 use rustc_hash::FxHashMap as HashMap;
683
684 use crate::{Rank, byte_pair_split};
685
686 fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
687 HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)])
688 }
689
690 #[test]
691 fn test_simple_characters() {
692 let ranks = setup_ranks();
693 let res = byte_pair_split(b"abcd", &ranks);
694 assert_eq!(res, vec![b"ab", b"cd"]);
695 }
696
697 #[test]
698 fn test_repeated_characters() {
699 let ranks = setup_ranks();
700 let res = byte_pair_split(b"abab", &ranks);
701 assert_eq!(res, vec![b"ab", b"ab"]);
702 }
703}