1use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6
7use dashmap::DashMap;
8use tiktoken_rs::CoreBPE;
9use zeph_llm::provider::{Message, MessagePart};
10
11const CACHE_CAP: usize = 10_000;
12const MAX_INPUT_LEN: usize = 65_536;
15
16const FUNC_INIT: usize = 7;
18const PROP_INIT: usize = 3;
19const PROP_KEY: usize = 3;
20const ENUM_INIT: isize = -3;
21const ENUM_ITEM: usize = 3;
22const FUNC_END: usize = 12;
23
24const TOOL_USE_OVERHEAD: usize = 20;
27const TOOL_RESULT_OVERHEAD: usize = 15;
29const TOOL_OUTPUT_OVERHEAD: usize = 8;
31const IMAGE_OVERHEAD: usize = 50;
33const IMAGE_DEFAULT_TOKENS: usize = 1000;
35const THINKING_OVERHEAD: usize = 10;
37
38pub struct TokenCounter {
39 bpe: Option<CoreBPE>,
40 cache: DashMap<u64, usize>,
41 cache_cap: usize,
42}
43
44impl TokenCounter {
45 #[must_use]
47 pub fn new() -> Self {
48 let bpe = match tiktoken_rs::cl100k_base() {
49 Ok(b) => Some(b),
50 Err(e) => {
51 tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
52 None
53 }
54 };
55 Self {
56 bpe,
57 cache: DashMap::new(),
58 cache_cap: CACHE_CAP,
59 }
60 }
61
62 #[must_use]
67 pub fn count_tokens(&self, text: &str) -> usize {
68 if text.is_empty() {
69 return 0;
70 }
71
72 if text.len() > MAX_INPUT_LEN {
73 return text.chars().count() / 4;
74 }
75
76 let key = hash_text(text);
77
78 if let Some(cached) = self.cache.get(&key) {
79 return *cached;
80 }
81
82 let count = match &self.bpe {
83 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
84 None => text.chars().count() / 4,
85 };
86
87 if self.cache.len() >= self.cache_cap {
90 let key_to_evict = self.cache.iter().next().map(|e| *e.key());
91 if let Some(k) = key_to_evict {
92 self.cache.remove(&k);
93 }
94 }
95 self.cache.insert(key, count);
96
97 count
98 }
99
100 #[must_use]
105 pub fn count_message_tokens(&self, msg: &Message) -> usize {
106 if msg.parts.is_empty() {
107 return self.count_tokens(&msg.content);
108 }
109 msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
110 }
111
112 #[must_use]
114 fn count_part_tokens(&self, part: &MessagePart) -> usize {
115 match part {
116 MessagePart::Text { text }
117 | MessagePart::Recall { text }
118 | MessagePart::CodeContext { text }
119 | MessagePart::Summary { text }
120 | MessagePart::CrossSession { text } => {
121 if text.trim().is_empty() {
122 return 0;
123 }
124 self.count_tokens(text)
125 }
126
127 MessagePart::ToolOutput {
130 tool_name, body, ..
131 } => TOOL_OUTPUT_OVERHEAD + self.count_tokens(tool_name) + self.count_tokens(body),
132
133 MessagePart::ToolUse { id, name, input } => {
135 TOOL_USE_OVERHEAD
136 + self.count_tokens(id)
137 + self.count_tokens(name)
138 + self.count_tokens(&input.to_string())
139 }
140
141 MessagePart::ToolResult {
143 tool_use_id,
144 content,
145 ..
146 } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
147
148 MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
152
153 MessagePart::ThinkingBlock {
155 thinking,
156 signature,
157 } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
158
159 MessagePart::RedactedThinkingBlock { data } => THINKING_OVERHEAD + data.len() / 4,
161 }
162 }
163
164 #[must_use]
166 pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
167 let base = count_schema_value(self, schema);
168 let total =
169 base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
170 total.max(0).cast_unsigned()
171 }
172}
173
174impl Default for TokenCounter {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180fn hash_text(text: &str) -> u64 {
181 let mut hasher = DefaultHasher::new();
182 text.hash(&mut hasher);
183 hasher.finish()
184}
185
186fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
187 match value {
188 serde_json::Value::Object(map) => {
189 let mut tokens = PROP_INIT;
190 for (key, val) in map {
191 tokens += PROP_KEY + counter.count_tokens(key);
192 tokens += count_schema_value(counter, val);
193 }
194 tokens
195 }
196 serde_json::Value::Array(arr) => {
197 let mut tokens = ENUM_ITEM;
198 for item in arr {
199 tokens += count_schema_value(counter, item);
200 }
201 tokens
202 }
203 serde_json::Value::String(s) => counter.count_tokens(s),
204 serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
212
213 fn make_msg(parts: Vec<MessagePart>) -> Message {
214 Message::from_parts(Role::User, parts)
215 }
216
217 fn make_msg_no_parts(content: &str) -> Message {
218 Message {
219 role: Role::User,
220 content: content.to_string(),
221 parts: vec![],
222 metadata: MessageMetadata::default(),
223 }
224 }
225
226 #[test]
227 fn count_message_tokens_empty_parts_falls_back_to_content() {
228 let counter = TokenCounter::new();
229 let msg = make_msg_no_parts("hello world");
230 assert_eq!(
231 counter.count_message_tokens(&msg),
232 counter.count_tokens("hello world")
233 );
234 }
235
236 #[test]
237 fn count_message_tokens_text_part_matches_count_tokens() {
238 let counter = TokenCounter::new();
239 let text = "the quick brown fox jumps over the lazy dog";
240 let msg = make_msg(vec![MessagePart::Text {
241 text: text.to_string(),
242 }]);
243 assert_eq!(
244 counter.count_message_tokens(&msg),
245 counter.count_tokens(text)
246 );
247 }
248
249 #[test]
250 fn count_message_tokens_tool_use_exceeds_flattened_content() {
251 let counter = TokenCounter::new();
252 let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
254 let msg = make_msg(vec![MessagePart::ToolUse {
255 id: "toolu_abc".into(),
256 name: "bash".into(),
257 input,
258 }]);
259 let structured = counter.count_message_tokens(&msg);
260 let flattened = counter.count_tokens(&msg.content);
261 assert!(
262 structured > flattened,
263 "structured={structured} should exceed flattened={flattened}"
264 );
265 }
266
267 #[test]
268 fn count_message_tokens_compacted_tool_output_is_small() {
269 let counter = TokenCounter::new();
270 let msg = make_msg(vec![MessagePart::ToolOutput {
272 tool_name: "bash".into(),
273 body: String::new(),
274 compacted_at: Some(1_700_000_000),
275 }]);
276 let tokens = counter.count_message_tokens(&msg);
277 assert!(
279 tokens <= 15,
280 "compacted tool output should be small, got {tokens}"
281 );
282 }
283
284 #[test]
285 fn count_message_tokens_image_returns_constant() {
286 let counter = TokenCounter::new();
287 let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
288 data: vec![0u8; 1000],
289 mime_type: "image/jpeg".into(),
290 }))]);
291 assert_eq!(
292 counter.count_message_tokens(&msg),
293 IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
294 );
295 }
296
297 #[test]
298 fn count_message_tokens_thinking_block_counts_text() {
299 let counter = TokenCounter::new();
300 let thinking = "step by step reasoning about the problem";
301 let signature = "sig";
302 let msg = make_msg(vec![MessagePart::ThinkingBlock {
303 thinking: thinking.to_string(),
304 signature: signature.to_string(),
305 }]);
306 let expected =
307 THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
308 assert_eq!(counter.count_message_tokens(&msg), expected);
309 }
310
311 #[test]
312 fn count_part_tokens_empty_text_returns_zero() {
313 let counter = TokenCounter::new();
314 assert_eq!(
315 counter.count_part_tokens(&MessagePart::Text {
316 text: String::new()
317 }),
318 0
319 );
320 assert_eq!(
321 counter.count_part_tokens(&MessagePart::Text {
322 text: " ".to_string()
323 }),
324 0
325 );
326 assert_eq!(
327 counter.count_part_tokens(&MessagePart::Recall {
328 text: "\n\t".to_string()
329 }),
330 0
331 );
332 }
333
334 #[test]
335 fn count_message_tokens_push_recompute_consistency() {
336 let counter = TokenCounter::new();
338 let parts = vec![
339 MessagePart::Text {
340 text: "hello".into(),
341 },
342 MessagePart::ToolOutput {
343 tool_name: "bash".into(),
344 body: "output data".into(),
345 compacted_at: None,
346 },
347 ];
348 let msg = make_msg(parts);
349 let total = counter.count_message_tokens(&msg);
350 let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
351 assert_eq!(total, sum);
352 }
353
354 #[test]
355 fn count_message_tokens_parts_take_priority_over_content() {
356 let counter = TokenCounter::new();
358 let parts_text = "hello from parts";
359 let msg = Message {
360 role: Role::User,
361 content: "completely different content that should be ignored".to_string(),
362 parts: vec![MessagePart::Text {
363 text: parts_text.to_string(),
364 }],
365 metadata: MessageMetadata::default(),
366 };
367 let parts_based = counter.count_tokens(parts_text);
368 let content_based = counter.count_tokens(&msg.content);
369 assert_ne!(
370 parts_based, content_based,
371 "test setup: parts and content must differ"
372 );
373 assert_eq!(counter.count_message_tokens(&msg), parts_based);
374 }
375
376 #[test]
377 fn count_part_tokens_tool_result() {
378 let counter = TokenCounter::new();
380 let tool_use_id = "toolu_xyz";
381 let content = "result text";
382 let part = MessagePart::ToolResult {
383 tool_use_id: tool_use_id.to_string(),
384 content: content.to_string(),
385 is_error: false,
386 };
387 let expected = TOOL_RESULT_OVERHEAD
388 + counter.count_tokens(tool_use_id)
389 + counter.count_tokens(content);
390 assert_eq!(counter.count_part_tokens(&part), expected);
391 }
392
393 #[test]
394 fn count_tokens_empty() {
395 let counter = TokenCounter::new();
396 assert_eq!(counter.count_tokens(""), 0);
397 }
398
399 #[test]
400 fn count_tokens_non_empty() {
401 let counter = TokenCounter::new();
402 assert!(counter.count_tokens("hello world") > 0);
403 }
404
405 #[test]
406 fn count_tokens_cache_hit_returns_same() {
407 let counter = TokenCounter::new();
408 let text = "the quick brown fox";
409 let first = counter.count_tokens(text);
410 let second = counter.count_tokens(text);
411 assert_eq!(first, second);
412 }
413
414 #[test]
415 fn count_tokens_fallback_mode() {
416 let counter = TokenCounter {
417 bpe: None,
418 cache: DashMap::new(),
419 cache_cap: CACHE_CAP,
420 };
421 assert_eq!(counter.count_tokens("abcdefgh"), 2);
423 assert_eq!(counter.count_tokens(""), 0);
424 }
425
426 #[test]
427 fn count_tokens_oversized_input_uses_fallback_without_cache() {
428 let counter = TokenCounter::new();
429 let large = "a".repeat(MAX_INPUT_LEN + 1);
431 let result = counter.count_tokens(&large);
432 assert_eq!(result, large.chars().count() / 4);
434 assert!(counter.cache.is_empty());
436 }
437
438 #[test]
439 fn count_tokens_unicode_bpe_differs_from_fallback() {
440 let counter = TokenCounter::new();
441 let text = "Привет мир! 你好世界! こんにちは! 🌍";
442 let bpe_count = counter.count_tokens(text);
443 let fallback_count = text.chars().count() / 4;
444 assert!(bpe_count > 0, "BPE count must be positive");
446 assert_ne!(
448 bpe_count, fallback_count,
449 "BPE tokenization should differ from chars/4 for unicode text"
450 );
451 }
452
453 #[test]
454 fn count_tool_schema_tokens_sample() {
455 let counter = TokenCounter::new();
456 let schema = serde_json::json!({
457 "name": "get_weather",
458 "description": "Get the current weather for a location",
459 "parameters": {
460 "type": "object",
461 "properties": {
462 "location": {
463 "type": "string",
464 "description": "The city name"
465 }
466 },
467 "required": ["location"]
468 }
469 });
470 let tokens = counter.count_tool_schema_tokens(&schema);
471 assert_eq!(tokens, 82);
474 }
475
476 #[test]
477 fn cache_eviction_at_capacity() {
478 let counter = TokenCounter {
479 bpe: None,
480 cache: DashMap::new(),
481 cache_cap: 3,
482 };
483 let _ = counter.count_tokens("aaaa");
484 let _ = counter.count_tokens("bbbb");
485 let _ = counter.count_tokens("cccc");
486 assert_eq!(counter.cache.len(), 3);
487 let _ = counter.count_tokens("dddd");
489 assert_eq!(counter.cache.len(), 3);
490 }
491}