tiktoken_rust/
model.rs

1use once_cell::sync::Lazy;
2use std::collections::{HashMap, HashSet};
3use std::fmt::Debug;
4use std::io;
5use std::string::FromUtf8Error;
6use thiserror::Error;
7
8#[derive(Debug, Clone)]
9pub enum AllowedSpecial<'a> {
10    All,
11    Allowed(HashSet<&'a str>),
12}
13
14#[derive(Debug, Clone)]
15pub enum DisallowedSpecial<'a> {
16    All,
17    Disallowed(HashSet<&'a str>),
18}
19
20#[derive(Debug, Clone)]
21pub enum DecodeMode {
22    Strict,
23    Replace, // replace invalid character
24}
25
26#[derive(Debug, Error)]
27pub enum EncodeError {
28    #[error("regex error: {0}")]
29    RegexError(#[from] fancy_regex::Error),
30    #[error("token `{0}` not found")]
31    TokenNotFoundError(usize),
32    #[error("could not encode `{0:?}` to token")]
33    TokenEncodeError(Vec<u8>),
34    #[error(
35        "Encountered text corresponding to disallowed special token '{0}'.\n
36If you want this text to be encoded as a special token, pass it to `allowed_special`.\n
37If you want this text to be encoded as normal text, disable the check for this token \
38by passing `disallowed_special=(enc.special_tokens_set - {{'{0}'}})`.\n
39To disable this check for all special tokens, pass `disallowed_special=()`.\n"
40    )]
41    SpecialTokenError(String),
42    #[error("convert bytes to string error: {0}")]
43    ConvertStringError(#[from] FromUtf8Error),
44    #[error(
45        "Could not automatically map {0} to a tokeniser.
46Please use `tiktoken_rust::get_encoding` to explicitly get the tokeniser you expect."
47    )]
48    ModelNameError(String),
49    #[error("Unknown encoding {0}")]
50    EncodingNameError(String),
51    #[error("Stdio error: {0}")]
52    IOError(#[from] io::Error),
53    #[error("Network error: {0}")]
54    HTTPError(#[from] reqwest::Error),
55}
56
57// TODO: these will likely be replaced by an API endpoint
58pub static MODEL_PREFIX_TO_ENCODING: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
59    HashMap::from([
60        // chat
61        ("gpt-4-", "cl100k_base"), // e.g., gpt-4-0314, etc., plus gpt-4-32k
62        ("gpt-3.5-turbo-", "cl100k_base"), // e.g, gpt-3.5-turbo-0301, -0401, etc.
63        ("gpt-35-turbo", "cl100k_base"), // Azure deployment name
64    ])
65});
66
67pub static MODEL_TO_ENCODING: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
68    HashMap::from([
69        // chat
70        ("gpt-4", "cl100k_base"),
71        ("gpt-3.5-turbo", "cl100k_base"),
72        ("gpt-35-turbo", "cl100k_base"), // Azure deployment name
73        // text
74        ("text-davinci-003", "p50k_base"),
75        ("text-davinci-002", "p50k_base"),
76        ("text-davinci-001", "r50k_base"),
77        ("text-curie-001", "r50k_base"),
78        ("text-babbage-001", "r50k_base"),
79        ("text-ada-001", "r50k_base"),
80        ("davinci", "r50k_base"),
81        ("curie", "r50k_base"),
82        ("babbage", "r50k_base"),
83        ("ada", "r50k_base"),
84        // code
85        ("code-davinci-002", "p50k_base"),
86        ("code-davinci-001", "p50k_base"),
87        ("code-cushman-002", "p50k_base"),
88        ("code-cushman-001", "p50k_base"),
89        ("davinci-codex", "p50k_base"),
90        ("cushman-codex", "p50k_base"),
91        // edit
92        ("text-davinci-edit-001", "p50k_edit"),
93        ("code-davinci-edit-001", "p50k_edit"),
94        // embeddings
95        ("text-embedding-ada-002", "cl100k_base"),
96        // old embeddings
97        ("text-similarity-davinci-001", "r50k_base"),
98        ("text-similarity-curie-001", "r50k_base"),
99        ("text-similarity-babbage-001", "r50k_base"),
100        ("text-similarity-ada-001", "r50k_base"),
101        ("text-search-davinci-doc-001", "r50k_base"),
102        ("text-search-curie-doc-001", "r50k_base"),
103        ("text-search-babbage-doc-001", "r50k_base"),
104        ("text-search-ada-doc-001", "r50k_base"),
105        ("code-search-babbage-code-001", "r50k_base"),
106        ("code-search-ada-code-001", "r50k_base"),
107        // open source
108        ("gpt2", "gpt2"),
109    ])
110});