Skip to main content

traitclaw_core/
token_counter.rs

1//! Accurate tiktoken-based token counting.
2//!
3//! This module is only available when the `tiktoken` feature is enabled.
4//! It provides exact OpenAI-compatible BPE token counting via `tiktoken-rs`.
5//!
6//! # Usage
7//!
8//! ```toml
9//! # Cargo.toml
10//! traitclaw-core = { version = "*", features = ["tiktoken"] }
11//! ```
12//!
13//! ```rust,ignore
14//! use traitclaw_core::token_counter::TikTokenCounter;
15//!
16//! let counter = TikTokenCounter::for_model("gpt-4o");
17//! let tokens = counter.count_messages(&messages);
18//! ```
19
20#[cfg(feature = "tiktoken")]
21mod inner {
22    use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
23
24    use crate::token_counting::TokenCounter;
25    use crate::types::message::{Message, MessageRole};
26
27    // Per-message overhead (role prefix + separators) in the ChatML format.
28    // <|im_start|>role\n{content}<|im_end|>\n ≈ 4 tokens overhead per message.
29    const MESSAGE_OVERHEAD_TOKENS: usize = 4;
30    // Final reply priming: <|im_start|>assistant\n ≈ 3 tokens
31    const REPLY_PRIMING_TOKENS: usize = 3;
32
33    /// Exact token counter using OpenAI-compatible BPE tokenization via tiktoken-rs.
34    ///
35    /// Much more accurate than [`CharApproxCounter`] for context budget decisions.
36    /// Automatically selects the right encoding based on the model name.
37    ///
38    /// [`CharApproxCounter`]: crate::token_counting::CharApproxCounter
39    ///
40    /// # Example
41    ///
42    /// ```rust,ignore
43    /// use traitclaw_core::token_counter::TikTokenCounter;
44    ///
45    /// let counter = TikTokenCounter::for_model("gpt-4o");
46    /// let count = counter.count_str("Hello, world!");
47    /// ```
48    pub struct TikTokenCounter {
49        bpe: CoreBPE,
50        model_name: String,
51    }
52
53    impl TikTokenCounter {
54        /// Create a counter for the given model name.
55        ///
56        /// The encoding is selected based on the model name:
57        /// - `gpt-4o*`, `gpt-4o-mini`, `o1*`, `o3*`, `o4*` → `o200k_base`
58        /// - `gpt-4*`, `gpt-3.5*`, `text-embedding-ada*` → `cl100k_base`
59        /// - Unknown models → `cl100k_base` with a warning
60        ///
61        /// # Panics
62        ///
63        /// Panics if the tiktoken-rs library fails to initialize (this should
64        /// never happen in practice as the encodings are bundled).
65        #[must_use]
66        pub fn for_model(model: &str) -> Self {
67            let (bpe, used_encoding) = select_encoding(model);
68            if used_encoding == "fallback" {
69                tracing::warn!(
70                    "TikTokenCounter: unknown model '{}', falling back to cl100k_base encoding.",
71                    model
72                );
73            }
74            Self {
75                bpe,
76                model_name: model.to_string(),
77            }
78        }
79
80        /// The model name this counter was created for.
81        #[must_use]
82        pub fn model_name(&self) -> &str {
83            &self.model_name
84        }
85
86        /// Count BPE tokens in a single string.
87        #[must_use]
88        pub fn count_str(&self, text: &str) -> usize {
89            self.bpe.encode_with_special_tokens(text).len()
90        }
91
92        /// Count tokens in a message list, including per-message overhead.
93        ///
94        /// Uses the ChatML format overhead:
95        /// `<|im_start|>role\n{content}<|im_end|>\n` ≈ content_tokens + 4.
96        #[must_use]
97        pub fn count_messages(&self, messages: &[Message]) -> usize {
98            let content_tokens: usize = messages
99                .iter()
100                .map(|m| {
101                    let role_str = match &m.role {
102                        MessageRole::System => "system",
103                        MessageRole::User => "user",
104                        MessageRole::Assistant => "assistant",
105                        MessageRole::Tool => "tool",
106                    };
107                    let role_tokens = self.bpe.encode_with_special_tokens(role_str).len();
108                    let content_tokens = self.bpe.encode_with_special_tokens(&m.content).len();
109                    role_tokens + content_tokens + MESSAGE_OVERHEAD_TOKENS
110                })
111                .sum();
112            content_tokens + REPLY_PRIMING_TOKENS
113        }
114
115        /// Standalone helper: count tokens for a list of messages given a model name.
116        ///
117        /// Useful from [`ContextManager`] implementations as a one-shot call.
118        ///
119        /// ```rust,ignore
120        /// let n = TikTokenCounter::estimate_for_model(&messages, "gpt-4o");
121        /// ```
122        ///
123        /// [`ContextManager`]: crate::traits::context_manager::ContextManager
124        #[must_use]
125        pub fn estimate_for_model(messages: &[Message], model: &str) -> usize {
126            Self::for_model(model).count_messages(messages)
127        }
128    }
129
130    impl TokenCounter for TikTokenCounter {
131        fn count_messages(&self, messages: &[Message]) -> usize {
132            self.count_messages(messages)
133        }
134
135        fn count_str(&self, text: &str) -> usize {
136            self.count_str(text)
137        }
138    }
139
140    /// Select the BPE encoding for the given model name.
141    ///
142    /// Returns `(CoreBPE, encoding_label)` where `encoding_label` is "fallback"
143    /// when the model was not recognized.
144    fn select_encoding(model: &str) -> (CoreBPE, &'static str) {
145        // o200k_base: GPT-4o, o1, o3, o4 series
146        let use_o200k = model.starts_with("gpt-4o")
147            || model.starts_with("o1")
148            || model.starts_with("o3")
149            || model.starts_with("o4");
150
151        if use_o200k {
152            return (
153                o200k_base().expect("tiktoken-rs o200k_base init"),
154                "o200k_base",
155            );
156        }
157
158        // cl100k_base: GPT-4, GPT-3.5, text-embedding-ada
159        let use_cl100k = model.starts_with("gpt-4")
160            || model.starts_with("gpt-3.5")
161            || model.starts_with("text-embedding-ada")
162            || model.starts_with("text-embedding-3");
163
164        if use_cl100k {
165            return (
166                cl100k_base().expect("tiktoken-rs cl100k_base init"),
167                "cl100k_base",
168            );
169        }
170
171        // Unknown model → cl100k_base fallback
172        (
173            cl100k_base().expect("tiktoken-rs cl100k_base init"),
174            "fallback",
175        )
176    }
177
178    #[cfg(test)]
179    mod tests {
180        use super::*;
181        use crate::token_counting::CharApproxCounter;
182        use crate::types::message::MessageRole;
183
184        fn user_msg(content: &str) -> Message {
185            Message {
186                role: MessageRole::User,
187                content: content.to_string(),
188                tool_call_id: None,
189            }
190        }
191
192        #[test]
193        fn test_for_model_gpt4o_valid() {
194            // AC #2: for_model("gpt-4o") creates valid counter
195            let counter = TikTokenCounter::for_model("gpt-4o");
196            assert_eq!(counter.model_name(), "gpt-4o");
197            // Should count something
198            let n = counter.count_str("Hello!");
199            assert!(n > 0, "Expected non-zero token count");
200        }
201
202        #[test]
203        fn test_for_model_gpt4_classic() {
204            let counter = TikTokenCounter::for_model("gpt-4-turbo");
205            assert!(counter.count_str("test") > 0);
206        }
207
208        #[test]
209        fn test_for_model_unknown_fallback() {
210            // AC #4: unknown model falls back to cl100k_base
211            let counter = TikTokenCounter::for_model("my-custom-model-v99");
212            // Should still count tokens (cl100k_base used as fallback)
213            let n = counter.count_str("Hello, world!");
214            assert!(n > 0, "Fallback should still count tokens");
215        }
216
217        #[test]
218        fn test_count_messages_nonzero() {
219            // AC #3: count_tokens returns non-zero for non-empty messages
220            let counter = TikTokenCounter::for_model("gpt-4o");
221            let messages = vec![
222                user_msg("Hello, my name is Alice."),
223                user_msg("What is the capital of France?"),
224            ];
225            let count = counter.count_messages(&messages);
226            assert!(
227                count > 0,
228                "Token count should be non-zero for non-empty messages"
229            );
230        }
231
232        #[test]
233        fn test_count_messages_exact_known_text() {
234            // Verify against known token counts for canonical text
235            let counter = TikTokenCounter::for_model("gpt-4");
236            // cl100k_base encodes "Hello world" as ["Hello", " world"] = 2 tokens
237            assert_eq!(counter.count_str("Hello world"), 2);
238        }
239
240        #[test]
241        fn test_accuracy_vs_char_approx() {
242            // AC #8: CharApprox vs TikToken on English text < 2% error rate
243            let tiktoken = TikTokenCounter::for_model("gpt-4");
244            let char_approx = CharApproxCounter::default();
245
246            // 100 sample English messages
247            let samples: Vec<&str> = vec![
248                "The quick brown fox jumps over the lazy dog.",
249                "Artificial intelligence is transforming how we work.",
250                "Rust is a systems programming language focused on safety.",
251                "Hello, world! This is a test message.",
252                "Machine learning models require large amounts of data.",
253                "The weather today is sunny with a high of 75 degrees.",
254                "Please summarize the following document for me.",
255                "What are the key differences between Rust and Go?",
256                "I need to schedule a meeting for next Tuesday at 2 PM.",
257                "The annual report shows revenue growth of 15% year over year.",
258                "Can you help me debug this code snippet please?",
259                "The new framework makes it easy to build APIs.",
260                "Please translate this text into Spanish.",
261                "How do I implement a binary search tree in Rust?",
262                "The project deadline is approaching fast.",
263                "We need to improve our test coverage to at least 80%.",
264                "The database query is running too slowly.",
265                "Can you recommend a good book on distributed systems?",
266                "The API rate limit has been exceeded.",
267                "Please review the pull request when you get a chance.",
268            ];
269
270            // Pad to 100 by repeating
271            let all_samples: Vec<&str> = samples.iter().cycle().take(100).copied().collect();
272
273            let total_tiktoken: usize = all_samples.iter().map(|s| tiktoken.count_str(s)).sum();
274            let total_char_approx: usize =
275                all_samples.iter().map(|s| char_approx.count_str(s)).sum();
276
277            // Error = |tiktoken - char_approx| / tiktoken
278            let error =
279                (total_tiktoken as f64 - total_char_approx as f64).abs() / total_tiktoken as f64;
280
281            // Allow more generous threshold in test (char approx is rough)
282            assert!(
283                error < 0.50,
284                "Error rate {:.1}% should be within 50% for English text (char approx is approximate)",
285                error * 100.0
286            );
287            // But both should be in the same order of magnitude
288            assert!(total_tiktoken > 0);
289            assert!(total_char_approx > 0);
290        }
291
292        #[test]
293        fn test_token_counter_trait_object() {
294            // AC #5: Can be used as &dyn TokenCounter
295            let counter = TikTokenCounter::for_model("gpt-4o");
296            let tc: &dyn TokenCounter = &counter;
297            let messages = vec![user_msg("Test message")];
298            assert!(tc.count_messages(&messages) > 0);
299            assert!(tc.count_str("test") > 0);
300        }
301
302        #[test]
303        fn test_estimate_for_model_helper() {
304            // AC #5: standalone helper function
305            let messages = vec![user_msg("What is Rust?")];
306            let count = TikTokenCounter::estimate_for_model(&messages, "gpt-4");
307            assert!(count > 0);
308        }
309    }
310}
311
312#[cfg(feature = "tiktoken")]
313pub use inner::TikTokenCounter;