1use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6
7use dashmap::DashMap;
8use tiktoken_rs::CoreBPE;
9use zeph_llm::provider::{Message, MessagePart};
10
11const CACHE_CAP: usize = 10_000;
12const MAX_INPUT_LEN: usize = 65_536;
15
16const FUNC_INIT: usize = 7;
18const PROP_INIT: usize = 3;
19const PROP_KEY: usize = 3;
20const ENUM_INIT: isize = -3;
21const ENUM_ITEM: usize = 3;
22const FUNC_END: usize = 12;
23
24const TOOL_USE_OVERHEAD: usize = 20;
27const TOOL_RESULT_OVERHEAD: usize = 15;
29const TOOL_OUTPUT_OVERHEAD: usize = 8;
31const IMAGE_OVERHEAD: usize = 50;
33const IMAGE_DEFAULT_TOKENS: usize = 1000;
35const THINKING_OVERHEAD: usize = 10;
37
38pub struct TokenCounter {
39 bpe: Option<CoreBPE>,
40 cache: DashMap<u64, usize>,
41 cache_cap: usize,
42}
43
44impl TokenCounter {
45 #[must_use]
47 pub fn new() -> Self {
48 let bpe = match tiktoken_rs::cl100k_base() {
49 Ok(b) => Some(b),
50 Err(e) => {
51 tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
52 None
53 }
54 };
55 Self {
56 bpe,
57 cache: DashMap::new(),
58 cache_cap: CACHE_CAP,
59 }
60 }
61
62 #[must_use]
67 pub fn count_tokens(&self, text: &str) -> usize {
68 if text.is_empty() {
69 return 0;
70 }
71
72 if text.len() > MAX_INPUT_LEN {
73 return text.chars().count() / 4;
74 }
75
76 let key = hash_text(text);
77
78 if let Some(cached) = self.cache.get(&key) {
79 return *cached;
80 }
81
82 let count = match &self.bpe {
83 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
84 None => text.chars().count() / 4,
85 };
86
87 if self.cache.len() >= self.cache_cap {
90 let key_to_evict = self.cache.iter().next().map(|e| *e.key());
91 if let Some(k) = key_to_evict {
92 self.cache.remove(&k);
93 }
94 }
95 self.cache.insert(key, count);
96
97 count
98 }
99
100 #[must_use]
105 pub fn count_message_tokens(&self, msg: &Message) -> usize {
106 if msg.parts.is_empty() {
107 return self.count_tokens(&msg.content);
108 }
109 msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
110 }
111
112 #[must_use]
114 fn count_part_tokens(&self, part: &MessagePart) -> usize {
115 match part {
116 MessagePart::Text { text }
117 | MessagePart::Recall { text }
118 | MessagePart::CodeContext { text }
119 | MessagePart::Summary { text }
120 | MessagePart::CrossSession { text } => {
121 if text.trim().is_empty() {
122 return 0;
123 }
124 self.count_tokens(text)
125 }
126
127 MessagePart::ToolOutput {
130 tool_name, body, ..
131 } => TOOL_OUTPUT_OVERHEAD + self.count_tokens(tool_name) + self.count_tokens(body),
132
133 MessagePart::ToolUse { id, name, input } => {
135 TOOL_USE_OVERHEAD
136 + self.count_tokens(id)
137 + self.count_tokens(name)
138 + self.count_tokens(&input.to_string())
139 }
140
141 MessagePart::ToolResult {
143 tool_use_id,
144 content,
145 ..
146 } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
147
148 MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
152
153 MessagePart::ThinkingBlock {
155 thinking,
156 signature,
157 } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
158
159 MessagePart::RedactedThinkingBlock { data } => THINKING_OVERHEAD + data.len() / 4,
161
162 MessagePart::Compaction { summary } => self.count_tokens(summary),
164 }
165 }
166
167 #[must_use]
169 pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
170 let base = count_schema_value(self, schema);
171 let total =
172 base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
173 total.max(0).cast_unsigned()
174 }
175}
176
177impl Default for TokenCounter {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183fn hash_text(text: &str) -> u64 {
184 let mut hasher = DefaultHasher::new();
185 text.hash(&mut hasher);
186 hasher.finish()
187}
188
189fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
190 match value {
191 serde_json::Value::Object(map) => {
192 let mut tokens = PROP_INIT;
193 for (key, val) in map {
194 tokens += PROP_KEY + counter.count_tokens(key);
195 tokens += count_schema_value(counter, val);
196 }
197 tokens
198 }
199 serde_json::Value::Array(arr) => {
200 let mut tokens = ENUM_ITEM;
201 for item in arr {
202 tokens += count_schema_value(counter, item);
203 }
204 tokens
205 }
206 serde_json::Value::String(s) => counter.count_tokens(s),
207 serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
215
216 fn make_msg(parts: Vec<MessagePart>) -> Message {
217 Message::from_parts(Role::User, parts)
218 }
219
220 fn make_msg_no_parts(content: &str) -> Message {
221 Message {
222 role: Role::User,
223 content: content.to_string(),
224 parts: vec![],
225 metadata: MessageMetadata::default(),
226 }
227 }
228
229 #[test]
230 fn count_message_tokens_empty_parts_falls_back_to_content() {
231 let counter = TokenCounter::new();
232 let msg = make_msg_no_parts("hello world");
233 assert_eq!(
234 counter.count_message_tokens(&msg),
235 counter.count_tokens("hello world")
236 );
237 }
238
239 #[test]
240 fn count_message_tokens_text_part_matches_count_tokens() {
241 let counter = TokenCounter::new();
242 let text = "the quick brown fox jumps over the lazy dog";
243 let msg = make_msg(vec![MessagePart::Text {
244 text: text.to_string(),
245 }]);
246 assert_eq!(
247 counter.count_message_tokens(&msg),
248 counter.count_tokens(text)
249 );
250 }
251
252 #[test]
253 fn count_message_tokens_tool_use_exceeds_flattened_content() {
254 let counter = TokenCounter::new();
255 let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
257 let msg = make_msg(vec![MessagePart::ToolUse {
258 id: "toolu_abc".into(),
259 name: "bash".into(),
260 input,
261 }]);
262 let structured = counter.count_message_tokens(&msg);
263 let flattened = counter.count_tokens(&msg.content);
264 assert!(
265 structured > flattened,
266 "structured={structured} should exceed flattened={flattened}"
267 );
268 }
269
270 #[test]
271 fn count_message_tokens_compacted_tool_output_is_small() {
272 let counter = TokenCounter::new();
273 let msg = make_msg(vec![MessagePart::ToolOutput {
275 tool_name: "bash".into(),
276 body: String::new(),
277 compacted_at: Some(1_700_000_000),
278 }]);
279 let tokens = counter.count_message_tokens(&msg);
280 assert!(
282 tokens <= 15,
283 "compacted tool output should be small, got {tokens}"
284 );
285 }
286
287 #[test]
288 fn count_message_tokens_image_returns_constant() {
289 let counter = TokenCounter::new();
290 let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
291 data: vec![0u8; 1000],
292 mime_type: "image/jpeg".into(),
293 }))]);
294 assert_eq!(
295 counter.count_message_tokens(&msg),
296 IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
297 );
298 }
299
300 #[test]
301 fn count_message_tokens_thinking_block_counts_text() {
302 let counter = TokenCounter::new();
303 let thinking = "step by step reasoning about the problem";
304 let signature = "sig";
305 let msg = make_msg(vec![MessagePart::ThinkingBlock {
306 thinking: thinking.to_string(),
307 signature: signature.to_string(),
308 }]);
309 let expected =
310 THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
311 assert_eq!(counter.count_message_tokens(&msg), expected);
312 }
313
314 #[test]
315 fn count_part_tokens_empty_text_returns_zero() {
316 let counter = TokenCounter::new();
317 assert_eq!(
318 counter.count_part_tokens(&MessagePart::Text {
319 text: String::new()
320 }),
321 0
322 );
323 assert_eq!(
324 counter.count_part_tokens(&MessagePart::Text {
325 text: " ".to_string()
326 }),
327 0
328 );
329 assert_eq!(
330 counter.count_part_tokens(&MessagePart::Recall {
331 text: "\n\t".to_string()
332 }),
333 0
334 );
335 }
336
337 #[test]
338 fn count_message_tokens_push_recompute_consistency() {
339 let counter = TokenCounter::new();
341 let parts = vec![
342 MessagePart::Text {
343 text: "hello".into(),
344 },
345 MessagePart::ToolOutput {
346 tool_name: "bash".into(),
347 body: "output data".into(),
348 compacted_at: None,
349 },
350 ];
351 let msg = make_msg(parts);
352 let total = counter.count_message_tokens(&msg);
353 let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
354 assert_eq!(total, sum);
355 }
356
357 #[test]
358 fn count_message_tokens_parts_take_priority_over_content() {
359 let counter = TokenCounter::new();
361 let parts_text = "hello from parts";
362 let msg = Message {
363 role: Role::User,
364 content: "completely different content that should be ignored".to_string(),
365 parts: vec![MessagePart::Text {
366 text: parts_text.to_string(),
367 }],
368 metadata: MessageMetadata::default(),
369 };
370 let parts_based = counter.count_tokens(parts_text);
371 let content_based = counter.count_tokens(&msg.content);
372 assert_ne!(
373 parts_based, content_based,
374 "test setup: parts and content must differ"
375 );
376 assert_eq!(counter.count_message_tokens(&msg), parts_based);
377 }
378
379 #[test]
380 fn count_part_tokens_tool_result() {
381 let counter = TokenCounter::new();
383 let tool_use_id = "toolu_xyz";
384 let content = "result text";
385 let part = MessagePart::ToolResult {
386 tool_use_id: tool_use_id.to_string(),
387 content: content.to_string(),
388 is_error: false,
389 };
390 let expected = TOOL_RESULT_OVERHEAD
391 + counter.count_tokens(tool_use_id)
392 + counter.count_tokens(content);
393 assert_eq!(counter.count_part_tokens(&part), expected);
394 }
395
396 #[test]
397 fn count_tokens_empty() {
398 let counter = TokenCounter::new();
399 assert_eq!(counter.count_tokens(""), 0);
400 }
401
402 #[test]
403 fn count_tokens_non_empty() {
404 let counter = TokenCounter::new();
405 assert!(counter.count_tokens("hello world") > 0);
406 }
407
408 #[test]
409 fn count_tokens_cache_hit_returns_same() {
410 let counter = TokenCounter::new();
411 let text = "the quick brown fox";
412 let first = counter.count_tokens(text);
413 let second = counter.count_tokens(text);
414 assert_eq!(first, second);
415 }
416
417 #[test]
418 fn count_tokens_fallback_mode() {
419 let counter = TokenCounter {
420 bpe: None,
421 cache: DashMap::new(),
422 cache_cap: CACHE_CAP,
423 };
424 assert_eq!(counter.count_tokens("abcdefgh"), 2);
426 assert_eq!(counter.count_tokens(""), 0);
427 }
428
429 #[test]
430 fn count_tokens_oversized_input_uses_fallback_without_cache() {
431 let counter = TokenCounter::new();
432 let large = "a".repeat(MAX_INPUT_LEN + 1);
434 let result = counter.count_tokens(&large);
435 assert_eq!(result, large.chars().count() / 4);
437 assert!(counter.cache.is_empty());
439 }
440
441 #[test]
442 fn count_tokens_unicode_bpe_differs_from_fallback() {
443 let counter = TokenCounter::new();
444 let text = "Привет мир! 你好世界! こんにちは! 🌍";
445 let bpe_count = counter.count_tokens(text);
446 let fallback_count = text.chars().count() / 4;
447 assert!(bpe_count > 0, "BPE count must be positive");
449 assert_ne!(
451 bpe_count, fallback_count,
452 "BPE tokenization should differ from chars/4 for unicode text"
453 );
454 }
455
456 #[test]
457 fn count_tool_schema_tokens_sample() {
458 let counter = TokenCounter::new();
459 let schema = serde_json::json!({
460 "name": "get_weather",
461 "description": "Get the current weather for a location",
462 "parameters": {
463 "type": "object",
464 "properties": {
465 "location": {
466 "type": "string",
467 "description": "The city name"
468 }
469 },
470 "required": ["location"]
471 }
472 });
473 let tokens = counter.count_tool_schema_tokens(&schema);
474 assert_eq!(tokens, 82);
477 }
478
479 #[test]
480 fn cache_eviction_at_capacity() {
481 let counter = TokenCounter {
482 bpe: None,
483 cache: DashMap::new(),
484 cache_cap: 3,
485 };
486 let _ = counter.count_tokens("aaaa");
487 let _ = counter.count_tokens("bbbb");
488 let _ = counter.count_tokens("cccc");
489 assert_eq!(counter.cache.len(), 3);
490 let _ = counter.count_tokens("dddd");
492 assert_eq!(counter.cache.len(), 3);
493 }
494}