tiktoken_rust/
core.rs

1use crate::model::*;
2use crate::openai_public::find_encoding_constructor;
3use crate::CoreBPE;
4use fancy_regex::Regex;
5use rayon::prelude::*;
6use rustc_hash::FxHashMap;
7use std::cmp::max;
8use std::collections::{HashMap, HashSet};
9use std::fmt::{Debug, Display, Formatter};
10use std::hash::Hash;
11
12pub type Result<T> = std::result::Result<T, EncodeError>;
13
14/// Return Encoding object
15/// TODO: cache created Encoding object
16pub fn get_encoding(encoding_name: &str) -> Result<Encoding> {
17    match find_encoding_constructor(encoding_name) {
18        Some(func) => Encoding::new(func()),
19        None => Err(EncodeError::EncodingNameError(encoding_name.to_string())),
20    }
21}
22
23/// Returns the encoding used by a model.
24pub fn encoding_for_model(model_name: &str) -> Result<Encoding> {
25    let encoding_opt = MODEL_TO_ENCODING
26        .get(model_name)
27        .map(|&encoding| get_encoding(encoding));
28    if let Some(encoding) = encoding_opt {
29        return encoding;
30    }
31
32    // Check if the model matches a known prefix
33    // Prefix matching avoids needing library updates for every model version release
34    // Note that this can match on non-existent models (e.g., gpt-3.5-turbo-FAKE)
35    for (&model_prefix, &model_encoding_name) in MODEL_PREFIX_TO_ENCODING.iter() {
36        if model_name.starts_with(model_prefix) {
37            return get_encoding(model_encoding_name);
38        }
39    }
40
41    Err(EncodeError::ModelNameError(model_name.to_string()))
42}
43
44pub struct EncodingParam {
45    name: String,
46    pat_str: String,
47    mergeable_ranks: HashMap<Vec<u8>, usize>,
48    special_tokens: HashMap<String, usize>,
49    explicit_n_vocab: Option<usize>,
50}
51
52impl EncodingParam {
53    pub fn new(
54        name: String,
55        pat_str: String,
56        mergeable_ranks: HashMap<Vec<u8>, usize>,
57        special_tokens: HashMap<String, usize>,
58        explicit_n_vocab: Option<usize>,
59    ) -> Self {
60        EncodingParam {
61            name,
62            pat_str,
63            mergeable_ranks,
64            special_tokens,
65            explicit_n_vocab,
66        }
67    }
68}
69
70pub struct Encoding {
71    name: String,
72    _pat_str: String,
73    special_tokens: HashMap<String, usize>,
74
75    max_token_value: usize,
76    core_bpe: CoreBPE,
77}
78
79impl Debug for Encoding {
80    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
81        write!(f, "<Encoding '{:?}'>", self.name)
82    }
83}
84
85/// Display
86impl Display for Encoding {
87    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
88        write!(f, "<Encoding '{:?}'>", self.name)
89    }
90}
91
92/// Private methods
93impl Encoding {
94    ///
95    /// Creates an Encoding object.
96    ///
97    /// See openai_public.py for examples of how to construct an Encoding object.
98    ///
99    /// Args:
100    /// name: The name of the encoding. It should be clear from the name of the encoding
101    ///       what behaviour to expect, in particular, encodings with different special tokens
102    ///       should have different names.
103    /// pat_str: A regex pattern string that is used to split the input text.
104    /// mergeable_ranks: A dictionary mapping mergeable token bytes to their ranks. The ranks
105    ///                  must correspond to merge priority.
106    /// special_tokens: A dictionary mapping special token strings to their token values.
107    /// explicit_n_vocab: The number of tokens in the vocabulary. If provided, it is checked
108    ///                   that the number of mergeable tokens and special tokens is equal to this number.
109    ///
110    fn new(param: EncodingParam) -> Result<Self> {
111        let max_token_value = max(
112            param
113                .mergeable_ranks
114                .values()
115                .max()
116                .copied()
117                .unwrap_or_default(),
118            param
119                .special_tokens
120                .values()
121                .max()
122                .copied()
123                .unwrap_or_default(),
124        );
125        if let Some(n_vocab) = param.explicit_n_vocab {
126            assert_eq!(
127                param.mergeable_ranks.len() + param.special_tokens.len(),
128                n_vocab
129            );
130            assert_eq!(max_token_value, n_vocab - 1);
131        }
132
133        let core_bpe = CoreBPE::new(
134            convert_to_fx_hashmap(&param.mergeable_ranks),
135            convert_to_fx_hashmap(&param.special_tokens),
136            param.pat_str.as_str(),
137        )?;
138
139        Ok(Encoding {
140            name: param.name,
141            _pat_str: param.pat_str,
142            special_tokens: param.special_tokens,
143            max_token_value,
144            core_bpe,
145        })
146    }
147}
148
149/// Public interfaces for encoding
150impl Encoding {
151    ///Encodes a string into tokens, ignoring special tokens.
152    ///
153    /// This is equivalent to `encode(text, disallowed_special=())` (but slightly faster).
154    pub fn encode_ordinary(&self, text: &str) -> Vec<usize> {
155        self.core_bpe._encode_ordinary_native(text)
156    }
157
158    ///Encodes a list of strings into tokens, in parallel, ignoring special tokens.
159    ///
160    /// This is equivalent to `encode_batch(text, disallowed_special=())` (but slightly faster).
161    pub fn encode_ordinary_batch(&self, texts: Vec<&str>) -> Vec<Vec<usize>> {
162        texts
163            .par_iter()
164            .map(|&txt| self.encode_ordinary(txt))
165            .collect()
166    }
167
168    /// Encodes a string into tokens.
169    /// Special tokens are artificial tokens used to unlock capabilities from a model,
170    /// such as fill-in-the-middle. So we want to be careful about accidentally encoding special
171    /// tokens, since they can be used to trick a model into doing something we don't want it to do.
172    /// Hence, by default, encode will raise an error if it encounters text that corresponds
173    /// to a special token. This can be controlled on a per-token level using the `allowed_special`
174    /// and `disallowed_special` parameters. In particular:
175    /// - Setting `disallowed_special` to () will prevent this function from raising errors and
176    /// cause all text corresponding to special tokens to be encoded as natural text.
177    /// - Setting `allowed_special` to "All" will cause this function to treat all text
178    /// corresponding to special tokens to be encoded as special tokens.
179    pub fn encode(
180        &self,
181        text: &str,
182        allowed_special: AllowedSpecial,
183        disallowed_special: DisallowedSpecial,
184    ) -> Result<Vec<usize>> {
185        let allowed_special_set = match allowed_special {
186            AllowedSpecial::All => self.special_tokens_set(),
187            AllowedSpecial::Allowed(allowed) => allowed,
188        };
189        let disallowed_special_set = match disallowed_special {
190            DisallowedSpecial::All => self
191                .special_tokens_set()
192                .difference(&allowed_special_set)
193                .copied()
194                .collect(),
195            DisallowedSpecial::Disallowed(disallowed) => disallowed,
196        };
197
198        if !disallowed_special_set.is_empty() {
199            let re = special_token_regex(disallowed_special_set)?;
200            if let Ok(Some(cap)) = re.captures(text) {
201                return Err(EncodeError::SpecialTokenError(String::from(
202                    cap.get(0).unwrap().as_str(),
203                )));
204            }
205        }
206
207        Ok(self.core_bpe._encode_native(text, &allowed_special_set).0)
208    }
209
210    /// Encodes a list of strings into tokens, in parallel.
211    ///
212    /// See `encode` for more details on `allowed_special` and `disallowed_special`.
213    pub fn encode_batch(
214        &self,
215        texts: Vec<&str>,
216        allowed_special: AllowedSpecial,
217        disallowed_special: DisallowedSpecial,
218    ) -> Result<Vec<Vec<usize>>> {
219        let data: Vec<Result<Vec<usize>>> = texts
220            .par_iter()
221            .map(|&txt| self.encode(txt, allowed_special.clone(), disallowed_special.clone()))
222            .collect();
223
224        let mut res = Vec::new();
225        for item in data {
226            res.push(item?);
227        }
228        Ok(res)
229    }
230
231    /// Encodes a string into stable tokens and possible completion sequences.
232    /// Note that the stable tokens will only represent a substring of `text`.
233    /// See `encode` for more details on `allowed_special` and `disallowed_special`.
234    /// This API should itself be considered unstable.
235    pub fn encode_with_unstable(
236        &self,
237        text: &str,
238        allowed_special: AllowedSpecial,
239        disallowed_special: DisallowedSpecial,
240    ) -> Result<(Vec<usize>, Vec<Vec<usize>>)> {
241        let allowed_special_set = match allowed_special {
242            AllowedSpecial::All => self.special_tokens_set(),
243            AllowedSpecial::Allowed(allowed) => allowed,
244        };
245        let disallowed_special_set = match disallowed_special {
246            DisallowedSpecial::All => self
247                .special_tokens_set()
248                .difference(&allowed_special_set)
249                .copied()
250                .collect(),
251            DisallowedSpecial::Disallowed(disallowed) => disallowed,
252        };
253
254        if !disallowed_special_set.is_empty() {
255            let re = special_token_regex(disallowed_special_set)?;
256            if let Ok(Some(cap)) = re.captures(text) {
257                return Err(EncodeError::SpecialTokenError(String::from(
258                    cap.get(0).unwrap().as_str(),
259                )));
260            }
261        }
262
263        let (tokens, completions) = self
264            .core_bpe
265            ._encode_unstable_native(text, &allowed_special_set);
266        let completions = completions.into_iter().collect();
267        Ok((tokens, completions))
268    }
269
270    /// Encodes text corresponding to a single token to its token value.
271    ///
272    /// NOTE: this will encode all special tokens.
273    pub fn encode_single_token(&self, piece: &[u8]) -> Result<usize> {
274        if let Some(token) = self.core_bpe.encoder.get(piece).copied() {
275            return Ok(token);
276        }
277        if let Ok(piece_str) = std::str::from_utf8(piece) {
278            if let Some(token) = self.core_bpe.special_tokens_encoder.get(piece_str).copied() {
279                return Ok(token);
280            }
281        }
282        Err(EncodeError::TokenEncodeError(piece.to_owned()))
283    }
284}
285
286/// Public interfaces for decoding
287impl Encoding {
288    /// Decodes a list of tokens into bytes.
289    pub fn decode_bytes(&self, tokens: &[usize]) -> Vec<u8> {
290        self.core_bpe._decode_native(tokens)
291    }
292
293    /// Decodes a batch (list of lists of tokens) into a list of bytes.
294    pub fn decode_bytes_batch(self, batch: &[&[usize]]) -> Vec<Vec<u8>> {
295        batch
296            .par_iter()
297            .map(|tokens| self.decode_bytes(tokens))
298            .collect()
299    }
300
301    /// Decodes a list of tokens into a string.
302    ///
303    /// WARNING: decoded bytes are not guaranteed to be valid UTF-8.
304    /// You can control this behaviour using the `mode` parameter.
305    /// `Strict` mode does validity check and returns Err if provided bytes are not UTF-8
306    /// `Replace` mode replaces invalid UTF-8 sequences with U+FFFD
307    ///
308    pub fn decode(&self, tokens: &[usize], mode: DecodeMode) -> Result<String> {
309        let bytes = self.decode_bytes(tokens);
310        match mode {
311            DecodeMode::Strict => String::from_utf8(bytes).map_err(EncodeError::ConvertStringError),
312            DecodeMode::Replace => Ok(String::from_utf8_lossy(&bytes).to_string()),
313        }
314    }
315
316    /// Decodes a batch (list of lists of tokens) into a list of strings.
317    pub fn decode_batch(&self, batch: &[&[usize]], mode: DecodeMode) -> Vec<Result<String>> {
318        batch
319            .par_iter()
320            .map(|tokens| self.decode(tokens, mode.clone()))
321            .collect()
322    }
323
324    /// Decodes a token into bytes.
325    /// NOTE: this will decode all special tokens.
326    pub fn decode_single_token_bytes(&self, token: usize) -> Result<Vec<u8>> {
327        if let Some(bytes) = self.core_bpe.decoder.get(&token) {
328            return Ok(bytes.to_vec());
329        }
330        if let Some(bytes) = self.core_bpe.special_tokens_decoder.get(&token) {
331            return Ok(bytes.to_vec());
332        }
333        Err(EncodeError::TokenNotFoundError(token))
334    }
335
336    /// Decodes a list of tokens into a list of bytes.
337    /// Useful for visualising tokenisation.
338    pub fn decode_tokens_bytes(&self, tokens: &Vec<usize>) -> Result<Vec<Vec<u8>>> {
339        let data: Vec<Result<Vec<u8>>> = tokens
340            .par_iter()
341            .map(|&token| self.decode_single_token_bytes(token))
342            .collect();
343
344        let mut res = Vec::new();
345        for item in data {
346            res.push(item?);
347        }
348        Ok(res)
349    }
350
351    /// Decodes a list of tokens into a string and a list of offsets.
352    /// Each offset is the index into text corresponding to the start of each token.
353    /// If UTF-8 character boundaries do not line up with token boundaries, the offset is the index
354    /// of the first character that contains bytes from the token.
355    /// This will currently raise if given tokens that decode to invalid UTF-8; this behaviour may
356    /// change in the future to be more permissive.
357    /// >>> enc.decode_with_offsets([31373, 995])
358    /// ('hello world', [0, 5])
359    pub fn decode_with_offsets(self, tokens: &Vec<usize>) -> Result<(String, Vec<usize>)> {
360        let token_bytes = self.decode_tokens_bytes(tokens)?;
361        let mut text_len = 0;
362        let mut offsets = vec![];
363
364        for token in token_bytes {
365            let offset = if token[0] >= 0x80 && token[0] < 0xC0 {
366                max(0, text_len - 1)
367            } else {
368                max(0, text_len)
369            };
370            offsets.push(offset);
371            text_len += token
372                .iter()
373                .map(|&c| if c < 0x80 || c >= 0xC0 { 1 } else { 0 })
374                .sum::<usize>();
375        }
376
377        let text = self.decode(tokens, DecodeMode::Strict)?;
378
379        Ok((text, offsets))
380    }
381}
382
383/// Miscellaneous interfaces
384impl Encoding {
385    /// Returns the name of this encoding
386    pub fn name(&self) -> &str {
387        self.name.as_str()
388    }
389
390    /// Returns the list of all token byte values.
391    pub fn token_byte_values(&self) -> Vec<Vec<u8>> {
392        self.core_bpe
393            .sorted_token_bytes
394            .iter()
395            .map(|x| x.to_vec())
396            .collect()
397    }
398
399    pub fn eot_token(&self) -> Option<usize> {
400        self.special_tokens.get("<|endoftext|>").copied()
401    }
402
403    /// For backwards compatibility. Prefer to use `enc.max_token_value + 1`.
404    pub fn n_vocab(&self) -> usize {
405        self.max_token_value + 1
406    }
407
408    // TODO: lazy evaluation
409    pub fn special_tokens_set(&self) -> HashSet<&str> {
410        HashSet::from_iter(self.special_tokens.keys().map(|k| k.as_str()))
411    }
412}
413
414// TODO: LRU cache
415fn special_token_regex(tokens: HashSet<&str>) -> Result<Regex> {
416    let inner: Vec<_> = tokens.iter().map(|&t| regex::escape(t)).collect();
417    let re = Regex::new(format!("({})", inner.join("|")).as_str())?;
418    Ok(re)
419}
420
421fn convert_to_fx_hashmap<K, V>(origin: &HashMap<K, V>) -> FxHashMap<K, V>
422where
423    K: Hash + Eq + PartialEq + Clone,
424    V: Clone,
425{
426    let mut res: FxHashMap<K, V> = FxHashMap::default();
427    origin
428        .iter()
429        .for_each(|(k, v)| _ = res.insert(k.clone(), v.clone()));
430    res
431}