tiktoken_rs/vendor_tiktoken.rs
1#[rustfmt::skip]
2/// This file is a vendored copy of the `tiktoken` crate.
3/// Modifications are limited to commenting out python related code and adjusting visibility of some functions, and suppressing lint warnings.
4/// Limit modifications to this file to make it easy to keep it in sync with upsteam
5// use std::borrow::Borrow;
6// use std::borrow::Cow;
7use std::collections::HashSet;
8use std::num::NonZeroU64;
9use std::thread;
10
11use fancy_regex::Regex;
12// #[cfg(feature = "python")]
13// use pyo3::prelude::*;
14use rustc_hash::FxHashMap as HashMap;
15
16// #[cfg(feature = "python")]
17// mod py;
18
19pub type Rank = u32;
20
21fn _byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
22 // This is a vector of (start, rank).
23 // The rank is of the pair starting at position start.
24 let mut parts = Vec::with_capacity(piece.len() + 1);
25
26 // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
27 // the way we currently do, this is equivalent. An easy way to break this would be to decouple
28 // merge priority from token index or to prevent specific token merges.
29 let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
30 for i in 0..piece.len() - 1 {
31 let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
32 if rank < min_rank.0 {
33 min_rank = (rank, i);
34 }
35 parts.push((i, rank));
36 }
37 parts.push((piece.len() - 1, Rank::MAX));
38 parts.push((piece.len(), Rank::MAX));
39
40 let get_rank = {
41 #[inline(always)]
42 |parts: &Vec<(usize, Rank)>, i: usize| {
43 if (i + 3) < parts.len() {
44 // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
45 // parts[i + 1], see comment in the main loop.
46 *ranks
47 .get(&piece[parts[i].0..parts[i + 3].0])
48 .unwrap_or(&Rank::MAX)
49 } else {
50 Rank::MAX
51 }
52 }
53 };
54
55 // If you have n parts and m merges, this does O(mn) work.
56 // We could do something with a heap and do O(m log n) work.
57 // n is often very small so considerations like cache-locality outweigh the algorithmic
58 // complexity downsides of the `parts` vector.
59 while min_rank.0 != Rank::MAX {
60 let i = min_rank.1;
61 // Update parts[i] and parts[i - 1] before removing parts[i + 1], since
62 // `parts.remove(i + 1)` will thrash the cache.
63 if i > 0 {
64 parts[i - 1].1 = get_rank(&parts, i - 1);
65 }
66 parts[i].1 = get_rank(&parts, i);
67 parts.remove(i + 1);
68
69 min_rank = (Rank::MAX, usize::MAX);
70 for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
71 if rank < min_rank.0 {
72 min_rank = (rank, i);
73 }
74 }
75 }
76 parts
77}
78
79pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
80 if piece.len() == 1 {
81 return vec![ranks[piece]];
82 }
83 _byte_pair_merge(ranks, piece)
84 .windows(2)
85 .map(|part| ranks[&piece[part[0].0..part[1].0]])
86 .collect()
87}
88
89pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<&'a [u8]> {
90 assert!(piece.len() > 1);
91 _byte_pair_merge(ranks, piece)
92 .windows(2)
93 .map(|part| &piece[part[0].0..part[1].0])
94 .collect()
95}
96
97// Various performance notes:
98//
99// Regex
100// =====
101// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy
102// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than
103// the usual regex we use.
104//
105// However, given that we're using a regex parse-able by `regex`, there isn't much difference
106// between using the `regex` crate and using the `fancy_regex` crate.
107//
108// There is an important interaction between threading, `regex` and `fancy_regex`.
109// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on
110// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain
111// old `regex`, we don't hit this, because `find_iter` has a different code path.
112// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md
113// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for
114// each thread.
115//
116// Threading
117// =========
118// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL.
119// So goodbye `rayon`! Let thread count etc be in control of our Python users.
120//
121// Caching
122// =======
123// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`.
124// Originally, we had one too! Without it, we were only vaguely faster than Python.
125// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance
126// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect
127// multi-threaded performance even when I only had readers (maybed I messed something up?).
128// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache!
129// These are exactly the set or merges that are likely to be hot. And now we don't have to think
130// about interior mutability, memory use, or cloning.
131//
132// Hashing
133// =======
134// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win?
135// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made
136// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster.
137
138pub struct FakeThreadId(NonZeroU64);
139
140fn hash_current_thread() -> usize {
141 // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter
142 // that works great for our use case of avoiding collisions in our array. Unfortunately,
143 // it's private. However, there are only so many ways you can layout a u64, so just transmute
144 // https://github.com/rust-lang/rust/issues/67939
145 const _: [u8; 8] = [0; std::mem::size_of::<std::thread::ThreadId>()];
146 const _: [u8; 8] = [0; std::mem::size_of::<FakeThreadId>()];
147 let x = unsafe {
148 std::mem::transmute::<std::thread::ThreadId, FakeThreadId>(thread::current().id()).0
149 };
150 u64::from(x) as usize
151}
152
153#[derive(Debug, Clone)]
154pub struct DecodeKeyError {
155 pub token: Rank,
156}
157
158impl std::fmt::Display for DecodeKeyError {
159 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
160 write!(f, "Invalid token for decoding: {}", self.token)
161 }
162}
163
164impl std::error::Error for DecodeKeyError {}
165
166#[derive(Debug, Clone)]
167#[allow(dead_code)]
168pub struct DecodeError {
169 pub message: String,
170}
171
172impl std::fmt::Display for DecodeError {
173 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
174 write!(f, "Could not decode tokens: {}", self.message)
175 }
176}
177
178impl std::error::Error for DecodeError {}
179
180pub const MAX_NUM_THREADS: usize = 128;
181
182// #[cfg_attr(feature = "python", pyclass)]
183#[derive(Clone)]
184pub struct CoreBPE {
185 pub(crate) encoder: HashMap<Vec<u8>, Rank>,
186 pub(crate) special_tokens_encoder: HashMap<String, Rank>,
187 pub(crate) decoder: HashMap<Rank, Vec<u8>>,
188 pub(crate) special_tokens_decoder: HashMap<Rank, Vec<u8>>,
189 pub(crate) regex_tls: Vec<Regex>,
190 pub(crate) special_regex_tls: Vec<Regex>,
191 #[allow(dead_code)]
192 pub(crate) sorted_token_bytes: Vec<Vec<u8>>,
193}
194
195impl CoreBPE {
196 fn _get_tl_regex(&self) -> &Regex {
197 // See performance notes above for what this is about
198 // It's also a little janky, please make a better version of it!
199 // However, it's nice that this doesn't leak memory to short-lived threads
200 &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS]
201 }
202
203 fn _get_tl_special_regex(&self) -> &Regex {
204 &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS]
205 }
206
207 /// Decodes tokens into a list of bytes.
208 ///
209 /// The bytes are not gauranteed to be a valid utf-8 string.
210 pub(crate) fn decode_bytes(&self, tokens: &[Rank]) -> Result<Vec<u8>, DecodeKeyError> {
211 let mut ret = Vec::with_capacity(tokens.len() * 2);
212 for &token in tokens {
213 let token_bytes = match self.decoder.get(&token) {
214 Some(bytes) => bytes,
215 None => self
216 .special_tokens_decoder
217 .get(&token)
218 .ok_or(DecodeKeyError { token })?,
219 };
220 ret.extend(token_bytes);
221 }
222 Ok(ret)
223 }
224
225 pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
226 // This is the core of the encoding logic; the other functions in here
227 // just make things complicated :-)
228 let regex = self._get_tl_regex();
229 let mut ret = vec![];
230 for mat in regex.find_iter(text) {
231 let piece = mat.unwrap().as_str().as_bytes();
232 match self.encoder.get(piece) {
233 Some(token) => ret.push(*token),
234 None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
235 }
236 }
237 ret
238 }
239
240 pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
241 let special_regex = self._get_tl_special_regex();
242 let regex = self._get_tl_regex();
243 let mut ret = vec![];
244
245 let mut start = 0;
246 let mut last_piece_token_len = 0;
247 loop {
248 let mut next_special;
249 let mut start_find = start;
250 loop {
251 // Find the next allowed special token, if any
252 next_special = special_regex.find_from_pos(text, start_find).unwrap();
253 match next_special {
254 Some(m) => {
255 if allowed_special.contains(&text[m.start()..m.end()]) {
256 break;
257 }
258 start_find = m.start() + 1;
259 }
260 None => break,
261 }
262 }
263 let end = next_special.map_or(text.len(), |m| m.start());
264
265 // Okay, here we go, compare this logic to encode_ordinary
266 for mat in regex.find_iter(&text[start..end]) {
267 let piece = mat.unwrap().as_str().as_bytes();
268 if let Some(token) = self.encoder.get(piece) {
269 last_piece_token_len = 1;
270 ret.push(*token);
271 continue;
272 }
273 let tokens = byte_pair_encode(piece, &self.encoder);
274 last_piece_token_len = tokens.len();
275 ret.extend(&tokens);
276 }
277
278 match next_special {
279 // And here we push the special token
280 Some(m) => {
281 let piece = m.as_str();
282 let token = self.special_tokens_encoder[piece];
283 ret.push(token);
284 start = m.end();
285 last_piece_token_len = 0;
286 }
287 None => break,
288 }
289 }
290
291 // last_piece_token_len is how many tokens came from the last regex split. This is used
292 // for determining unstable tokens, since you can't merge across (stable) regex splits
293 (ret, last_piece_token_len)
294 }
295
296 fn _increase_last_piece_token_len(
297 &self,
298 tokens: Vec<Rank>,
299 mut last_piece_token_len: usize,
300 ) -> (Vec<Rank>, usize) {
301 // Unfortunately, the locations where our regex splits can be unstable.
302 // For the purposes of determining unstable tokens, unstable regex splitting
303 // is only a problem if a split that was present disappears, since this can
304 // lead to merging of tokens otherwise thought to be stable.
305 // cl100k_base makes our life hard by including the \s*[\r\n]+
306 // pattern. This can e.g. cause "\n" + " " to become "\n \n".
307 // Here is a quick and dirty fix:
308 {
309 let token_is_all_space = |token| {
310 self.decoder
311 .get(token)
312 .map(|token_bytes| {
313 token_bytes
314 .iter()
315 .rev()
316 .all(|&b| [b' ', b'\n', b'\t'].contains(&b))
317 })
318 .unwrap_or(false)
319 };
320 if last_piece_token_len > 0
321 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len])
322 {
323 while (last_piece_token_len < tokens.len())
324 && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1])
325 {
326 last_piece_token_len += 1;
327 }
328 }
329 }
330 debug_assert!(last_piece_token_len <= tokens.len());
331
332 (tokens, last_piece_token_len)
333 }
334
335 pub fn _encode_unstable_native(
336 &self,
337 text: &str,
338 allowed_special: &HashSet<&str>,
339 ) -> (Vec<Rank>, HashSet<Vec<Rank>>) {
340 let (tokens, last_piece_token_len) = self.encode(text, allowed_special);
341 if last_piece_token_len == 0 {
342 // If last_piece_token_len is zero, the last token was a special token and we have
343 // no unstable bytes
344 return (tokens, HashSet::new());
345 }
346 let (mut tokens, last_piece_token_len) =
347 self._increase_last_piece_token_len(tokens, last_piece_token_len);
348
349 let unstable_bytes = self
350 .decode_bytes(&tokens[tokens.len() - last_piece_token_len..])
351 .unwrap();
352 tokens.truncate(tokens.len() - last_piece_token_len);
353
354 // TODO: we should try harder to find additional stable tokens
355 // This would reduce the amount of retokenising when determining completions
356 // Refer to the logic in an older version of this file
357
358 let mut completions = HashSet::new();
359 if unstable_bytes.is_empty() {
360 return (tokens, completions);
361 }
362
363 // This is the easy bit. Just find all single tokens that start with unstable_bytes
364 // (including tokens that exactly match unstable_bytes)
365 // Separating this from the loop below helps with performance in a common case.
366 let mut point = self
367 .sorted_token_bytes
368 .partition_point(|x| x.as_slice() < unstable_bytes.as_slice());
369 while point < self.sorted_token_bytes.len()
370 && self.sorted_token_bytes[point].starts_with(&unstable_bytes)
371 {
372 completions.insert(vec![
373 self.encoder[self.sorted_token_bytes[point].as_slice()],
374 ]);
375 point += 1;
376 }
377
378 // Now apply even more brute force. At every (other) possible position for the straddling
379 // token, concatenate additional bytes from that token (if any) to unstable_bytes,
380 // and retokenise the whole thing and see what we get.
381 for i in 1..unstable_bytes.len() {
382 let prefix = &unstable_bytes[..i];
383 let suffix = &unstable_bytes[i..];
384 let mut point = self
385 .sorted_token_bytes
386 .partition_point(|x| x.as_slice() < suffix);
387 // TODO: Perf optimisation if suffix starts with " "?
388 while point < self.sorted_token_bytes.len()
389 && self.sorted_token_bytes[point].starts_with(suffix)
390 {
391 let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat();
392 let encoded = match std::str::from_utf8(&possibility) {
393 // Morally, this is byte_pair_encode(&possibility, &self.encoder)
394 // But we might have introduced a regex split which would prevent merges.
395 // (particularly possible in the presence of unstable regex splits)
396 // So convert to UTF-8 and do regex splitting.
397 // E.g. with cl100k_base " !" gets split to " " + " !",
398 // but byte_pair_encode(" !") != byte_pair_encode(" ")
399 Ok(s) => self.encode_ordinary(s),
400
401 // Technically, whether or not this arm is correct depends on whether there
402 // would be a regex split before the UTF-8 truncation point.
403 // Probably niche enough that no one will ever notice (after all, people didn't
404 // notice all the big holes in the previous unstable token implementation)
405 Err(_) => byte_pair_encode(&possibility, &self.encoder),
406 // Something like the following is intriguing but incorrect:
407 // Err(e) => self.encode_ordinary(unsafe {
408 // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()])
409 // }),
410 };
411 let mut seq = Vec::new();
412 let mut seq_len = 0;
413 for token in encoded {
414 seq.push(token);
415 seq_len += self.decoder[&token].len();
416 if seq_len >= unstable_bytes.len() {
417 break;
418 }
419 }
420 completions.insert(seq);
421 point += 1;
422 }
423 }
424
425 // This is also not straightforward. While we generally assume that regex splits are stable,
426 // unfortunately, they are not. That is, if adding bytes were to make a split appear in
427 // unstable_bytes, this could make tokens possible which our logic would otherwise think
428 // would be merged.
429 // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could
430 // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token.
431 // Here is a quick and dirty fix:
432 // This isn't right if we ever remove \s+(?!\S)
433 if unstable_bytes.len() > 1 {
434 let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice());
435 if unstable_bytes.len() - last_decoded.1 > 0
436 && last_decoded.0.map_or(false, |c| c.is_whitespace())
437 {
438 let mut reencoded = byte_pair_encode(
439 &unstable_bytes[..unstable_bytes.len() - last_decoded.1],
440 &self.encoder,
441 );
442 reencoded.extend(byte_pair_encode(
443 &unstable_bytes[unstable_bytes.len() - last_decoded.1..],
444 &self.encoder,
445 ));
446 completions.insert(reencoded);
447 }
448 }
449
450 (tokens, completions)
451 }
452
453 // pub fn new<E, SE, NSE>(
454 // encoder: E,
455 // special_tokens_encoder: SE,
456 // pattern: &str,
457 // ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
458 // where
459 // E: IntoIterator<Item = (Vec<u8>, Rank)>,
460 // SE: IntoIterator<Item = (String, Rank)>,
461 // NSE: IntoIterator<Item = (String, (Rank, Rank))>,
462 // {
463 // Self::new_internal(
464 // HashMap::from_iter(encoder),
465 // HashMap::from_iter(special_tokens_encoder),
466 // pattern,
467 // )
468 // }
469
470 // fn new_internal(
471 // encoder: HashMap<Vec<u8>, Rank>,
472 // special_tokens_encoder: HashMap<String, Rank>,
473 // pattern: &str,
474 // ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
475 // let regex = Regex::new(pattern)?;
476
477 // let special_regex = {
478 // let parts = special_tokens_encoder
479 // .keys()
480 // .map(|s| fancy_regex::escape(s))
481 // .collect::<Vec<_>>();
482 // Regex::new(&parts.join("|"))?
483 // };
484
485 // let decoder: HashMap<Rank, Vec<u8>> =
486 // encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
487
488 // assert!(
489 // encoder.len() == decoder.len(),
490 // "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
491 // );
492
493 // let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
494 // .iter()
495 // .map(|(k, v)| (*v, k.as_bytes().to_vec()))
496 // .collect();
497
498 // // Clone because I don't know how to tell Rust I'm not going to change the map
499 // let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
500 // sorted_token_bytes.sort();
501
502 // Ok(Self {
503 // encoder,
504 // special_tokens_encoder,
505 // decoder,
506 // special_tokens_decoder,
507 // regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
508 // special_regex_tls: (0..MAX_NUM_THREADS)
509 // .map(|_| special_regex.clone())
510 // .collect(),
511 // sorted_token_bytes,
512 // })
513 // }
514
515 pub fn special_tokens(&self) -> HashSet<&str> {
516 self.special_tokens_encoder
517 .keys()
518 .map(|s| s.as_str())
519 .collect()
520 }
521
522 pub fn encode_with_special_tokens(&self, text: &str) -> Vec<Rank> {
523 let allowed_special = self.special_tokens();
524 self.encode(text, &allowed_special).0
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 // use fancy_regex::Regex;
531 use rustc_hash::FxHashMap as HashMap;
532
533 use crate::{byte_pair_split, Rank};
534
535 fn setup_ranks() -> HashMap<Vec<u8>, Rank> {
536 HashMap::from_iter([(b"ab".to_vec(), 0), (b"cd".to_vec(), 1)])
537 }
538
539 #[test]
540 fn test_simple_characters() {
541 let ranks = setup_ranks();
542 let res = byte_pair_split(b"abcd", &ranks);
543 assert_eq!(res, vec![b"ab", b"cd"]);
544 }
545
546 #[test]
547 fn test_repeated_characters() {
548 let ranks = setup_ranks();
549 let res = byte_pair_split(b"abab", &ranks);
550 assert_eq!(res, vec![b"ab", b"ab"]);
551 }
552}