Skip to main content

toklab_core/
lib.rs

1//! Pure-Rust core for `toklab`. Thin wrapper around
2//! [tiktoken-rs](https://crates.io/crates/tiktoken-rs) that adds:
3//!
4//! - **Bulk APIs** (`count_many`, optional rayon parallelism) — the win over
5//!   pure-Python `tiktoken` is in long lists where Python interpreter
6//!   overhead dominates.
7//! - **Length-budgeting helpers** (`fits`, `truncate_to`) so the common
8//!   patterns are one call instead of three.
9//! - **Model-name lookup** that maps OpenAI model names to encodings via
10//!   `tiktoken_rs::get_bpe_from_model`.
11//!
12//! Encodings supported out of the box: `cl100k_base` (GPT-3.5, GPT-4,
13//! text-embedding-3-*) and `o200k_base` (GPT-4o family).
14
15#![deny(unsafe_code)]
16#![warn(missing_docs)]
17#![warn(rust_2018_idioms)]
18
19use rayon::prelude::*;
20use thiserror::Error;
21use tiktoken_rs::CoreBPE;
22
23/// Crate-wide result alias.
24pub type Result<T> = std::result::Result<T, TokenizerError>;
25
26/// All errors surfaced by `toklab-core`.
27#[derive(Error, Debug)]
28pub enum TokenizerError {
29    /// Unknown encoding name passed to [`Tokenizer::for_encoding`].
30    #[error("unknown encoding: {0} (expected cl100k_base or o200k_base)")]
31    UnknownEncoding(String),
32    /// tiktoken-rs failed to load BPE tables. Should be unreachable for the
33    /// bundled encodings; surfaces if a future version makes them optional.
34    #[error("tiktoken-rs error: {0}")]
35    Tiktoken(String),
36}
37
38/// Wraps a `CoreBPE` for one specific encoding.
39pub struct Tokenizer {
40    bpe: CoreBPE,
41    encoding_name: String,
42}
43
44impl Tokenizer {
45    /// Construct from an OpenAI model name (`"gpt-4"`, `"gpt-4o"`,
46    /// `"gpt-4.1"`, `"gpt-5"`, etc.). Tries `tiktoken_rs::get_bpe_from_model`
47    /// first; if that fails (the model is too new for the bundled
48    /// mapping), falls back to encoding inference via name pattern.
49    pub fn for_model(model: &str) -> Result<Self> {
50        match tiktoken_rs::get_bpe_from_model(model) {
51            Ok(bpe) => Ok(Self {
52                bpe,
53                encoding_name: encoding_for_model(model).to_string(),
54            }),
55            Err(_) => {
56                // Fallback: route by name pattern to the right base encoding.
57                // Catches future model names tiktoken-rs hasn't enumerated.
58                let encoding = encoding_for_model(model);
59                Self::for_encoding(encoding)
60            }
61        }
62    }
63
64    /// Construct from an encoding name. Accepts `"cl100k_base"` and
65    /// `"o200k_base"`.
66    pub fn for_encoding(name: &str) -> Result<Self> {
67        let bpe =
68            match name {
69                "cl100k_base" => tiktoken_rs::cl100k_base()
70                    .map_err(|e| TokenizerError::Tiktoken(e.to_string()))?,
71                "o200k_base" => tiktoken_rs::o200k_base()
72                    .map_err(|e| TokenizerError::Tiktoken(e.to_string()))?,
73                other => return Err(TokenizerError::UnknownEncoding(other.to_string())),
74            };
75        Ok(Self {
76            bpe,
77            encoding_name: name.to_string(),
78        })
79    }
80
81    /// Encoding name (`"cl100k_base"` or `"o200k_base"`).
82    pub fn encoding_name(&self) -> &str {
83        &self.encoding_name
84    }
85
86    /// Count BPE tokens in `text`, ignoring special tokens.
87    pub fn count(&self, text: &str) -> usize {
88        self.bpe.encode_ordinary(text).len()
89    }
90
91    /// Bulk count. With `parallel = true` distributes across rayon's pool.
92    pub fn count_many(&self, texts: &[&str], parallel: bool) -> Vec<usize> {
93        if parallel {
94            texts
95                .par_iter()
96                .map(|t| self.bpe.encode_ordinary(t).len())
97                .collect()
98        } else {
99            texts
100                .iter()
101                .map(|t| self.bpe.encode_ordinary(t).len())
102                .collect()
103        }
104    }
105
106    /// Encode to BPE token IDs (ordinary mode, no special tokens).
107    pub fn encode(&self, text: &str) -> Vec<u32> {
108        // tiktoken-rs 0.6 returns Vec<Rank> where Rank == u32; if a future
109        // version changes this we'll catch it here.
110        self.bpe.encode_ordinary(text)
111    }
112
113    /// Decode a slice of BPE token IDs back to a string.
114    pub fn decode(&self, tokens: &[u32]) -> Result<String> {
115        self.bpe
116            .decode(tokens.to_vec())
117            .map_err(|e| TokenizerError::Tiktoken(e.to_string()))
118    }
119
120    /// True iff `text` encodes to `<= budget` BPE tokens.
121    pub fn fits(&self, text: &str, budget: usize) -> bool {
122        self.count(text) <= budget
123    }
124
125    /// Encode `text`, truncate to the first `budget` tokens, and decode back.
126    /// If `text` already fits, returns it unchanged. Boundary handling is
127    /// whatever tiktoken-rs's `decode` does on a mid-token cut, which is
128    /// well-defined for cl100k/o200k since each token decodes to a complete
129    /// UTF-8 sequence in the merged-vocabulary case.
130    pub fn truncate_to(&self, text: &str, budget: usize) -> Result<String> {
131        let mut tokens = self.bpe.encode_ordinary(text);
132        if tokens.len() <= budget {
133            return Ok(text.to_string());
134        }
135        tokens.truncate(budget);
136        self.bpe
137            .decode(tokens)
138            .map_err(|e| TokenizerError::Tiktoken(e.to_string()))
139    }
140}
141
142/// Map an OpenAI model name to its encoding name. Used for diagnostics so
143/// callers can see which encoding their model resolved to. Conservative
144/// when the family is unknown — defaults to `cl100k_base`, which is the
145/// safer "older" encoding to over-count tokens against.
146fn encoding_for_model(model: &str) -> &'static str {
147    // Newer reasoning + multimodal families use o200k_base.
148    if model.starts_with("gpt-4o")
149        || model.starts_with("gpt-5")
150        || model.starts_with("o1")
151        || model.starts_with("o3")
152        || model.starts_with("o4")
153        || model.starts_with("chatgpt-4o")
154    {
155        "o200k_base"
156    } else {
157        // gpt-4, gpt-4.1 (cl100k for now per OpenAI), gpt-3.5*, and
158        // text-embedding-3-* all sit on cl100k_base.
159        "cl100k_base"
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn round_trip_simple_text() {
169        let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
170        let text = "hello world";
171        let toks = tok.encode(text);
172        let decoded = tok.decode(&toks).unwrap();
173        assert_eq!(decoded, text);
174    }
175
176    #[test]
177    fn count_matches_encode_len() {
178        let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
179        let text = "the quick brown fox jumps over the lazy dog";
180        assert_eq!(tok.count(text), tok.encode(text).len());
181    }
182
183    #[test]
184    fn count_many_serial_and_parallel_agree() {
185        let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
186        let texts: Vec<&str> = vec!["hi", "world", "lorem ipsum dolor sit amet"];
187        let serial = tok.count_many(&texts, false);
188        let par = tok.count_many(&texts, true);
189        assert_eq!(serial, par);
190    }
191
192    #[test]
193    fn for_model_gpt4_is_cl100k() {
194        let tok = Tokenizer::for_model("gpt-4").unwrap();
195        assert_eq!(tok.encoding_name(), "cl100k_base");
196    }
197
198    #[test]
199    fn for_model_gpt5_is_o200k() {
200        // gpt-5 may not exist in tiktoken-rs's mapping; we should still
201        // resolve via the encoding_for_model fallback.
202        let tok = Tokenizer::for_model("gpt-5").unwrap();
203        assert_eq!(tok.encoding_name(), "o200k_base");
204    }
205
206    #[test]
207    fn for_model_o3_is_o200k() {
208        let tok = Tokenizer::for_model("o3-mini").unwrap();
209        assert_eq!(tok.encoding_name(), "o200k_base");
210    }
211
212    #[test]
213    fn for_model_unknown_falls_back_to_cl100k() {
214        // Truly unknown family (won't be in tiktoken-rs mapping); fallback
215        // routes to cl100k_base because it doesn't match the o200k prefixes.
216        let tok = Tokenizer::for_model("future-model-7b").unwrap();
217        assert_eq!(tok.encoding_name(), "cl100k_base");
218    }
219
220    #[test]
221    fn for_model_gpt4o_is_o200k() {
222        let tok = Tokenizer::for_model("gpt-4o").unwrap();
223        assert_eq!(tok.encoding_name(), "o200k_base");
224    }
225
226    #[test]
227    fn unknown_encoding_rejected() {
228        assert!(Tokenizer::for_encoding("unknown_base").is_err());
229    }
230
231    #[test]
232    fn fits_and_truncate() {
233        let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
234        let text = "the quick brown fox";
235        let n = tok.count(text);
236        assert!(tok.fits(text, n));
237        assert!(tok.fits(text, n + 1));
238        assert!(!tok.fits(text, n - 1));
239
240        let truncated = tok.truncate_to(text, 2).unwrap();
241        assert!(tok.count(&truncated) <= 2);
242        assert!(truncated.len() <= text.len());
243    }
244
245    #[test]
246    fn truncate_returns_input_when_fits() {
247        let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
248        let text = "hi";
249        assert_eq!(tok.truncate_to(text, 100).unwrap(), text);
250    }
251
252    #[test]
253    fn empty_text_is_zero_tokens() {
254        let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
255        assert_eq!(tok.count(""), 0);
256        assert_eq!(tok.encode(""), Vec::<u32>::new());
257    }
258
259    #[test]
260    fn unicode_text_round_trips() {
261        let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
262        let text = "你好世界 🌍";
263        let toks = tok.encode(text);
264        assert_eq!(tok.decode(&toks).unwrap(), text);
265    }
266
267    #[test]
268    fn count_many_handles_empty_list() {
269        let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
270        let empty: Vec<&str> = vec![];
271        assert!(tok.count_many(&empty, false).is_empty());
272        assert!(tok.count_many(&empty, true).is_empty());
273    }
274}