1use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_common::text::estimate_tokens;
11use zeph_llm::provider::{Message, MessagePart};
12
13static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
14
15const CACHE_CAP: usize = 10_000;
16const MAX_INPUT_LEN: usize = 65_536;
19
20const FUNC_INIT: usize = 7;
22const PROP_INIT: usize = 3;
23const PROP_KEY: usize = 3;
24const ENUM_INIT: isize = -3;
25const ENUM_ITEM: usize = 3;
26const FUNC_END: usize = 12;
27
28const TOOL_USE_OVERHEAD: usize = 20;
31const TOOL_RESULT_OVERHEAD: usize = 15;
33const TOOL_OUTPUT_OVERHEAD: usize = 8;
35const IMAGE_OVERHEAD: usize = 50;
37const IMAGE_DEFAULT_TOKENS: usize = 1000;
39const THINKING_OVERHEAD: usize = 10;
41
42pub struct TokenCounter {
43 bpe: &'static Option<CoreBPE>,
44 cache: DashMap<u64, usize>,
45 cache_cap: usize,
46}
47
48impl TokenCounter {
49 #[must_use]
53 pub fn new() -> Self {
54 let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
55 Ok(b) => Some(b),
56 Err(e) => {
57 tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
58 None
59 }
60 });
61 Self {
62 bpe,
63 cache: DashMap::new(),
64 cache_cap: CACHE_CAP,
65 }
66 }
67
68 #[must_use]
73 pub fn count_tokens(&self, text: &str) -> usize {
74 if text.is_empty() {
75 return 0;
76 }
77
78 if text.len() > MAX_INPUT_LEN {
79 return zeph_common::text::estimate_tokens(text);
80 }
81
82 let key = hash_text(text);
83
84 if let Some(cached) = self.cache.get(&key) {
85 return *cached;
86 }
87
88 let count = match self.bpe {
89 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
90 None => zeph_common::text::estimate_tokens(text),
91 };
92
93 if self.cache.len() >= self.cache_cap {
96 let key_to_evict = self.cache.iter().next().map(|e| *e.key());
97 if let Some(k) = key_to_evict {
98 self.cache.remove(&k);
99 }
100 }
101 self.cache.insert(key, count);
102
103 count
104 }
105
106 #[must_use]
111 pub fn count_message_tokens(&self, msg: &Message) -> usize {
112 if msg.parts.is_empty() {
113 return self.count_tokens(&msg.content);
114 }
115 msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
116 }
117
118 #[must_use]
120 fn count_part_tokens(&self, part: &MessagePart) -> usize {
121 match part {
122 MessagePart::Text { text }
123 | MessagePart::Recall { text }
124 | MessagePart::CodeContext { text }
125 | MessagePart::Summary { text }
126 | MessagePart::CrossSession { text } => {
127 if text.trim().is_empty() {
128 return 0;
129 }
130 self.count_tokens(text)
131 }
132
133 MessagePart::ToolOutput {
136 tool_name, body, ..
137 } => TOOL_OUTPUT_OVERHEAD + self.count_tokens(tool_name) + self.count_tokens(body),
138
139 MessagePart::ToolUse { id, name, input } => {
141 TOOL_USE_OVERHEAD
142 + self.count_tokens(id)
143 + self.count_tokens(name)
144 + self.count_tokens(&input.to_string())
145 }
146
147 MessagePart::ToolResult {
149 tool_use_id,
150 content,
151 ..
152 } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
153
154 MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
158
159 MessagePart::ThinkingBlock {
161 thinking,
162 signature,
163 } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
164
165 MessagePart::RedactedThinkingBlock { data } => {
167 THINKING_OVERHEAD + estimate_tokens(data)
168 }
169
170 MessagePart::Compaction { summary } => self.count_tokens(summary),
172 }
173 }
174
175 #[must_use]
177 pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
178 let base = count_schema_value(self, schema);
179 let total =
180 base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
181 total.max(0).cast_unsigned()
182 }
183}
184
185impl Default for TokenCounter {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191fn hash_text(text: &str) -> u64 {
192 let mut hasher = DefaultHasher::new();
193 text.hash(&mut hasher);
194 hasher.finish()
195}
196
197fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
198 match value {
199 serde_json::Value::Object(map) => {
200 let mut tokens = PROP_INIT;
201 for (key, val) in map {
202 tokens += PROP_KEY + counter.count_tokens(key);
203 tokens += count_schema_value(counter, val);
204 }
205 tokens
206 }
207 serde_json::Value::Array(arr) => {
208 let mut tokens = ENUM_ITEM;
209 for item in arr {
210 tokens += count_schema_value(counter, item);
211 }
212 tokens
213 }
214 serde_json::Value::String(s) => counter.count_tokens(s),
215 serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
223
224 static BPE_NONE: Option<CoreBPE> = None;
225
226 fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
227 TokenCounter {
228 bpe: &BPE_NONE,
229 cache: DashMap::new(),
230 cache_cap,
231 }
232 }
233
234 fn make_msg(parts: Vec<MessagePart>) -> Message {
235 Message::from_parts(Role::User, parts)
236 }
237
238 fn make_msg_no_parts(content: &str) -> Message {
239 Message {
240 role: Role::User,
241 content: content.to_string(),
242 parts: vec![],
243 metadata: MessageMetadata::default(),
244 }
245 }
246
247 #[test]
248 fn count_message_tokens_empty_parts_falls_back_to_content() {
249 let counter = TokenCounter::new();
250 let msg = make_msg_no_parts("hello world");
251 assert_eq!(
252 counter.count_message_tokens(&msg),
253 counter.count_tokens("hello world")
254 );
255 }
256
257 #[test]
258 fn count_message_tokens_text_part_matches_count_tokens() {
259 let counter = TokenCounter::new();
260 let text = "the quick brown fox jumps over the lazy dog";
261 let msg = make_msg(vec![MessagePart::Text {
262 text: text.to_string(),
263 }]);
264 assert_eq!(
265 counter.count_message_tokens(&msg),
266 counter.count_tokens(text)
267 );
268 }
269
270 #[test]
271 fn count_message_tokens_tool_use_exceeds_flattened_content() {
272 let counter = TokenCounter::new();
273 let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
275 let msg = make_msg(vec![MessagePart::ToolUse {
276 id: "toolu_abc".into(),
277 name: "bash".into(),
278 input,
279 }]);
280 let structured = counter.count_message_tokens(&msg);
281 let flattened = counter.count_tokens(&msg.content);
282 assert!(
283 structured > flattened,
284 "structured={structured} should exceed flattened={flattened}"
285 );
286 }
287
288 #[test]
289 fn count_message_tokens_compacted_tool_output_is_small() {
290 let counter = TokenCounter::new();
291 let msg = make_msg(vec![MessagePart::ToolOutput {
293 tool_name: "bash".into(),
294 body: String::new(),
295 compacted_at: Some(1_700_000_000),
296 }]);
297 let tokens = counter.count_message_tokens(&msg);
298 assert!(
300 tokens <= 15,
301 "compacted tool output should be small, got {tokens}"
302 );
303 }
304
305 #[test]
306 fn count_message_tokens_image_returns_constant() {
307 let counter = TokenCounter::new();
308 let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
309 data: vec![0u8; 1000],
310 mime_type: "image/jpeg".into(),
311 }))]);
312 assert_eq!(
313 counter.count_message_tokens(&msg),
314 IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
315 );
316 }
317
318 #[test]
319 fn count_message_tokens_thinking_block_counts_text() {
320 let counter = TokenCounter::new();
321 let thinking = "step by step reasoning about the problem";
322 let signature = "sig";
323 let msg = make_msg(vec![MessagePart::ThinkingBlock {
324 thinking: thinking.to_string(),
325 signature: signature.to_string(),
326 }]);
327 let expected =
328 THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
329 assert_eq!(counter.count_message_tokens(&msg), expected);
330 }
331
332 #[test]
333 fn count_part_tokens_empty_text_returns_zero() {
334 let counter = TokenCounter::new();
335 assert_eq!(
336 counter.count_part_tokens(&MessagePart::Text {
337 text: String::new()
338 }),
339 0
340 );
341 assert_eq!(
342 counter.count_part_tokens(&MessagePart::Text {
343 text: " ".to_string()
344 }),
345 0
346 );
347 assert_eq!(
348 counter.count_part_tokens(&MessagePart::Recall {
349 text: "\n\t".to_string()
350 }),
351 0
352 );
353 }
354
355 #[test]
356 fn count_message_tokens_push_recompute_consistency() {
357 let counter = TokenCounter::new();
359 let parts = vec![
360 MessagePart::Text {
361 text: "hello".into(),
362 },
363 MessagePart::ToolOutput {
364 tool_name: "bash".into(),
365 body: "output data".into(),
366 compacted_at: None,
367 },
368 ];
369 let msg = make_msg(parts);
370 let total = counter.count_message_tokens(&msg);
371 let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
372 assert_eq!(total, sum);
373 }
374
375 #[test]
376 fn count_message_tokens_parts_take_priority_over_content() {
377 let counter = TokenCounter::new();
379 let parts_text = "hello from parts";
380 let msg = Message {
381 role: Role::User,
382 content: "completely different content that should be ignored".to_string(),
383 parts: vec![MessagePart::Text {
384 text: parts_text.to_string(),
385 }],
386 metadata: MessageMetadata::default(),
387 };
388 let parts_based = counter.count_tokens(parts_text);
389 let content_based = counter.count_tokens(&msg.content);
390 assert_ne!(
391 parts_based, content_based,
392 "test setup: parts and content must differ"
393 );
394 assert_eq!(counter.count_message_tokens(&msg), parts_based);
395 }
396
397 #[test]
398 fn count_part_tokens_tool_result() {
399 let counter = TokenCounter::new();
401 let tool_use_id = "toolu_xyz";
402 let content = "result text";
403 let part = MessagePart::ToolResult {
404 tool_use_id: tool_use_id.to_string(),
405 content: content.to_string(),
406 is_error: false,
407 };
408 let expected = TOOL_RESULT_OVERHEAD
409 + counter.count_tokens(tool_use_id)
410 + counter.count_tokens(content);
411 assert_eq!(counter.count_part_tokens(&part), expected);
412 }
413
414 #[test]
415 fn count_tokens_empty() {
416 let counter = TokenCounter::new();
417 assert_eq!(counter.count_tokens(""), 0);
418 }
419
420 #[test]
421 fn count_tokens_non_empty() {
422 let counter = TokenCounter::new();
423 assert!(counter.count_tokens("hello world") > 0);
424 }
425
426 #[test]
427 fn count_tokens_cache_hit_returns_same() {
428 let counter = TokenCounter::new();
429 let text = "the quick brown fox";
430 let first = counter.count_tokens(text);
431 let second = counter.count_tokens(text);
432 assert_eq!(first, second);
433 }
434
435 #[test]
436 fn count_tokens_fallback_mode() {
437 let counter = counter_with_no_bpe(CACHE_CAP);
438 assert_eq!(counter.count_tokens("abcdefgh"), 2);
440 assert_eq!(counter.count_tokens(""), 0);
441 }
442
443 #[test]
444 fn count_tokens_oversized_input_uses_fallback_without_cache() {
445 let counter = TokenCounter::new();
446 let large = "a".repeat(MAX_INPUT_LEN + 1);
448 let result = counter.count_tokens(&large);
449 assert_eq!(result, zeph_common::text::estimate_tokens(&large));
451 assert!(counter.cache.is_empty());
453 }
454
455 #[test]
456 fn count_tokens_unicode_bpe_differs_from_fallback() {
457 let counter = TokenCounter::new();
458 let text = "Привет мир! 你好世界! こんにちは! 🌍";
459 let bpe_count = counter.count_tokens(text);
460 let fallback_count = zeph_common::text::estimate_tokens(text);
461 assert!(bpe_count > 0, "BPE count must be positive");
463 assert_ne!(
465 bpe_count, fallback_count,
466 "BPE tokenization should differ from chars/4 for unicode text"
467 );
468 }
469
470 #[test]
471 fn count_tool_schema_tokens_sample() {
472 let counter = TokenCounter::new();
473 let schema = serde_json::json!({
474 "name": "get_weather",
475 "description": "Get the current weather for a location",
476 "parameters": {
477 "type": "object",
478 "properties": {
479 "location": {
480 "type": "string",
481 "description": "The city name"
482 }
483 },
484 "required": ["location"]
485 }
486 });
487 let tokens = counter.count_tool_schema_tokens(&schema);
488 assert_eq!(tokens, 82);
491 }
492
493 #[test]
494 fn two_instances_share_bpe_pointer() {
495 let a = TokenCounter::new();
496 let b = TokenCounter::new();
497 assert!(std::ptr::eq(a.bpe, b.bpe));
498 }
499
500 #[test]
501 fn cache_eviction_at_capacity() {
502 let counter = counter_with_no_bpe(3);
503 let _ = counter.count_tokens("aaaa");
504 let _ = counter.count_tokens("bbbb");
505 let _ = counter.count_tokens("cccc");
506 assert_eq!(counter.cache.len(), 3);
507 let _ = counter.count_tokens("dddd");
509 assert_eq!(counter.cache.len(), 3);
510 }
511}