1use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_common::text::estimate_tokens;
11use zeph_llm::provider::{Message, MessagePart};
12
13static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
14
15const CACHE_CAP: usize = 10_000;
16const MAX_INPUT_LEN: usize = 65_536;
19
20const FUNC_INIT: usize = 7;
22const PROP_INIT: usize = 3;
23const PROP_KEY: usize = 3;
24const ENUM_INIT: isize = -3;
25const ENUM_ITEM: usize = 3;
26const FUNC_END: usize = 12;
27
28const TOOL_USE_OVERHEAD: usize = 20;
31const TOOL_RESULT_OVERHEAD: usize = 15;
33const TOOL_OUTPUT_OVERHEAD: usize = 8;
35const IMAGE_OVERHEAD: usize = 50;
37const IMAGE_DEFAULT_TOKENS: usize = 1000;
39const THINKING_OVERHEAD: usize = 10;
41
42pub struct TokenCounter {
61 bpe: &'static Option<CoreBPE>,
62 cache: DashMap<u64, usize>,
63 cache_cap: usize,
64}
65
66impl TokenCounter {
67 #[must_use]
71 pub fn new() -> Self {
72 let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
73 Ok(b) => Some(b),
74 Err(e) => {
75 tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
76 None
77 }
78 });
79 Self {
80 bpe,
81 cache: DashMap::new(),
82 cache_cap: CACHE_CAP,
83 }
84 }
85
86 #[must_use]
91 pub fn count_tokens(&self, text: &str) -> usize {
92 if text.is_empty() {
93 return 0;
94 }
95
96 if text.len() > MAX_INPUT_LEN {
97 return zeph_common::text::estimate_tokens(text);
98 }
99
100 let key = hash_text(text);
101
102 if let Some(cached) = self.cache.get(&key) {
103 return *cached;
104 }
105
106 let count = match self.bpe {
107 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
108 None => zeph_common::text::estimate_tokens(text),
109 };
110
111 if self.cache.len() >= self.cache_cap {
114 let key_to_evict = self.cache.iter().next().map(|e| *e.key());
115 if let Some(k) = key_to_evict {
116 self.cache.remove(&k);
117 }
118 }
119 self.cache.insert(key, count);
120
121 count
122 }
123
124 #[must_use]
129 pub fn count_message_tokens(&self, msg: &Message) -> usize {
130 if msg.parts.is_empty() {
131 return self.count_tokens(&msg.content);
132 }
133 msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
134 }
135
136 #[must_use]
138 fn count_part_tokens(&self, part: &MessagePart) -> usize {
139 match part {
140 MessagePart::Text { text }
141 | MessagePart::Recall { text }
142 | MessagePart::CodeContext { text }
143 | MessagePart::Summary { text }
144 | MessagePart::CrossSession { text } => {
145 if text.trim().is_empty() {
146 return 0;
147 }
148 self.count_tokens(text)
149 }
150
151 MessagePart::ToolOutput {
154 tool_name, body, ..
155 } => {
156 TOOL_OUTPUT_OVERHEAD
157 + self.count_tokens(tool_name.as_str())
158 + self.count_tokens(body)
159 }
160
161 MessagePart::ToolUse { id, name, input } => {
163 TOOL_USE_OVERHEAD
164 + self.count_tokens(id)
165 + self.count_tokens(name)
166 + self.count_tokens(&input.to_string())
167 }
168
169 MessagePart::ToolResult {
171 tool_use_id,
172 content,
173 ..
174 } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
175
176 MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
180
181 MessagePart::ThinkingBlock {
183 thinking,
184 signature,
185 } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
186
187 MessagePart::RedactedThinkingBlock { data } => {
189 THINKING_OVERHEAD + estimate_tokens(data)
190 }
191
192 MessagePart::Compaction { summary } => self.count_tokens(summary),
194
195 _ => 0,
197 }
198 }
199
200 #[must_use]
202 pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
203 let base = count_schema_value(self, schema);
204 let total =
205 base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
206 total.max(0).cast_unsigned()
207 }
208}
209
210impl Default for TokenCounter {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216impl zeph_common::memory::TokenCounting for TokenCounter {
217 fn count_tokens(&self, text: &str) -> usize {
218 self.count_tokens(text)
219 }
220
221 fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
222 self.count_tool_schema_tokens(schema)
223 }
224}
225
226fn hash_text(text: &str) -> u64 {
227 let mut hasher = DefaultHasher::new();
228 text.hash(&mut hasher);
229 hasher.finish()
230}
231
232fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
233 match value {
234 serde_json::Value::Object(map) => {
235 let mut tokens = PROP_INIT;
236 for (key, val) in map {
237 tokens += PROP_KEY + counter.count_tokens(key);
238 tokens += count_schema_value(counter, val);
239 }
240 tokens
241 }
242 serde_json::Value::Array(arr) => {
243 let mut tokens = ENUM_ITEM;
244 for item in arr {
245 tokens += count_schema_value(counter, item);
246 }
247 tokens
248 }
249 serde_json::Value::String(s) => counter.count_tokens(s),
250 serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
258
259 static BPE_NONE: Option<CoreBPE> = None;
260
261 fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
262 TokenCounter {
263 bpe: &BPE_NONE,
264 cache: DashMap::new(),
265 cache_cap,
266 }
267 }
268
269 fn make_msg(parts: Vec<MessagePart>) -> Message {
270 Message::from_parts(Role::User, parts)
271 }
272
273 fn make_msg_no_parts(content: &str) -> Message {
274 Message {
275 role: Role::User,
276 content: content.to_string(),
277 parts: vec![],
278 metadata: MessageMetadata::default(),
279 }
280 }
281
282 #[test]
283 fn count_message_tokens_empty_parts_falls_back_to_content() {
284 let counter = TokenCounter::new();
285 let msg = make_msg_no_parts("hello world");
286 assert_eq!(
287 counter.count_message_tokens(&msg),
288 counter.count_tokens("hello world")
289 );
290 }
291
292 #[test]
293 fn count_message_tokens_text_part_matches_count_tokens() {
294 let counter = TokenCounter::new();
295 let text = "the quick brown fox jumps over the lazy dog";
296 let msg = make_msg(vec![MessagePart::Text {
297 text: text.to_string(),
298 }]);
299 assert_eq!(
300 counter.count_message_tokens(&msg),
301 counter.count_tokens(text)
302 );
303 }
304
305 #[test]
306 fn count_message_tokens_tool_use_exceeds_flattened_content() {
307 let counter = TokenCounter::new();
308 let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
310 let msg = make_msg(vec![MessagePart::ToolUse {
311 id: "toolu_abc".into(),
312 name: "bash".into(),
313 input,
314 }]);
315 let structured = counter.count_message_tokens(&msg);
316 let flattened = counter.count_tokens(&msg.content);
317 assert!(
318 structured > flattened,
319 "structured={structured} should exceed flattened={flattened}"
320 );
321 }
322
323 #[test]
324 fn count_message_tokens_compacted_tool_output_is_small() {
325 let counter = TokenCounter::new();
326 let msg = make_msg(vec![MessagePart::ToolOutput {
328 tool_name: "bash".into(),
329 body: String::new(),
330 compacted_at: Some(1_700_000_000),
331 }]);
332 let tokens = counter.count_message_tokens(&msg);
333 assert!(
335 tokens <= 15,
336 "compacted tool output should be small, got {tokens}"
337 );
338 }
339
340 #[test]
341 fn count_message_tokens_image_returns_constant() {
342 let counter = TokenCounter::new();
343 let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
344 data: vec![0u8; 1000],
345 mime_type: "image/jpeg".into(),
346 }))]);
347 assert_eq!(
348 counter.count_message_tokens(&msg),
349 IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
350 );
351 }
352
353 #[test]
354 fn count_message_tokens_thinking_block_counts_text() {
355 let counter = TokenCounter::new();
356 let thinking = "step by step reasoning about the problem";
357 let signature = "sig";
358 let msg = make_msg(vec![MessagePart::ThinkingBlock {
359 thinking: thinking.to_string(),
360 signature: signature.to_string(),
361 }]);
362 let expected =
363 THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
364 assert_eq!(counter.count_message_tokens(&msg), expected);
365 }
366
367 #[test]
368 fn count_part_tokens_empty_text_returns_zero() {
369 let counter = TokenCounter::new();
370 assert_eq!(
371 counter.count_part_tokens(&MessagePart::Text {
372 text: String::new()
373 }),
374 0
375 );
376 assert_eq!(
377 counter.count_part_tokens(&MessagePart::Text {
378 text: " ".to_string()
379 }),
380 0
381 );
382 assert_eq!(
383 counter.count_part_tokens(&MessagePart::Recall {
384 text: "\n\t".to_string()
385 }),
386 0
387 );
388 }
389
390 #[test]
391 fn count_message_tokens_push_recompute_consistency() {
392 let counter = TokenCounter::new();
394 let parts = vec![
395 MessagePart::Text {
396 text: "hello".into(),
397 },
398 MessagePart::ToolOutput {
399 tool_name: "bash".into(),
400 body: "output data".into(),
401 compacted_at: None,
402 },
403 ];
404 let msg = make_msg(parts);
405 let total = counter.count_message_tokens(&msg);
406 let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
407 assert_eq!(total, sum);
408 }
409
410 #[test]
411 fn count_message_tokens_parts_take_priority_over_content() {
412 let counter = TokenCounter::new();
414 let parts_text = "hello from parts";
415 let msg = Message {
416 role: Role::User,
417 content: "completely different content that should be ignored".to_string(),
418 parts: vec![MessagePart::Text {
419 text: parts_text.to_string(),
420 }],
421 metadata: MessageMetadata::default(),
422 };
423 let parts_based = counter.count_tokens(parts_text);
424 let content_based = counter.count_tokens(&msg.content);
425 assert_ne!(
426 parts_based, content_based,
427 "test setup: parts and content must differ"
428 );
429 assert_eq!(counter.count_message_tokens(&msg), parts_based);
430 }
431
432 #[test]
433 fn count_part_tokens_tool_result() {
434 let counter = TokenCounter::new();
436 let tool_use_id = "toolu_xyz";
437 let content = "result text";
438 let part = MessagePart::ToolResult {
439 tool_use_id: tool_use_id.to_string(),
440 content: content.to_string(),
441 is_error: false,
442 };
443 let expected = TOOL_RESULT_OVERHEAD
444 + counter.count_tokens(tool_use_id)
445 + counter.count_tokens(content);
446 assert_eq!(counter.count_part_tokens(&part), expected);
447 }
448
449 #[test]
450 fn count_tokens_empty() {
451 let counter = TokenCounter::new();
452 assert_eq!(counter.count_tokens(""), 0);
453 }
454
455 #[test]
456 fn count_tokens_non_empty() {
457 let counter = TokenCounter::new();
458 assert!(counter.count_tokens("hello world") > 0);
459 }
460
461 #[test]
462 fn count_tokens_cache_hit_returns_same() {
463 let counter = TokenCounter::new();
464 let text = "the quick brown fox";
465 let first = counter.count_tokens(text);
466 let second = counter.count_tokens(text);
467 assert_eq!(first, second);
468 }
469
470 #[test]
471 fn count_tokens_fallback_mode() {
472 let counter = counter_with_no_bpe(CACHE_CAP);
473 assert_eq!(counter.count_tokens("abcdefgh"), 2);
475 assert_eq!(counter.count_tokens(""), 0);
476 }
477
478 #[test]
479 fn count_tokens_oversized_input_uses_fallback_without_cache() {
480 let counter = TokenCounter::new();
481 let large = "a".repeat(MAX_INPUT_LEN + 1);
483 let result = counter.count_tokens(&large);
484 assert_eq!(result, zeph_common::text::estimate_tokens(&large));
486 assert!(counter.cache.is_empty());
488 }
489
490 #[test]
491 fn count_tokens_unicode_bpe_differs_from_fallback() {
492 let counter = TokenCounter::new();
493 let text = "Привет мир! 你好世界! こんにちは! 🌍";
494 let bpe_count = counter.count_tokens(text);
495 let fallback_count = zeph_common::text::estimate_tokens(text);
496 assert!(bpe_count > 0, "BPE count must be positive");
498 assert_ne!(
500 bpe_count, fallback_count,
501 "BPE tokenization should differ from chars/4 for unicode text"
502 );
503 }
504
505 #[test]
506 fn count_tool_schema_tokens_sample() {
507 let counter = TokenCounter::new();
508 let schema = serde_json::json!({
509 "name": "get_weather",
510 "description": "Get the current weather for a location",
511 "parameters": {
512 "type": "object",
513 "properties": {
514 "location": {
515 "type": "string",
516 "description": "The city name"
517 }
518 },
519 "required": ["location"]
520 }
521 });
522 let tokens = counter.count_tool_schema_tokens(&schema);
523 assert_eq!(tokens, 82);
526 }
527
528 #[test]
529 fn two_instances_share_bpe_pointer() {
530 let a = TokenCounter::new();
531 let b = TokenCounter::new();
532 assert!(std::ptr::eq(a.bpe, b.bpe));
533 }
534
535 #[test]
536 fn cache_eviction_at_capacity() {
537 let counter = counter_with_no_bpe(3);
538 let _ = counter.count_tokens("aaaa");
539 let _ = counter.count_tokens("bbbb");
540 let _ = counter.count_tokens("cccc");
541 assert_eq!(counter.cache.len(), 3);
542 let _ = counter.count_tokens("dddd");
544 assert_eq!(counter.cache.len(), 3);
545 }
546}