1use once_cell::sync::Lazy;
2use std::collections::{HashMap, HashSet};
3use std::fmt::Debug;
4use std::io;
5use std::string::FromUtf8Error;
6use thiserror::Error;
7
8#[derive(Debug, Clone)]
9pub enum AllowedSpecial<'a> {
10 All,
11 Allowed(HashSet<&'a str>),
12}
13
14#[derive(Debug, Clone)]
15pub enum DisallowedSpecial<'a> {
16 All,
17 Disallowed(HashSet<&'a str>),
18}
19
20#[derive(Debug, Clone)]
21pub enum DecodeMode {
22 Strict,
23 Replace, }
25
26#[derive(Debug, Error)]
27pub enum EncodeError {
28 #[error("regex error: {0}")]
29 RegexError(#[from] fancy_regex::Error),
30 #[error("token `{0}` not found")]
31 TokenNotFoundError(usize),
32 #[error("could not encode `{0:?}` to token")]
33 TokenEncodeError(Vec<u8>),
34 #[error(
35 "Encountered text corresponding to disallowed special token '{0}'.\n
36If you want this text to be encoded as a special token, pass it to `allowed_special`.\n
37If you want this text to be encoded as normal text, disable the check for this token \
38by passing `disallowed_special=(enc.special_tokens_set - {{'{0}'}})`.\n
39To disable this check for all special tokens, pass `disallowed_special=()`.\n"
40 )]
41 SpecialTokenError(String),
42 #[error("convert bytes to string error: {0}")]
43 ConvertStringError(#[from] FromUtf8Error),
44 #[error(
45 "Could not automatically map {0} to a tokeniser.
46Please use `tiktoken_rust::get_encoding` to explicitly get the tokeniser you expect."
47 )]
48 ModelNameError(String),
49 #[error("Unknown encoding {0}")]
50 EncodingNameError(String),
51 #[error("Stdio error: {0}")]
52 IOError(#[from] io::Error),
53 #[error("Network error: {0}")]
54 HTTPError(#[from] reqwest::Error),
55}
56
57pub static MODEL_PREFIX_TO_ENCODING: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
59 HashMap::from([
60 ("gpt-4-", "cl100k_base"), ("gpt-3.5-turbo-", "cl100k_base"), ("gpt-35-turbo", "cl100k_base"), ])
65});
66
67pub static MODEL_TO_ENCODING: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
68 HashMap::from([
69 ("gpt-4", "cl100k_base"),
71 ("gpt-3.5-turbo", "cl100k_base"),
72 ("gpt-35-turbo", "cl100k_base"), ("text-davinci-003", "p50k_base"),
75 ("text-davinci-002", "p50k_base"),
76 ("text-davinci-001", "r50k_base"),
77 ("text-curie-001", "r50k_base"),
78 ("text-babbage-001", "r50k_base"),
79 ("text-ada-001", "r50k_base"),
80 ("davinci", "r50k_base"),
81 ("curie", "r50k_base"),
82 ("babbage", "r50k_base"),
83 ("ada", "r50k_base"),
84 ("code-davinci-002", "p50k_base"),
86 ("code-davinci-001", "p50k_base"),
87 ("code-cushman-002", "p50k_base"),
88 ("code-cushman-001", "p50k_base"),
89 ("davinci-codex", "p50k_base"),
90 ("cushman-codex", "p50k_base"),
91 ("text-davinci-edit-001", "p50k_edit"),
93 ("code-davinci-edit-001", "p50k_edit"),
94 ("text-embedding-ada-002", "cl100k_base"),
96 ("text-similarity-davinci-001", "r50k_base"),
98 ("text-similarity-curie-001", "r50k_base"),
99 ("text-similarity-babbage-001", "r50k_base"),
100 ("text-similarity-ada-001", "r50k_base"),
101 ("text-search-davinci-doc-001", "r50k_base"),
102 ("text-search-curie-doc-001", "r50k_base"),
103 ("text-search-babbage-doc-001", "r50k_base"),
104 ("text-search-ada-doc-001", "r50k_base"),
105 ("code-search-babbage-code-001", "r50k_base"),
106 ("code-search-ada-code-001", "r50k_base"),
107 ("gpt2", "gpt2"),
109 ])
110});