tiktoken_rs/
patched_tiktoken.rs

1use super::vendor_tiktoken::*;
2use anyhow::anyhow;
3use anyhow::Result;
4use fancy_regex::Regex;
5use rustc_hash::FxHashMap as HashMap;
6
7/// Rust API
8impl CoreBPE {
9    // ====================
10    // Encoding
11    // ====================
12
13    // This function a copy of the similar function in python API, but it return
14    // Rust's results and errors
15    pub fn new(
16        encoder: HashMap<Vec<u8>, Rank>,
17        special_tokens_encoder: HashMap<String, Rank>,
18        pattern: &str,
19    ) -> Result<Self> {
20        let regex = Regex::new(pattern)?;
21
22        let special_regex = {
23            let parts = special_tokens_encoder
24                .keys()
25                .map(|s| fancy_regex::escape(s))
26                .collect::<Vec<_>>();
27            Regex::new(&parts.join("|"))?
28        };
29
30        let decoder: HashMap<Rank, Vec<u8>> =
31            encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
32
33        assert!(
34            encoder.len() == decoder.len(),
35            "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"
36        );
37
38        let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
39            .iter()
40            .map(|(k, v)| (*v, k.as_bytes().to_vec()))
41            .collect();
42
43        // Clone because I don't know how to tell Rust I'm not going to change the map
44        let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
45        sorted_token_bytes.sort();
46
47        Ok(Self {
48            encoder,
49            special_tokens_encoder,
50            decoder,
51            special_tokens_decoder,
52            regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(),
53            special_regex_tls: (0..MAX_NUM_THREADS)
54                .map(|_| special_regex.clone())
55                .collect(),
56            sorted_token_bytes,
57        })
58    }
59
60    // ====================
61    // Decoding
62    // ====================
63
64    /// Decode a vector of tokens into a valid UTF-8 String
65    ///
66    /// If unicode validation is not wanted, see _decode_native.
67    pub fn decode(&self, tokens: Vec<Rank>) -> Result<String> {
68        match String::from_utf8(self.decode_bytes(&tokens)?) {
69            Ok(text) => Ok(text),
70            Err(e) => Err(anyhow!("Unable to decode into a valid UTF-8 string: {}", e)),
71        }
72    }
73
74    pub fn _decode_native_and_split(
75        &self,
76        tokens: Vec<Rank>,
77    ) -> impl Iterator<Item = Vec<u8>> + '_ {
78        tokens.into_iter().map(|token| {
79            let token_bytes = self
80                .decoder
81                .get(&token)
82                .unwrap_or_else(|| &self.special_tokens_decoder[&token]);
83            token_bytes.clone()
84        })
85    }
86
87    /// Tokenize a string and return the decoded tokens using the correct BPE model.
88    ///
89    /// This method takes a string, encodes it using the BPE model, and decodes the encoded tokens into
90    /// a vector of strings. It can be used to tokenize a string and return the decoded tokens using the
91    /// correct BPE model.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    ///     use tiktoken_rs::cl100k_base;
97    ///     let bpe = cl100k_base().unwrap();
98    ///     let tokenized: Result<Vec<_>, _> = bpe
99    ///         .split_by_token("This is a test         with a lot of spaces", true);
100    ///     let tokenized = tokenized.unwrap();
101    ///     assert_eq!(
102    ///         tokenized,
103    ///         vec!["This", " is", " a", " test", "        ", " with", " a", " lot", " of", " spaces"]
104    ///     );
105    /// ```
106    ///
107    /// # Arguments
108    ///
109    /// * text: A string slice containing the text to be tokenized.
110    /// * use_special_tokens: A boolean indicating whether to use the special tokens in the BPE model.
111    ///
112    /// # Returns
113    ///
114    /// * `Result<Vec<String>>`: A Result containing a vector of decoded tokens as strings, or an error
115    ///   if the string cannot be converted into a valid UTF-8 string.
116    ///
117    /// # Errors
118    ///
119    /// This function will return an error if:
120    ///
121    /// * The input text cannot be converted into a valid UTF-8 string during the decoding process.
122    ///
123    pub fn split_by_token<'a>(
124        &'a self,
125        text: &'a str,
126        use_special_tokens: bool,
127    ) -> Result<Vec<String>> {
128        self.split_by_token_iter(text, use_special_tokens).collect()
129    }
130
131    /// Iterator for decoding and splitting a String.
132    /// See `split_by_token` for more details.
133    pub fn split_by_token_iter<'a>(
134        &'a self,
135        text: &'a str,
136        use_special_tokens: bool,
137    ) -> impl Iterator<Item = Result<String>> + 'a {
138        // First, encode the text using the BPE model
139        let encoded = match use_special_tokens {
140            true => self.encode_with_special_tokens(text),
141            false => self.encode_ordinary(text),
142        };
143
144        self._decode_native_and_split(encoded).map(|token| {
145            // Map each token to a Result<String>
146            Ok(String::from_utf8_lossy(token.as_slice()).to_string())
147        })
148    }
149
150    /// Tokenize a string and return the decoded tokens using the correct BPE model.
151    /// This method is equivalent to `split_by_token(text, false)`.
152    pub fn split_by_token_ordinary<'a>(&'a self, text: &'a str) -> Result<Vec<String>> {
153        self.split_by_token(text, false)
154    }
155
156    /// Iterator for decoding and splitting a String.
157    /// This method is equivalent to `split_by_token_iter(text, false)`.
158    pub fn split_by_token_ordinary_iter<'a>(
159        &'a self,
160        text: &'a str,
161    ) -> impl Iterator<Item = Result<String>> + 'a {
162        self.split_by_token_iter(text, false)
163    }
164}