tiktoken_rs/patched_tiktoken.rs
1use super::vendor_tiktoken::*;
2use anyhow::anyhow;
3use anyhow::Result;
4use fancy_regex::Regex;
5use rustc_hash::FxHashMap as HashMap;
6
7/// Rust API
8impl CoreBPE {
9 // ====================
10 // Encoding
11 // ====================
12
13 // This function a copy of the similar function in python API, but it return
14 // Rust's results and errors
15 pub fn new(
16 encoder: HashMap<Vec<u8>, Rank>,
17 special_tokens_encoder: HashMap<String, Rank>,
18 pattern: &str,
19 ) -> Result<Self> {
20 let regex = Regex::new(pattern)?;
21
22 let special_regex = {
23 let parts = special_tokens_encoder
24 .keys()
25 .map(|s| fancy_regex::escape(s))
26 .collect::<Vec<_>>();
27 Regex::new(&parts.join("|"))?
28 };
29
30 let decoder: HashMap<Rank, Vec<u8>> =
31 encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
32
33 assert!(
34 encoder.len() == decoder.len(),
35 "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
36 );
37
38 let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
39 .iter()
40 .map(|(k, v)| (*v, k.as_bytes().to_vec()))
41 .collect();
42
43 // Clone because I don't know how to tell Rust I'm not going to change the map
44 let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
45 sorted_token_bytes.sort();
46
47 Ok(Self {
48 encoder,
49 special_tokens_encoder,
50 decoder,
51 special_tokens_decoder,
52 regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
53 special_regex_tls: (0..MAX_NUM_THREADS)
54 .map(|_| special_regex.clone())
55 .collect(),
56 sorted_token_bytes,
57 })
58 }
59
60 // ====================
61 // Decoding
62 // ====================
63
64 /// Decode a vector of tokens into a valid UTF-8 String
65 ///
66 /// If unicode validation is not wanted, see _decode_native.
67 pub fn decode(&self, tokens: Vec<Rank>) -> Result<String> {
68 match String::from_utf8(self.decode_bytes(&tokens)?) {
69 Ok(text) => Ok(text),
70 Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
71 }
72 }
73
74 pub fn _decode_native_and_split(
75 &self,
76 tokens: Vec<Rank>,
77 ) -> impl Iterator<Item = Vec<u8>> + '_ {
78 tokens.into_iter().map(|token| {
79 let token_bytes = self
80 .decoder
81 .get(&token)
82 .unwrap_or_else(|| &self.special_tokens_decoder[&token]);
83 token_bytes.clone()
84 })
85 }
86
87 /// Tokenize a string and return the decoded tokens using the correct BPE model.
88 ///
89 /// This method takes a string, encodes it using the BPE model, and decodes the encoded tokens into
90 /// a vector of strings. It can be used to tokenize a string and return the decoded tokens using the
91 /// correct BPE model.
92 ///
93 /// # Examples
94 ///
95 /// ```
96 /// use tiktoken_rs::cl100k_base;
97 /// let bpe = cl100k_base().unwrap();
98 /// let tokenized: Result<Vec<_>, _> = bpe
99 /// .split_by_token("This is a test with a lot of spaces", true);
100 /// let tokenized = tokenized.unwrap();
101 /// assert_eq!(
102 /// tokenized,
103 /// vec!["This", " is", " a", " test", " ", " with", " a", " lot", " of", " spaces"]
104 /// );
105 /// ```
106 ///
107 /// # Arguments
108 ///
109 /// * text: A string slice containing the text to be tokenized.
110 /// * use_special_tokens: A boolean indicating whether to use the special tokens in the BPE model.
111 ///
112 /// # Returns
113 ///
114 /// * `Result<Vec<String>>`: A Result containing a vector of decoded tokens as strings, or an error
115 /// if the string cannot be converted into a valid UTF-8 string.
116 ///
117 /// # Errors
118 ///
119 /// This function will return an error if:
120 ///
121 /// * The input text cannot be converted into a valid UTF-8 string during the decoding process.
122 ///
123 pub fn split_by_token<'a>(
124 &'a self,
125 text: &'a str,
126 use_special_tokens: bool,
127 ) -> Result<Vec<String>> {
128 self.split_by_token_iter(text, use_special_tokens).collect()
129 }
130
131 /// Iterator for decoding and splitting a String.
132 /// See `split_by_token` for more details.
133 pub fn split_by_token_iter<'a>(
134 &'a self,
135 text: &'a str,
136 use_special_tokens: bool,
137 ) -> impl Iterator<Item = Result<String>> + 'a {
138 // First, encode the text using the BPE model
139 let encoded = match use_special_tokens {
140 true => self.encode_with_special_tokens(text),
141 false => self.encode_ordinary(text),
142 };
143
144 self._decode_native_and_split(encoded).map(|token| {
145 // Map each token to a Result<String>
146 Ok(String::from_utf8_lossy(token.as_slice()).to_string())
147 })
148 }
149
150 /// Tokenize a string and return the decoded tokens using the correct BPE model.
151 /// This method is equivalent to `split_by_token(text, false)`.
152 pub fn split_by_token_ordinary<'a>(&'a self, text: &'a str) -> Result<Vec<String>> {
153 self.split_by_token(text, false)
154 }
155
156 /// Iterator for decoding and splitting a String.
157 /// This method is equivalent to `split_by_token_iter(text, false)`.
158 pub fn split_by_token_ordinary_iter<'a>(
159 &'a self,
160 text: &'a str,
161 ) -> impl Iterator<Item = Result<String>> + 'a {
162 self.split_by_token_iter(text, false)
163 }
164}