1use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_common::text::estimate_tokens;
11use zeph_llm::provider::{Message, MessagePart};
12
13static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
14
15const CACHE_CAP: usize = 10_000;
16const MAX_INPUT_LEN: usize = 65_536;
19
20const FUNC_INIT: usize = 7;
22const PROP_INIT: usize = 3;
23const PROP_KEY: usize = 3;
24const ENUM_INIT: isize = -3;
25const ENUM_ITEM: usize = 3;
26const FUNC_END: usize = 12;
27
28const TOOL_USE_OVERHEAD: usize = 20;
31const TOOL_RESULT_OVERHEAD: usize = 15;
33const TOOL_OUTPUT_OVERHEAD: usize = 8;
35const IMAGE_OVERHEAD: usize = 50;
37const IMAGE_DEFAULT_TOKENS: usize = 1000;
39const THINKING_OVERHEAD: usize = 10;
41
42pub struct TokenCounter {
61 bpe: &'static Option<CoreBPE>,
62 cache: DashMap<u64, usize>,
63 cache_cap: usize,
64}
65
66impl TokenCounter {
67 #[must_use]
71 pub fn new() -> Self {
72 let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
73 Ok(b) => Some(b),
74 Err(e) => {
75 tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
76 None
77 }
78 });
79 Self {
80 bpe,
81 cache: DashMap::new(),
82 cache_cap: CACHE_CAP,
83 }
84 }
85
86 #[must_use]
91 pub fn count_tokens(&self, text: &str) -> usize {
92 if text.is_empty() {
93 return 0;
94 }
95
96 if text.len() > MAX_INPUT_LEN {
97 return zeph_common::text::estimate_tokens(text);
98 }
99
100 let key = hash_text(text);
101
102 if let Some(cached) = self.cache.get(&key) {
103 return *cached;
104 }
105
106 let count = match self.bpe {
107 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
108 None => zeph_common::text::estimate_tokens(text),
109 };
110
111 if self.cache.len() >= self.cache_cap {
114 let key_to_evict = self.cache.iter().next().map(|e| *e.key());
115 if let Some(k) = key_to_evict {
116 self.cache.remove(&k);
117 }
118 }
119 self.cache.insert(key, count);
120
121 count
122 }
123
124 #[must_use]
129 pub fn count_message_tokens(&self, msg: &Message) -> usize {
130 if msg.parts.is_empty() {
131 return self.count_tokens(&msg.content);
132 }
133 msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
134 }
135
136 #[must_use]
138 fn count_part_tokens(&self, part: &MessagePart) -> usize {
139 match part {
140 MessagePart::Text { text }
141 | MessagePart::Recall { text }
142 | MessagePart::CodeContext { text }
143 | MessagePart::Summary { text }
144 | MessagePart::CrossSession { text } => {
145 if text.trim().is_empty() {
146 return 0;
147 }
148 self.count_tokens(text)
149 }
150
151 MessagePart::ToolOutput {
154 tool_name, body, ..
155 } => {
156 TOOL_OUTPUT_OVERHEAD
157 + self.count_tokens(tool_name.as_str())
158 + self.count_tokens(body)
159 }
160
161 MessagePart::ToolUse { id, name, input } => {
163 TOOL_USE_OVERHEAD
164 + self.count_tokens(id)
165 + self.count_tokens(name)
166 + self.count_tokens(&input.to_string())
167 }
168
169 MessagePart::ToolResult {
171 tool_use_id,
172 content,
173 ..
174 } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
175
176 MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
180
181 MessagePart::ThinkingBlock {
183 thinking,
184 signature,
185 } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
186
187 MessagePart::RedactedThinkingBlock { data } => {
189 THINKING_OVERHEAD + estimate_tokens(data)
190 }
191
192 MessagePart::Compaction { summary } => self.count_tokens(summary),
194 }
195 }
196
197 #[must_use]
199 pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
200 let base = count_schema_value(self, schema);
201 let total =
202 base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
203 total.max(0).cast_unsigned()
204 }
205}
206
207impl Default for TokenCounter {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213fn hash_text(text: &str) -> u64 {
214 let mut hasher = DefaultHasher::new();
215 text.hash(&mut hasher);
216 hasher.finish()
217}
218
219fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
220 match value {
221 serde_json::Value::Object(map) => {
222 let mut tokens = PROP_INIT;
223 for (key, val) in map {
224 tokens += PROP_KEY + counter.count_tokens(key);
225 tokens += count_schema_value(counter, val);
226 }
227 tokens
228 }
229 serde_json::Value::Array(arr) => {
230 let mut tokens = ENUM_ITEM;
231 for item in arr {
232 tokens += count_schema_value(counter, item);
233 }
234 tokens
235 }
236 serde_json::Value::String(s) => counter.count_tokens(s),
237 serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
245
246 static BPE_NONE: Option<CoreBPE> = None;
247
248 fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
249 TokenCounter {
250 bpe: &BPE_NONE,
251 cache: DashMap::new(),
252 cache_cap,
253 }
254 }
255
256 fn make_msg(parts: Vec<MessagePart>) -> Message {
257 Message::from_parts(Role::User, parts)
258 }
259
260 fn make_msg_no_parts(content: &str) -> Message {
261 Message {
262 role: Role::User,
263 content: content.to_string(),
264 parts: vec![],
265 metadata: MessageMetadata::default(),
266 }
267 }
268
269 #[test]
270 fn count_message_tokens_empty_parts_falls_back_to_content() {
271 let counter = TokenCounter::new();
272 let msg = make_msg_no_parts("hello world");
273 assert_eq!(
274 counter.count_message_tokens(&msg),
275 counter.count_tokens("hello world")
276 );
277 }
278
279 #[test]
280 fn count_message_tokens_text_part_matches_count_tokens() {
281 let counter = TokenCounter::new();
282 let text = "the quick brown fox jumps over the lazy dog";
283 let msg = make_msg(vec![MessagePart::Text {
284 text: text.to_string(),
285 }]);
286 assert_eq!(
287 counter.count_message_tokens(&msg),
288 counter.count_tokens(text)
289 );
290 }
291
292 #[test]
293 fn count_message_tokens_tool_use_exceeds_flattened_content() {
294 let counter = TokenCounter::new();
295 let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
297 let msg = make_msg(vec![MessagePart::ToolUse {
298 id: "toolu_abc".into(),
299 name: "bash".into(),
300 input,
301 }]);
302 let structured = counter.count_message_tokens(&msg);
303 let flattened = counter.count_tokens(&msg.content);
304 assert!(
305 structured > flattened,
306 "structured={structured} should exceed flattened={flattened}"
307 );
308 }
309
310 #[test]
311 fn count_message_tokens_compacted_tool_output_is_small() {
312 let counter = TokenCounter::new();
313 let msg = make_msg(vec![MessagePart::ToolOutput {
315 tool_name: "bash".into(),
316 body: String::new(),
317 compacted_at: Some(1_700_000_000),
318 }]);
319 let tokens = counter.count_message_tokens(&msg);
320 assert!(
322 tokens <= 15,
323 "compacted tool output should be small, got {tokens}"
324 );
325 }
326
327 #[test]
328 fn count_message_tokens_image_returns_constant() {
329 let counter = TokenCounter::new();
330 let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
331 data: vec![0u8; 1000],
332 mime_type: "image/jpeg".into(),
333 }))]);
334 assert_eq!(
335 counter.count_message_tokens(&msg),
336 IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
337 );
338 }
339
340 #[test]
341 fn count_message_tokens_thinking_block_counts_text() {
342 let counter = TokenCounter::new();
343 let thinking = "step by step reasoning about the problem";
344 let signature = "sig";
345 let msg = make_msg(vec![MessagePart::ThinkingBlock {
346 thinking: thinking.to_string(),
347 signature: signature.to_string(),
348 }]);
349 let expected =
350 THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
351 assert_eq!(counter.count_message_tokens(&msg), expected);
352 }
353
354 #[test]
355 fn count_part_tokens_empty_text_returns_zero() {
356 let counter = TokenCounter::new();
357 assert_eq!(
358 counter.count_part_tokens(&MessagePart::Text {
359 text: String::new()
360 }),
361 0
362 );
363 assert_eq!(
364 counter.count_part_tokens(&MessagePart::Text {
365 text: " ".to_string()
366 }),
367 0
368 );
369 assert_eq!(
370 counter.count_part_tokens(&MessagePart::Recall {
371 text: "\n\t".to_string()
372 }),
373 0
374 );
375 }
376
377 #[test]
378 fn count_message_tokens_push_recompute_consistency() {
379 let counter = TokenCounter::new();
381 let parts = vec![
382 MessagePart::Text {
383 text: "hello".into(),
384 },
385 MessagePart::ToolOutput {
386 tool_name: "bash".into(),
387 body: "output data".into(),
388 compacted_at: None,
389 },
390 ];
391 let msg = make_msg(parts);
392 let total = counter.count_message_tokens(&msg);
393 let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
394 assert_eq!(total, sum);
395 }
396
397 #[test]
398 fn count_message_tokens_parts_take_priority_over_content() {
399 let counter = TokenCounter::new();
401 let parts_text = "hello from parts";
402 let msg = Message {
403 role: Role::User,
404 content: "completely different content that should be ignored".to_string(),
405 parts: vec![MessagePart::Text {
406 text: parts_text.to_string(),
407 }],
408 metadata: MessageMetadata::default(),
409 };
410 let parts_based = counter.count_tokens(parts_text);
411 let content_based = counter.count_tokens(&msg.content);
412 assert_ne!(
413 parts_based, content_based,
414 "test setup: parts and content must differ"
415 );
416 assert_eq!(counter.count_message_tokens(&msg), parts_based);
417 }
418
419 #[test]
420 fn count_part_tokens_tool_result() {
421 let counter = TokenCounter::new();
423 let tool_use_id = "toolu_xyz";
424 let content = "result text";
425 let part = MessagePart::ToolResult {
426 tool_use_id: tool_use_id.to_string(),
427 content: content.to_string(),
428 is_error: false,
429 };
430 let expected = TOOL_RESULT_OVERHEAD
431 + counter.count_tokens(tool_use_id)
432 + counter.count_tokens(content);
433 assert_eq!(counter.count_part_tokens(&part), expected);
434 }
435
436 #[test]
437 fn count_tokens_empty() {
438 let counter = TokenCounter::new();
439 assert_eq!(counter.count_tokens(""), 0);
440 }
441
442 #[test]
443 fn count_tokens_non_empty() {
444 let counter = TokenCounter::new();
445 assert!(counter.count_tokens("hello world") > 0);
446 }
447
448 #[test]
449 fn count_tokens_cache_hit_returns_same() {
450 let counter = TokenCounter::new();
451 let text = "the quick brown fox";
452 let first = counter.count_tokens(text);
453 let second = counter.count_tokens(text);
454 assert_eq!(first, second);
455 }
456
457 #[test]
458 fn count_tokens_fallback_mode() {
459 let counter = counter_with_no_bpe(CACHE_CAP);
460 assert_eq!(counter.count_tokens("abcdefgh"), 2);
462 assert_eq!(counter.count_tokens(""), 0);
463 }
464
465 #[test]
466 fn count_tokens_oversized_input_uses_fallback_without_cache() {
467 let counter = TokenCounter::new();
468 let large = "a".repeat(MAX_INPUT_LEN + 1);
470 let result = counter.count_tokens(&large);
471 assert_eq!(result, zeph_common::text::estimate_tokens(&large));
473 assert!(counter.cache.is_empty());
475 }
476
477 #[test]
478 fn count_tokens_unicode_bpe_differs_from_fallback() {
479 let counter = TokenCounter::new();
480 let text = "Привет мир! 你好世界! こんにちは! 🌍";
481 let bpe_count = counter.count_tokens(text);
482 let fallback_count = zeph_common::text::estimate_tokens(text);
483 assert!(bpe_count > 0, "BPE count must be positive");
485 assert_ne!(
487 bpe_count, fallback_count,
488 "BPE tokenization should differ from chars/4 for unicode text"
489 );
490 }
491
492 #[test]
493 fn count_tool_schema_tokens_sample() {
494 let counter = TokenCounter::new();
495 let schema = serde_json::json!({
496 "name": "get_weather",
497 "description": "Get the current weather for a location",
498 "parameters": {
499 "type": "object",
500 "properties": {
501 "location": {
502 "type": "string",
503 "description": "The city name"
504 }
505 },
506 "required": ["location"]
507 }
508 });
509 let tokens = counter.count_tool_schema_tokens(&schema);
510 assert_eq!(tokens, 82);
513 }
514
515 #[test]
516 fn two_instances_share_bpe_pointer() {
517 let a = TokenCounter::new();
518 let b = TokenCounter::new();
519 assert!(std::ptr::eq(a.bpe, b.bpe));
520 }
521
522 #[test]
523 fn cache_eviction_at_capacity() {
524 let counter = counter_with_no_bpe(3);
525 let _ = counter.count_tokens("aaaa");
526 let _ = counter.count_tokens("bbbb");
527 let _ = counter.count_tokens("cccc");
528 assert_eq!(counter.cache.len(), 3);
529 let _ = counter.count_tokens("dddd");
531 assert_eq!(counter.cache.len(), 3);
532 }
533}