1use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_llm::provider::{Message, MessagePart};
11
12static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
13
14const CACHE_CAP: usize = 10_000;
15const MAX_INPUT_LEN: usize = 65_536;
18
19const FUNC_INIT: usize = 7;
21const PROP_INIT: usize = 3;
22const PROP_KEY: usize = 3;
23const ENUM_INIT: isize = -3;
24const ENUM_ITEM: usize = 3;
25const FUNC_END: usize = 12;
26
27const TOOL_USE_OVERHEAD: usize = 20;
30const TOOL_RESULT_OVERHEAD: usize = 15;
32const TOOL_OUTPUT_OVERHEAD: usize = 8;
34const IMAGE_OVERHEAD: usize = 50;
36const IMAGE_DEFAULT_TOKENS: usize = 1000;
38const THINKING_OVERHEAD: usize = 10;
40
41pub struct TokenCounter {
42 bpe: &'static Option<CoreBPE>,
43 cache: DashMap<u64, usize>,
44 cache_cap: usize,
45}
46
47impl TokenCounter {
48 #[must_use]
52 pub fn new() -> Self {
53 let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
54 Ok(b) => Some(b),
55 Err(e) => {
56 tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
57 None
58 }
59 });
60 Self {
61 bpe,
62 cache: DashMap::new(),
63 cache_cap: CACHE_CAP,
64 }
65 }
66
67 #[must_use]
72 pub fn count_tokens(&self, text: &str) -> usize {
73 if text.is_empty() {
74 return 0;
75 }
76
77 if text.len() > MAX_INPUT_LEN {
78 return zeph_common::text::estimate_tokens(text);
79 }
80
81 let key = hash_text(text);
82
83 if let Some(cached) = self.cache.get(&key) {
84 return *cached;
85 }
86
87 let count = match self.bpe {
88 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
89 None => zeph_common::text::estimate_tokens(text),
90 };
91
92 if self.cache.len() >= self.cache_cap {
95 let key_to_evict = self.cache.iter().next().map(|e| *e.key());
96 if let Some(k) = key_to_evict {
97 self.cache.remove(&k);
98 }
99 }
100 self.cache.insert(key, count);
101
102 count
103 }
104
105 #[must_use]
110 pub fn count_message_tokens(&self, msg: &Message) -> usize {
111 if msg.parts.is_empty() {
112 return self.count_tokens(&msg.content);
113 }
114 msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
115 }
116
117 #[must_use]
119 fn count_part_tokens(&self, part: &MessagePart) -> usize {
120 match part {
121 MessagePart::Text { text }
122 | MessagePart::Recall { text }
123 | MessagePart::CodeContext { text }
124 | MessagePart::Summary { text }
125 | MessagePart::CrossSession { text } => {
126 if text.trim().is_empty() {
127 return 0;
128 }
129 self.count_tokens(text)
130 }
131
132 MessagePart::ToolOutput {
135 tool_name, body, ..
136 } => TOOL_OUTPUT_OVERHEAD + self.count_tokens(tool_name) + self.count_tokens(body),
137
138 MessagePart::ToolUse { id, name, input } => {
140 TOOL_USE_OVERHEAD
141 + self.count_tokens(id)
142 + self.count_tokens(name)
143 + self.count_tokens(&input.to_string())
144 }
145
146 MessagePart::ToolResult {
148 tool_use_id,
149 content,
150 ..
151 } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
152
153 MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
157
158 MessagePart::ThinkingBlock {
160 thinking,
161 signature,
162 } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
163
164 MessagePart::RedactedThinkingBlock { data } => THINKING_OVERHEAD + data.len() / 4,
166
167 MessagePart::Compaction { summary } => self.count_tokens(summary),
169 }
170 }
171
172 #[must_use]
174 pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
175 let base = count_schema_value(self, schema);
176 let total =
177 base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
178 total.max(0).cast_unsigned()
179 }
180}
181
182impl Default for TokenCounter {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188fn hash_text(text: &str) -> u64 {
189 let mut hasher = DefaultHasher::new();
190 text.hash(&mut hasher);
191 hasher.finish()
192}
193
194fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
195 match value {
196 serde_json::Value::Object(map) => {
197 let mut tokens = PROP_INIT;
198 for (key, val) in map {
199 tokens += PROP_KEY + counter.count_tokens(key);
200 tokens += count_schema_value(counter, val);
201 }
202 tokens
203 }
204 serde_json::Value::Array(arr) => {
205 let mut tokens = ENUM_ITEM;
206 for item in arr {
207 tokens += count_schema_value(counter, item);
208 }
209 tokens
210 }
211 serde_json::Value::String(s) => counter.count_tokens(s),
212 serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
220
221 static BPE_NONE: Option<CoreBPE> = None;
222
223 fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
224 TokenCounter {
225 bpe: &BPE_NONE,
226 cache: DashMap::new(),
227 cache_cap,
228 }
229 }
230
231 fn make_msg(parts: Vec<MessagePart>) -> Message {
232 Message::from_parts(Role::User, parts)
233 }
234
235 fn make_msg_no_parts(content: &str) -> Message {
236 Message {
237 role: Role::User,
238 content: content.to_string(),
239 parts: vec![],
240 metadata: MessageMetadata::default(),
241 }
242 }
243
244 #[test]
245 fn count_message_tokens_empty_parts_falls_back_to_content() {
246 let counter = TokenCounter::new();
247 let msg = make_msg_no_parts("hello world");
248 assert_eq!(
249 counter.count_message_tokens(&msg),
250 counter.count_tokens("hello world")
251 );
252 }
253
254 #[test]
255 fn count_message_tokens_text_part_matches_count_tokens() {
256 let counter = TokenCounter::new();
257 let text = "the quick brown fox jumps over the lazy dog";
258 let msg = make_msg(vec![MessagePart::Text {
259 text: text.to_string(),
260 }]);
261 assert_eq!(
262 counter.count_message_tokens(&msg),
263 counter.count_tokens(text)
264 );
265 }
266
267 #[test]
268 fn count_message_tokens_tool_use_exceeds_flattened_content() {
269 let counter = TokenCounter::new();
270 let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
272 let msg = make_msg(vec![MessagePart::ToolUse {
273 id: "toolu_abc".into(),
274 name: "bash".into(),
275 input,
276 }]);
277 let structured = counter.count_message_tokens(&msg);
278 let flattened = counter.count_tokens(&msg.content);
279 assert!(
280 structured > flattened,
281 "structured={structured} should exceed flattened={flattened}"
282 );
283 }
284
285 #[test]
286 fn count_message_tokens_compacted_tool_output_is_small() {
287 let counter = TokenCounter::new();
288 let msg = make_msg(vec![MessagePart::ToolOutput {
290 tool_name: "bash".into(),
291 body: String::new(),
292 compacted_at: Some(1_700_000_000),
293 }]);
294 let tokens = counter.count_message_tokens(&msg);
295 assert!(
297 tokens <= 15,
298 "compacted tool output should be small, got {tokens}"
299 );
300 }
301
302 #[test]
303 fn count_message_tokens_image_returns_constant() {
304 let counter = TokenCounter::new();
305 let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
306 data: vec![0u8; 1000],
307 mime_type: "image/jpeg".into(),
308 }))]);
309 assert_eq!(
310 counter.count_message_tokens(&msg),
311 IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
312 );
313 }
314
315 #[test]
316 fn count_message_tokens_thinking_block_counts_text() {
317 let counter = TokenCounter::new();
318 let thinking = "step by step reasoning about the problem";
319 let signature = "sig";
320 let msg = make_msg(vec![MessagePart::ThinkingBlock {
321 thinking: thinking.to_string(),
322 signature: signature.to_string(),
323 }]);
324 let expected =
325 THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
326 assert_eq!(counter.count_message_tokens(&msg), expected);
327 }
328
329 #[test]
330 fn count_part_tokens_empty_text_returns_zero() {
331 let counter = TokenCounter::new();
332 assert_eq!(
333 counter.count_part_tokens(&MessagePart::Text {
334 text: String::new()
335 }),
336 0
337 );
338 assert_eq!(
339 counter.count_part_tokens(&MessagePart::Text {
340 text: " ".to_string()
341 }),
342 0
343 );
344 assert_eq!(
345 counter.count_part_tokens(&MessagePart::Recall {
346 text: "\n\t".to_string()
347 }),
348 0
349 );
350 }
351
352 #[test]
353 fn count_message_tokens_push_recompute_consistency() {
354 let counter = TokenCounter::new();
356 let parts = vec![
357 MessagePart::Text {
358 text: "hello".into(),
359 },
360 MessagePart::ToolOutput {
361 tool_name: "bash".into(),
362 body: "output data".into(),
363 compacted_at: None,
364 },
365 ];
366 let msg = make_msg(parts);
367 let total = counter.count_message_tokens(&msg);
368 let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
369 assert_eq!(total, sum);
370 }
371
372 #[test]
373 fn count_message_tokens_parts_take_priority_over_content() {
374 let counter = TokenCounter::new();
376 let parts_text = "hello from parts";
377 let msg = Message {
378 role: Role::User,
379 content: "completely different content that should be ignored".to_string(),
380 parts: vec![MessagePart::Text {
381 text: parts_text.to_string(),
382 }],
383 metadata: MessageMetadata::default(),
384 };
385 let parts_based = counter.count_tokens(parts_text);
386 let content_based = counter.count_tokens(&msg.content);
387 assert_ne!(
388 parts_based, content_based,
389 "test setup: parts and content must differ"
390 );
391 assert_eq!(counter.count_message_tokens(&msg), parts_based);
392 }
393
394 #[test]
395 fn count_part_tokens_tool_result() {
396 let counter = TokenCounter::new();
398 let tool_use_id = "toolu_xyz";
399 let content = "result text";
400 let part = MessagePart::ToolResult {
401 tool_use_id: tool_use_id.to_string(),
402 content: content.to_string(),
403 is_error: false,
404 };
405 let expected = TOOL_RESULT_OVERHEAD
406 + counter.count_tokens(tool_use_id)
407 + counter.count_tokens(content);
408 assert_eq!(counter.count_part_tokens(&part), expected);
409 }
410
411 #[test]
412 fn count_tokens_empty() {
413 let counter = TokenCounter::new();
414 assert_eq!(counter.count_tokens(""), 0);
415 }
416
417 #[test]
418 fn count_tokens_non_empty() {
419 let counter = TokenCounter::new();
420 assert!(counter.count_tokens("hello world") > 0);
421 }
422
423 #[test]
424 fn count_tokens_cache_hit_returns_same() {
425 let counter = TokenCounter::new();
426 let text = "the quick brown fox";
427 let first = counter.count_tokens(text);
428 let second = counter.count_tokens(text);
429 assert_eq!(first, second);
430 }
431
432 #[test]
433 fn count_tokens_fallback_mode() {
434 let counter = counter_with_no_bpe(CACHE_CAP);
435 assert_eq!(counter.count_tokens("abcdefgh"), 2);
437 assert_eq!(counter.count_tokens(""), 0);
438 }
439
440 #[test]
441 fn count_tokens_oversized_input_uses_fallback_without_cache() {
442 let counter = TokenCounter::new();
443 let large = "a".repeat(MAX_INPUT_LEN + 1);
445 let result = counter.count_tokens(&large);
446 assert_eq!(result, zeph_common::text::estimate_tokens(&large));
448 assert!(counter.cache.is_empty());
450 }
451
452 #[test]
453 fn count_tokens_unicode_bpe_differs_from_fallback() {
454 let counter = TokenCounter::new();
455 let text = "Привет мир! 你好世界! こんにちは! 🌍";
456 let bpe_count = counter.count_tokens(text);
457 let fallback_count = zeph_common::text::estimate_tokens(text);
458 assert!(bpe_count > 0, "BPE count must be positive");
460 assert_ne!(
462 bpe_count, fallback_count,
463 "BPE tokenization should differ from chars/4 for unicode text"
464 );
465 }
466
467 #[test]
468 fn count_tool_schema_tokens_sample() {
469 let counter = TokenCounter::new();
470 let schema = serde_json::json!({
471 "name": "get_weather",
472 "description": "Get the current weather for a location",
473 "parameters": {
474 "type": "object",
475 "properties": {
476 "location": {
477 "type": "string",
478 "description": "The city name"
479 }
480 },
481 "required": ["location"]
482 }
483 });
484 let tokens = counter.count_tool_schema_tokens(&schema);
485 assert_eq!(tokens, 82);
488 }
489
490 #[test]
491 fn two_instances_share_bpe_pointer() {
492 let a = TokenCounter::new();
493 let b = TokenCounter::new();
494 assert!(std::ptr::eq(a.bpe, b.bpe));
495 }
496
497 #[test]
498 fn cache_eviction_at_capacity() {
499 let counter = counter_with_no_bpe(3);
500 let _ = counter.count_tokens("aaaa");
501 let _ = counter.count_tokens("bbbb");
502 let _ = counter.count_tokens("cccc");
503 assert_eq!(counter.cache.len(), 3);
504 let _ = counter.count_tokens("dddd");
506 assert_eq!(counter.cache.len(), 3);
507 }
508}