traitclaw_core/
token_counter.rs1#[cfg(feature = "tiktoken")]
21mod inner {
22 use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
23
24 use crate::token_counting::TokenCounter;
25 use crate::types::message::{Message, MessageRole};
26
27 const MESSAGE_OVERHEAD_TOKENS: usize = 4;
30 const REPLY_PRIMING_TOKENS: usize = 3;
32
33 pub struct TikTokenCounter {
49 bpe: CoreBPE,
50 model_name: String,
51 }
52
53 impl TikTokenCounter {
54 #[must_use]
66 pub fn for_model(model: &str) -> Self {
67 let (bpe, used_encoding) = select_encoding(model);
68 if used_encoding == "fallback" {
69 tracing::warn!(
70 "TikTokenCounter: unknown model '{}', falling back to cl100k_base encoding.",
71 model
72 );
73 }
74 Self {
75 bpe,
76 model_name: model.to_string(),
77 }
78 }
79
80 #[must_use]
82 pub fn model_name(&self) -> &str {
83 &self.model_name
84 }
85
86 #[must_use]
88 pub fn count_str(&self, text: &str) -> usize {
89 self.bpe.encode_with_special_tokens(text).len()
90 }
91
92 #[must_use]
97 pub fn count_messages(&self, messages: &[Message]) -> usize {
98 let content_tokens: usize = messages
99 .iter()
100 .map(|m| {
101 let role_str = match &m.role {
102 MessageRole::System => "system",
103 MessageRole::User => "user",
104 MessageRole::Assistant => "assistant",
105 MessageRole::Tool => "tool",
106 };
107 let role_tokens = self.bpe.encode_with_special_tokens(role_str).len();
108 let content_tokens = self.bpe.encode_with_special_tokens(&m.content).len();
109 role_tokens + content_tokens + MESSAGE_OVERHEAD_TOKENS
110 })
111 .sum();
112 content_tokens + REPLY_PRIMING_TOKENS
113 }
114
115 #[must_use]
125 pub fn estimate_for_model(messages: &[Message], model: &str) -> usize {
126 Self::for_model(model).count_messages(messages)
127 }
128 }
129
130 impl TokenCounter for TikTokenCounter {
131 fn count_messages(&self, messages: &[Message]) -> usize {
132 self.count_messages(messages)
133 }
134
135 fn count_str(&self, text: &str) -> usize {
136 self.count_str(text)
137 }
138 }
139
140 fn select_encoding(model: &str) -> (CoreBPE, &'static str) {
145 let use_o200k = model.starts_with("gpt-4o")
147 || model.starts_with("o1")
148 || model.starts_with("o3")
149 || model.starts_with("o4");
150
151 if use_o200k {
152 return (
153 o200k_base().expect("tiktoken-rs o200k_base init"),
154 "o200k_base",
155 );
156 }
157
158 let use_cl100k = model.starts_with("gpt-4")
160 || model.starts_with("gpt-3.5")
161 || model.starts_with("text-embedding-ada")
162 || model.starts_with("text-embedding-3");
163
164 if use_cl100k {
165 return (
166 cl100k_base().expect("tiktoken-rs cl100k_base init"),
167 "cl100k_base",
168 );
169 }
170
171 (
173 cl100k_base().expect("tiktoken-rs cl100k_base init"),
174 "fallback",
175 )
176 }
177
178 #[cfg(test)]
179 mod tests {
180 use super::*;
181 use crate::token_counting::CharApproxCounter;
182 use crate::types::message::MessageRole;
183
184 fn user_msg(content: &str) -> Message {
185 Message {
186 role: MessageRole::User,
187 content: content.to_string(),
188 tool_call_id: None,
189 }
190 }
191
192 #[test]
193 fn test_for_model_gpt4o_valid() {
194 let counter = TikTokenCounter::for_model("gpt-4o");
196 assert_eq!(counter.model_name(), "gpt-4o");
197 let n = counter.count_str("Hello!");
199 assert!(n > 0, "Expected non-zero token count");
200 }
201
202 #[test]
203 fn test_for_model_gpt4_classic() {
204 let counter = TikTokenCounter::for_model("gpt-4-turbo");
205 assert!(counter.count_str("test") > 0);
206 }
207
208 #[test]
209 fn test_for_model_unknown_fallback() {
210 let counter = TikTokenCounter::for_model("my-custom-model-v99");
212 let n = counter.count_str("Hello, world!");
214 assert!(n > 0, "Fallback should still count tokens");
215 }
216
217 #[test]
218 fn test_count_messages_nonzero() {
219 let counter = TikTokenCounter::for_model("gpt-4o");
221 let messages = vec![
222 user_msg("Hello, my name is Alice."),
223 user_msg("What is the capital of France?"),
224 ];
225 let count = counter.count_messages(&messages);
226 assert!(
227 count > 0,
228 "Token count should be non-zero for non-empty messages"
229 );
230 }
231
232 #[test]
233 fn test_count_messages_exact_known_text() {
234 let counter = TikTokenCounter::for_model("gpt-4");
236 assert_eq!(counter.count_str("Hello world"), 2);
238 }
239
240 #[test]
241 fn test_accuracy_vs_char_approx() {
242 let tiktoken = TikTokenCounter::for_model("gpt-4");
244 let char_approx = CharApproxCounter::default();
245
246 let samples: Vec<&str> = vec![
248 "The quick brown fox jumps over the lazy dog.",
249 "Artificial intelligence is transforming how we work.",
250 "Rust is a systems programming language focused on safety.",
251 "Hello, world! This is a test message.",
252 "Machine learning models require large amounts of data.",
253 "The weather today is sunny with a high of 75 degrees.",
254 "Please summarize the following document for me.",
255 "What are the key differences between Rust and Go?",
256 "I need to schedule a meeting for next Tuesday at 2 PM.",
257 "The annual report shows revenue growth of 15% year over year.",
258 "Can you help me debug this code snippet please?",
259 "The new framework makes it easy to build APIs.",
260 "Please translate this text into Spanish.",
261 "How do I implement a binary search tree in Rust?",
262 "The project deadline is approaching fast.",
263 "We need to improve our test coverage to at least 80%.",
264 "The database query is running too slowly.",
265 "Can you recommend a good book on distributed systems?",
266 "The API rate limit has been exceeded.",
267 "Please review the pull request when you get a chance.",
268 ];
269
270 let all_samples: Vec<&str> = samples.iter().cycle().take(100).copied().collect();
272
273 let total_tiktoken: usize = all_samples.iter().map(|s| tiktoken.count_str(s)).sum();
274 let total_char_approx: usize =
275 all_samples.iter().map(|s| char_approx.count_str(s)).sum();
276
277 let error =
279 (total_tiktoken as f64 - total_char_approx as f64).abs() / total_tiktoken as f64;
280
281 assert!(
283 error < 0.50,
284 "Error rate {:.1}% should be within 50% for English text (char approx is approximate)",
285 error * 100.0
286 );
287 assert!(total_tiktoken > 0);
289 assert!(total_char_approx > 0);
290 }
291
292 #[test]
293 fn test_token_counter_trait_object() {
294 let counter = TikTokenCounter::for_model("gpt-4o");
296 let tc: &dyn TokenCounter = &counter;
297 let messages = vec![user_msg("Test message")];
298 assert!(tc.count_messages(&messages) > 0);
299 assert!(tc.count_str("test") > 0);
300 }
301
302 #[test]
303 fn test_estimate_for_model_helper() {
304 let messages = vec![user_msg("What is Rust?")];
306 let count = TikTokenCounter::estimate_for_model(&messages, "gpt-4");
307 assert!(count > 0);
308 }
309 }
310}
311
312#[cfg(feature = "tiktoken")]
313pub use inner::TikTokenCounter;