1use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_common::text::estimate_tokens;
11use zeph_llm::provider::{Message, MessagePart};
12
13static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
14
15const CACHE_CAP: usize = 10_000;
16const MAX_INPUT_LEN: usize = 65_536;
19
20const FUNC_INIT: usize = 7;
22const PROP_INIT: usize = 3;
23const PROP_KEY: usize = 3;
24const ENUM_INIT: isize = -3;
25const ENUM_ITEM: usize = 3;
26const FUNC_END: usize = 12;
27
28const TOOL_USE_OVERHEAD: usize = 20;
31const TOOL_RESULT_OVERHEAD: usize = 15;
33const TOOL_OUTPUT_OVERHEAD: usize = 8;
35const IMAGE_OVERHEAD: usize = 50;
37const IMAGE_DEFAULT_TOKENS: usize = 1000;
39const THINKING_OVERHEAD: usize = 10;
41
42pub struct TokenCounter {
61 bpe: &'static Option<CoreBPE>,
62 cache: DashMap<u64, usize>,
63 cache_cap: usize,
64}
65
66impl TokenCounter {
67 #[must_use]
71 pub fn new() -> Self {
72 let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
73 Ok(b) => Some(b),
74 Err(e) => {
75 tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
76 None
77 }
78 });
79 Self {
80 bpe,
81 cache: DashMap::new(),
82 cache_cap: CACHE_CAP,
83 }
84 }
85
86 #[must_use]
91 pub fn count_tokens(&self, text: &str) -> usize {
92 if text.is_empty() {
93 return 0;
94 }
95
96 if text.len() > MAX_INPUT_LEN {
97 return zeph_common::text::estimate_tokens(text);
98 }
99
100 let key = hash_text(text);
101
102 if let Some(cached) = self.cache.get(&key) {
103 return *cached;
104 }
105
106 let count = match self.bpe {
107 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
108 None => zeph_common::text::estimate_tokens(text),
109 };
110
111 if self.cache.len() >= self.cache_cap {
114 let key_to_evict = self.cache.iter().next().map(|e| *e.key());
115 if let Some(k) = key_to_evict {
116 self.cache.remove(&k);
117 }
118 }
119 self.cache.insert(key, count);
120
121 count
122 }
123
124 #[must_use]
129 pub fn count_message_tokens(&self, msg: &Message) -> usize {
130 if msg.parts.is_empty() {
131 return self.count_tokens(&msg.content);
132 }
133 msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
134 }
135
136 #[must_use]
138 fn count_part_tokens(&self, part: &MessagePart) -> usize {
139 match part {
140 MessagePart::Text { text }
141 | MessagePart::Recall { text }
142 | MessagePart::CodeContext { text }
143 | MessagePart::Summary { text }
144 | MessagePart::CrossSession { text } => {
145 if text.trim().is_empty() {
146 return 0;
147 }
148 self.count_tokens(text)
149 }
150
151 MessagePart::ToolOutput {
154 tool_name, body, ..
155 } => {
156 TOOL_OUTPUT_OVERHEAD
157 + self.count_tokens(tool_name.as_str())
158 + self.count_tokens(body)
159 }
160
161 MessagePart::ToolUse { id, name, input } => {
163 TOOL_USE_OVERHEAD
164 + self.count_tokens(id)
165 + self.count_tokens(name)
166 + self.count_tokens(&input.to_string())
167 }
168
169 MessagePart::ToolResult {
171 tool_use_id,
172 content,
173 ..
174 } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
175
176 MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
180
181 MessagePart::ThinkingBlock {
183 thinking,
184 signature,
185 } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
186
187 MessagePart::RedactedThinkingBlock { data } => {
189 THINKING_OVERHEAD + estimate_tokens(data)
190 }
191
192 MessagePart::Compaction { summary } => self.count_tokens(summary),
194 }
195 }
196
197 #[must_use]
199 pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
200 let base = count_schema_value(self, schema);
201 let total =
202 base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
203 total.max(0).cast_unsigned()
204 }
205}
206
207impl Default for TokenCounter {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213impl zeph_common::memory::TokenCounting for TokenCounter {
214 fn count_tokens(&self, text: &str) -> usize {
215 self.count_tokens(text)
216 }
217
218 fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
219 self.count_tool_schema_tokens(schema)
220 }
221}
222
223fn hash_text(text: &str) -> u64 {
224 let mut hasher = DefaultHasher::new();
225 text.hash(&mut hasher);
226 hasher.finish()
227}
228
229fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
230 match value {
231 serde_json::Value::Object(map) => {
232 let mut tokens = PROP_INIT;
233 for (key, val) in map {
234 tokens += PROP_KEY + counter.count_tokens(key);
235 tokens += count_schema_value(counter, val);
236 }
237 tokens
238 }
239 serde_json::Value::Array(arr) => {
240 let mut tokens = ENUM_ITEM;
241 for item in arr {
242 tokens += count_schema_value(counter, item);
243 }
244 tokens
245 }
246 serde_json::Value::String(s) => counter.count_tokens(s),
247 serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
255
256 static BPE_NONE: Option<CoreBPE> = None;
257
258 fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
259 TokenCounter {
260 bpe: &BPE_NONE,
261 cache: DashMap::new(),
262 cache_cap,
263 }
264 }
265
266 fn make_msg(parts: Vec<MessagePart>) -> Message {
267 Message::from_parts(Role::User, parts)
268 }
269
270 fn make_msg_no_parts(content: &str) -> Message {
271 Message {
272 role: Role::User,
273 content: content.to_string(),
274 parts: vec![],
275 metadata: MessageMetadata::default(),
276 }
277 }
278
279 #[test]
280 fn count_message_tokens_empty_parts_falls_back_to_content() {
281 let counter = TokenCounter::new();
282 let msg = make_msg_no_parts("hello world");
283 assert_eq!(
284 counter.count_message_tokens(&msg),
285 counter.count_tokens("hello world")
286 );
287 }
288
289 #[test]
290 fn count_message_tokens_text_part_matches_count_tokens() {
291 let counter = TokenCounter::new();
292 let text = "the quick brown fox jumps over the lazy dog";
293 let msg = make_msg(vec![MessagePart::Text {
294 text: text.to_string(),
295 }]);
296 assert_eq!(
297 counter.count_message_tokens(&msg),
298 counter.count_tokens(text)
299 );
300 }
301
302 #[test]
303 fn count_message_tokens_tool_use_exceeds_flattened_content() {
304 let counter = TokenCounter::new();
305 let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
307 let msg = make_msg(vec![MessagePart::ToolUse {
308 id: "toolu_abc".into(),
309 name: "bash".into(),
310 input,
311 }]);
312 let structured = counter.count_message_tokens(&msg);
313 let flattened = counter.count_tokens(&msg.content);
314 assert!(
315 structured > flattened,
316 "structured={structured} should exceed flattened={flattened}"
317 );
318 }
319
320 #[test]
321 fn count_message_tokens_compacted_tool_output_is_small() {
322 let counter = TokenCounter::new();
323 let msg = make_msg(vec![MessagePart::ToolOutput {
325 tool_name: "bash".into(),
326 body: String::new(),
327 compacted_at: Some(1_700_000_000),
328 }]);
329 let tokens = counter.count_message_tokens(&msg);
330 assert!(
332 tokens <= 15,
333 "compacted tool output should be small, got {tokens}"
334 );
335 }
336
337 #[test]
338 fn count_message_tokens_image_returns_constant() {
339 let counter = TokenCounter::new();
340 let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
341 data: vec![0u8; 1000],
342 mime_type: "image/jpeg".into(),
343 }))]);
344 assert_eq!(
345 counter.count_message_tokens(&msg),
346 IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
347 );
348 }
349
350 #[test]
351 fn count_message_tokens_thinking_block_counts_text() {
352 let counter = TokenCounter::new();
353 let thinking = "step by step reasoning about the problem";
354 let signature = "sig";
355 let msg = make_msg(vec![MessagePart::ThinkingBlock {
356 thinking: thinking.to_string(),
357 signature: signature.to_string(),
358 }]);
359 let expected =
360 THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
361 assert_eq!(counter.count_message_tokens(&msg), expected);
362 }
363
364 #[test]
365 fn count_part_tokens_empty_text_returns_zero() {
366 let counter = TokenCounter::new();
367 assert_eq!(
368 counter.count_part_tokens(&MessagePart::Text {
369 text: String::new()
370 }),
371 0
372 );
373 assert_eq!(
374 counter.count_part_tokens(&MessagePart::Text {
375 text: " ".to_string()
376 }),
377 0
378 );
379 assert_eq!(
380 counter.count_part_tokens(&MessagePart::Recall {
381 text: "\n\t".to_string()
382 }),
383 0
384 );
385 }
386
387 #[test]
388 fn count_message_tokens_push_recompute_consistency() {
389 let counter = TokenCounter::new();
391 let parts = vec![
392 MessagePart::Text {
393 text: "hello".into(),
394 },
395 MessagePart::ToolOutput {
396 tool_name: "bash".into(),
397 body: "output data".into(),
398 compacted_at: None,
399 },
400 ];
401 let msg = make_msg(parts);
402 let total = counter.count_message_tokens(&msg);
403 let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
404 assert_eq!(total, sum);
405 }
406
407 #[test]
408 fn count_message_tokens_parts_take_priority_over_content() {
409 let counter = TokenCounter::new();
411 let parts_text = "hello from parts";
412 let msg = Message {
413 role: Role::User,
414 content: "completely different content that should be ignored".to_string(),
415 parts: vec![MessagePart::Text {
416 text: parts_text.to_string(),
417 }],
418 metadata: MessageMetadata::default(),
419 };
420 let parts_based = counter.count_tokens(parts_text);
421 let content_based = counter.count_tokens(&msg.content);
422 assert_ne!(
423 parts_based, content_based,
424 "test setup: parts and content must differ"
425 );
426 assert_eq!(counter.count_message_tokens(&msg), parts_based);
427 }
428
429 #[test]
430 fn count_part_tokens_tool_result() {
431 let counter = TokenCounter::new();
433 let tool_use_id = "toolu_xyz";
434 let content = "result text";
435 let part = MessagePart::ToolResult {
436 tool_use_id: tool_use_id.to_string(),
437 content: content.to_string(),
438 is_error: false,
439 };
440 let expected = TOOL_RESULT_OVERHEAD
441 + counter.count_tokens(tool_use_id)
442 + counter.count_tokens(content);
443 assert_eq!(counter.count_part_tokens(&part), expected);
444 }
445
446 #[test]
447 fn count_tokens_empty() {
448 let counter = TokenCounter::new();
449 assert_eq!(counter.count_tokens(""), 0);
450 }
451
452 #[test]
453 fn count_tokens_non_empty() {
454 let counter = TokenCounter::new();
455 assert!(counter.count_tokens("hello world") > 0);
456 }
457
458 #[test]
459 fn count_tokens_cache_hit_returns_same() {
460 let counter = TokenCounter::new();
461 let text = "the quick brown fox";
462 let first = counter.count_tokens(text);
463 let second = counter.count_tokens(text);
464 assert_eq!(first, second);
465 }
466
467 #[test]
468 fn count_tokens_fallback_mode() {
469 let counter = counter_with_no_bpe(CACHE_CAP);
470 assert_eq!(counter.count_tokens("abcdefgh"), 2);
472 assert_eq!(counter.count_tokens(""), 0);
473 }
474
475 #[test]
476 fn count_tokens_oversized_input_uses_fallback_without_cache() {
477 let counter = TokenCounter::new();
478 let large = "a".repeat(MAX_INPUT_LEN + 1);
480 let result = counter.count_tokens(&large);
481 assert_eq!(result, zeph_common::text::estimate_tokens(&large));
483 assert!(counter.cache.is_empty());
485 }
486
487 #[test]
488 fn count_tokens_unicode_bpe_differs_from_fallback() {
489 let counter = TokenCounter::new();
490 let text = "Привет мир! 你好世界! こんにちは! 🌍";
491 let bpe_count = counter.count_tokens(text);
492 let fallback_count = zeph_common::text::estimate_tokens(text);
493 assert!(bpe_count > 0, "BPE count must be positive");
495 assert_ne!(
497 bpe_count, fallback_count,
498 "BPE tokenization should differ from chars/4 for unicode text"
499 );
500 }
501
502 #[test]
503 fn count_tool_schema_tokens_sample() {
504 let counter = TokenCounter::new();
505 let schema = serde_json::json!({
506 "name": "get_weather",
507 "description": "Get the current weather for a location",
508 "parameters": {
509 "type": "object",
510 "properties": {
511 "location": {
512 "type": "string",
513 "description": "The city name"
514 }
515 },
516 "required": ["location"]
517 }
518 });
519 let tokens = counter.count_tool_schema_tokens(&schema);
520 assert_eq!(tokens, 82);
523 }
524
525 #[test]
526 fn two_instances_share_bpe_pointer() {
527 let a = TokenCounter::new();
528 let b = TokenCounter::new();
529 assert!(std::ptr::eq(a.bpe, b.bpe));
530 }
531
532 #[test]
533 fn cache_eviction_at_capacity() {
534 let counter = counter_with_no_bpe(3);
535 let _ = counter.count_tokens("aaaa");
536 let _ = counter.count_tokens("bbbb");
537 let _ = counter.count_tokens("cccc");
538 assert_eq!(counter.cache.len(), 3);
539 let _ = counter.count_tokens("dddd");
541 assert_eq!(counter.cache.len(), 3);
542 }
543}