1#![deny(unsafe_code)]
16#![warn(missing_docs)]
17#![warn(rust_2018_idioms)]
18
19use rayon::prelude::*;
20use thiserror::Error;
21use tiktoken_rs::CoreBPE;
22
23pub type Result<T> = std::result::Result<T, TokenizerError>;
25
26#[derive(Error, Debug)]
28pub enum TokenizerError {
29 #[error("unknown encoding: {0} (expected cl100k_base or o200k_base)")]
31 UnknownEncoding(String),
32 #[error("tiktoken-rs error: {0}")]
35 Tiktoken(String),
36}
37
38pub struct Tokenizer {
40 bpe: CoreBPE,
41 encoding_name: String,
42}
43
44impl Tokenizer {
45 pub fn for_model(model: &str) -> Result<Self> {
50 match tiktoken_rs::get_bpe_from_model(model) {
51 Ok(bpe) => Ok(Self {
52 bpe,
53 encoding_name: encoding_for_model(model).to_string(),
54 }),
55 Err(_) => {
56 let encoding = encoding_for_model(model);
59 Self::for_encoding(encoding)
60 }
61 }
62 }
63
64 pub fn for_encoding(name: &str) -> Result<Self> {
67 let bpe =
68 match name {
69 "cl100k_base" => tiktoken_rs::cl100k_base()
70 .map_err(|e| TokenizerError::Tiktoken(e.to_string()))?,
71 "o200k_base" => tiktoken_rs::o200k_base()
72 .map_err(|e| TokenizerError::Tiktoken(e.to_string()))?,
73 other => return Err(TokenizerError::UnknownEncoding(other.to_string())),
74 };
75 Ok(Self {
76 bpe,
77 encoding_name: name.to_string(),
78 })
79 }
80
81 pub fn encoding_name(&self) -> &str {
83 &self.encoding_name
84 }
85
86 pub fn count(&self, text: &str) -> usize {
88 self.bpe.encode_ordinary(text).len()
89 }
90
91 pub fn count_many(&self, texts: &[&str], parallel: bool) -> Vec<usize> {
93 if parallel {
94 texts
95 .par_iter()
96 .map(|t| self.bpe.encode_ordinary(t).len())
97 .collect()
98 } else {
99 texts
100 .iter()
101 .map(|t| self.bpe.encode_ordinary(t).len())
102 .collect()
103 }
104 }
105
106 pub fn encode(&self, text: &str) -> Vec<u32> {
108 self.bpe.encode_ordinary(text)
111 }
112
113 pub fn decode(&self, tokens: &[u32]) -> Result<String> {
115 self.bpe
116 .decode(tokens.to_vec())
117 .map_err(|e| TokenizerError::Tiktoken(e.to_string()))
118 }
119
120 pub fn fits(&self, text: &str, budget: usize) -> bool {
122 self.count(text) <= budget
123 }
124
125 pub fn truncate_to(&self, text: &str, budget: usize) -> Result<String> {
131 let mut tokens = self.bpe.encode_ordinary(text);
132 if tokens.len() <= budget {
133 return Ok(text.to_string());
134 }
135 tokens.truncate(budget);
136 self.bpe
137 .decode(tokens)
138 .map_err(|e| TokenizerError::Tiktoken(e.to_string()))
139 }
140}
141
142fn encoding_for_model(model: &str) -> &'static str {
147 if model.starts_with("gpt-4o")
149 || model.starts_with("gpt-5")
150 || model.starts_with("o1")
151 || model.starts_with("o3")
152 || model.starts_with("o4")
153 || model.starts_with("chatgpt-4o")
154 {
155 "o200k_base"
156 } else {
157 "cl100k_base"
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn round_trip_simple_text() {
169 let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
170 let text = "hello world";
171 let toks = tok.encode(text);
172 let decoded = tok.decode(&toks).unwrap();
173 assert_eq!(decoded, text);
174 }
175
176 #[test]
177 fn count_matches_encode_len() {
178 let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
179 let text = "the quick brown fox jumps over the lazy dog";
180 assert_eq!(tok.count(text), tok.encode(text).len());
181 }
182
183 #[test]
184 fn count_many_serial_and_parallel_agree() {
185 let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
186 let texts: Vec<&str> = vec!["hi", "world", "lorem ipsum dolor sit amet"];
187 let serial = tok.count_many(&texts, false);
188 let par = tok.count_many(&texts, true);
189 assert_eq!(serial, par);
190 }
191
192 #[test]
193 fn for_model_gpt4_is_cl100k() {
194 let tok = Tokenizer::for_model("gpt-4").unwrap();
195 assert_eq!(tok.encoding_name(), "cl100k_base");
196 }
197
198 #[test]
199 fn for_model_gpt5_is_o200k() {
200 let tok = Tokenizer::for_model("gpt-5").unwrap();
203 assert_eq!(tok.encoding_name(), "o200k_base");
204 }
205
206 #[test]
207 fn for_model_o3_is_o200k() {
208 let tok = Tokenizer::for_model("o3-mini").unwrap();
209 assert_eq!(tok.encoding_name(), "o200k_base");
210 }
211
212 #[test]
213 fn for_model_unknown_falls_back_to_cl100k() {
214 let tok = Tokenizer::for_model("future-model-7b").unwrap();
217 assert_eq!(tok.encoding_name(), "cl100k_base");
218 }
219
220 #[test]
221 fn for_model_gpt4o_is_o200k() {
222 let tok = Tokenizer::for_model("gpt-4o").unwrap();
223 assert_eq!(tok.encoding_name(), "o200k_base");
224 }
225
226 #[test]
227 fn unknown_encoding_rejected() {
228 assert!(Tokenizer::for_encoding("unknown_base").is_err());
229 }
230
231 #[test]
232 fn fits_and_truncate() {
233 let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
234 let text = "the quick brown fox";
235 let n = tok.count(text);
236 assert!(tok.fits(text, n));
237 assert!(tok.fits(text, n + 1));
238 assert!(!tok.fits(text, n - 1));
239
240 let truncated = tok.truncate_to(text, 2).unwrap();
241 assert!(tok.count(&truncated) <= 2);
242 assert!(truncated.len() <= text.len());
243 }
244
245 #[test]
246 fn truncate_returns_input_when_fits() {
247 let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
248 let text = "hi";
249 assert_eq!(tok.truncate_to(text, 100).unwrap(), text);
250 }
251
252 #[test]
253 fn empty_text_is_zero_tokens() {
254 let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
255 assert_eq!(tok.count(""), 0);
256 assert_eq!(tok.encode(""), Vec::<u32>::new());
257 }
258
259 #[test]
260 fn unicode_text_round_trips() {
261 let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
262 let text = "你好世界 🌍";
263 let toks = tok.encode(text);
264 assert_eq!(tok.decode(&toks).unwrap(), text);
265 }
266
267 #[test]
268 fn count_many_handles_empty_list() {
269 let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
270 let empty: Vec<&str> = vec![];
271 assert!(tok.count_many(&empty, false).is_empty());
272 assert!(tok.count_many(&empty, true).is_empty());
273 }
274}