Skip to main content

zeph_memory/
token_counter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6
7use dashmap::DashMap;
8use tiktoken_rs::CoreBPE;
9use zeph_llm::provider::{Message, MessagePart};
10
11const CACHE_CAP: usize = 10_000;
12/// Inputs larger than this limit bypass BPE encoding and use the chars/4 fallback.
13/// Prevents CPU amplification from pathologically large inputs.
14const MAX_INPUT_LEN: usize = 65_536;
15
16// OpenAI function-calling token overhead constants
17const FUNC_INIT: usize = 7;
18const PROP_INIT: usize = 3;
19const PROP_KEY: usize = 3;
20const ENUM_INIT: isize = -3;
21const ENUM_ITEM: usize = 3;
22const FUNC_END: usize = 12;
23
24// Structural overhead per part type (approximate token counts for JSON framing)
25/// `{"type":"tool_use","id":"","name":"","input":}`
26const TOOL_USE_OVERHEAD: usize = 20;
27/// `{"type":"tool_result","tool_use_id":"","content":""}`
28const TOOL_RESULT_OVERHEAD: usize = 15;
29/// `[tool output: <name>]\n` wrapping
30const TOOL_OUTPUT_OVERHEAD: usize = 8;
31/// Image block JSON structure overhead (dimension-based counting unavailable)
32const IMAGE_OVERHEAD: usize = 50;
33/// Default token estimate for a typical image (dimensions not available in `ImageData`)
34const IMAGE_DEFAULT_TOKENS: usize = 1000;
35/// Thinking/redacted block framing
36const THINKING_OVERHEAD: usize = 10;
37
38pub struct TokenCounter {
39    bpe: Option<CoreBPE>,
40    cache: DashMap<u64, usize>,
41    cache_cap: usize,
42}
43
44impl TokenCounter {
45    /// Create a new counter. Falls back to chars/4 if tiktoken init fails.
46    #[must_use]
47    pub fn new() -> Self {
48        let bpe = match tiktoken_rs::cl100k_base() {
49            Ok(b) => Some(b),
50            Err(e) => {
51                tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
52                None
53            }
54        };
55        Self {
56            bpe,
57            cache: DashMap::new(),
58            cache_cap: CACHE_CAP,
59        }
60    }
61
62    /// Count tokens in text. Uses cache, falls back to heuristic.
63    ///
64    /// Inputs exceeding 64 KiB bypass BPE and use chars/4 without caching to
65    /// avoid CPU amplification from oversized inputs.
66    #[must_use]
67    pub fn count_tokens(&self, text: &str) -> usize {
68        if text.is_empty() {
69            return 0;
70        }
71
72        if text.len() > MAX_INPUT_LEN {
73            return text.chars().count() / 4;
74        }
75
76        let key = hash_text(text);
77
78        if let Some(cached) = self.cache.get(&key) {
79            return *cached;
80        }
81
82        let count = match &self.bpe {
83            Some(bpe) => bpe.encode_with_special_tokens(text).len(),
84            None => text.chars().count() / 4,
85        };
86
87        // TOCTOU between len() check and insert is benign: worst case we evict
88        // one extra entry and temporarily exceed the cap by one slot.
89        if self.cache.len() >= self.cache_cap {
90            let key_to_evict = self.cache.iter().next().map(|e| *e.key());
91            if let Some(k) = key_to_evict {
92                self.cache.remove(&k);
93            }
94        }
95        self.cache.insert(key, count);
96
97        count
98    }
99
100    /// Estimate token count for a message the way the LLM API will see it.
101    ///
102    /// When structured parts exist, counts from parts matching the API payload
103    /// structure. Falls back to `content` (flattened text) when parts is empty.
104    #[must_use]
105    pub fn count_message_tokens(&self, msg: &Message) -> usize {
106        if msg.parts.is_empty() {
107            return self.count_tokens(&msg.content);
108        }
109        msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
110    }
111
112    /// Estimate tokens for a single [`MessagePart`] matching the API payload structure.
113    #[must_use]
114    fn count_part_tokens(&self, part: &MessagePart) -> usize {
115        match part {
116            MessagePart::Text { text }
117            | MessagePart::Recall { text }
118            | MessagePart::CodeContext { text }
119            | MessagePart::Summary { text }
120            | MessagePart::CrossSession { text } => {
121                if text.trim().is_empty() {
122                    return 0;
123                }
124                self.count_tokens(text)
125            }
126
127            // API always emits `[tool output: {name}]\n{body}` regardless of compacted_at.
128            // When body is emptied by compaction, count_tokens(body) returns 0 naturally.
129            MessagePart::ToolOutput {
130                tool_name, body, ..
131            } => TOOL_OUTPUT_OVERHEAD + self.count_tokens(tool_name) + self.count_tokens(body),
132
133            // API sends structured JSON block: `{"type":"tool_use","id":"...","name":"...","input":...}`
134            MessagePart::ToolUse { id, name, input } => {
135                TOOL_USE_OVERHEAD
136                    + self.count_tokens(id)
137                    + self.count_tokens(name)
138                    + self.count_tokens(&input.to_string())
139            }
140
141            // API sends structured block: `{"type":"tool_result","tool_use_id":"...","content":"..."}`
142            MessagePart::ToolResult {
143                tool_use_id,
144                content,
145                ..
146            } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
147
148            // Image token count depends on pixel dimensions, which are unavailable in ImageData.
149            // Using a fixed constant is more accurate than bytes-based formula because
150            // Claude's actual formula is width*height based, not payload-size based.
151            MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
152
153            // ThinkingBlock is preserved verbatim in multi-turn requests.
154            MessagePart::ThinkingBlock {
155                thinking,
156                signature,
157            } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
158
159            // RedactedThinkingBlock is an opaque base64 blob — BPE is not meaningful here.
160            MessagePart::RedactedThinkingBlock { data } => THINKING_OVERHEAD + data.len() / 4,
161        }
162    }
163
164    /// Count tokens for an `OpenAI` tool/function schema `JSON` value.
165    #[must_use]
166    pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
167        let base = count_schema_value(self, schema);
168        let total =
169            base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
170        total.max(0).cast_unsigned()
171    }
172}
173
174impl Default for TokenCounter {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180fn hash_text(text: &str) -> u64 {
181    let mut hasher = DefaultHasher::new();
182    text.hash(&mut hasher);
183    hasher.finish()
184}
185
186fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
187    match value {
188        serde_json::Value::Object(map) => {
189            let mut tokens = PROP_INIT;
190            for (key, val) in map {
191                tokens += PROP_KEY + counter.count_tokens(key);
192                tokens += count_schema_value(counter, val);
193            }
194            tokens
195        }
196        serde_json::Value::Array(arr) => {
197            let mut tokens = ENUM_ITEM;
198            for item in arr {
199                tokens += count_schema_value(counter, item);
200            }
201            tokens
202        }
203        serde_json::Value::String(s) => counter.count_tokens(s),
204        serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
212
213    fn make_msg(parts: Vec<MessagePart>) -> Message {
214        Message::from_parts(Role::User, parts)
215    }
216
217    fn make_msg_no_parts(content: &str) -> Message {
218        Message {
219            role: Role::User,
220            content: content.to_string(),
221            parts: vec![],
222            metadata: MessageMetadata::default(),
223        }
224    }
225
226    #[test]
227    fn count_message_tokens_empty_parts_falls_back_to_content() {
228        let counter = TokenCounter::new();
229        let msg = make_msg_no_parts("hello world");
230        assert_eq!(
231            counter.count_message_tokens(&msg),
232            counter.count_tokens("hello world")
233        );
234    }
235
236    #[test]
237    fn count_message_tokens_text_part_matches_count_tokens() {
238        let counter = TokenCounter::new();
239        let text = "the quick brown fox jumps over the lazy dog";
240        let msg = make_msg(vec![MessagePart::Text {
241            text: text.to_string(),
242        }]);
243        assert_eq!(
244            counter.count_message_tokens(&msg),
245            counter.count_tokens(text)
246        );
247    }
248
249    #[test]
250    fn count_message_tokens_tool_use_exceeds_flattened_content() {
251        let counter = TokenCounter::new();
252        // Large JSON input: structured counting should be higher than flattened "[tool_use: bash(id)]"
253        let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
254        let msg = make_msg(vec![MessagePart::ToolUse {
255            id: "toolu_abc".into(),
256            name: "bash".into(),
257            input,
258        }]);
259        let structured = counter.count_message_tokens(&msg);
260        let flattened = counter.count_tokens(&msg.content);
261        assert!(
262            structured > flattened,
263            "structured={structured} should exceed flattened={flattened}"
264        );
265    }
266
267    #[test]
268    fn count_message_tokens_compacted_tool_output_is_small() {
269        let counter = TokenCounter::new();
270        // Compacted ToolOutput has empty body — should count close to overhead only
271        let msg = make_msg(vec![MessagePart::ToolOutput {
272            tool_name: "bash".into(),
273            body: String::new(),
274            compacted_at: Some(1_700_000_000),
275        }]);
276        let tokens = counter.count_message_tokens(&msg);
277        // Should be small: TOOL_OUTPUT_OVERHEAD + count_tokens("bash") + 0
278        assert!(
279            tokens <= 15,
280            "compacted tool output should be small, got {tokens}"
281        );
282    }
283
284    #[test]
285    fn count_message_tokens_image_returns_constant() {
286        let counter = TokenCounter::new();
287        let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
288            data: vec![0u8; 1000],
289            mime_type: "image/jpeg".into(),
290        }))]);
291        assert_eq!(
292            counter.count_message_tokens(&msg),
293            IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
294        );
295    }
296
297    #[test]
298    fn count_message_tokens_thinking_block_counts_text() {
299        let counter = TokenCounter::new();
300        let thinking = "step by step reasoning about the problem";
301        let signature = "sig";
302        let msg = make_msg(vec![MessagePart::ThinkingBlock {
303            thinking: thinking.to_string(),
304            signature: signature.to_string(),
305        }]);
306        let expected =
307            THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
308        assert_eq!(counter.count_message_tokens(&msg), expected);
309    }
310
311    #[test]
312    fn count_part_tokens_empty_text_returns_zero() {
313        let counter = TokenCounter::new();
314        assert_eq!(
315            counter.count_part_tokens(&MessagePart::Text {
316                text: String::new()
317            }),
318            0
319        );
320        assert_eq!(
321            counter.count_part_tokens(&MessagePart::Text {
322                text: "   ".to_string()
323            }),
324            0
325        );
326        assert_eq!(
327            counter.count_part_tokens(&MessagePart::Recall {
328                text: "\n\t".to_string()
329            }),
330            0
331        );
332    }
333
334    #[test]
335    fn count_message_tokens_push_recompute_consistency() {
336        // Verify that sum of count_message_tokens per part equals recompute result
337        let counter = TokenCounter::new();
338        let parts = vec![
339            MessagePart::Text {
340                text: "hello".into(),
341            },
342            MessagePart::ToolOutput {
343                tool_name: "bash".into(),
344                body: "output data".into(),
345                compacted_at: None,
346            },
347        ];
348        let msg = make_msg(parts);
349        let total = counter.count_message_tokens(&msg);
350        let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
351        assert_eq!(total, sum);
352    }
353
354    #[test]
355    fn count_message_tokens_parts_take_priority_over_content() {
356        // R-2: primary regression guard — when parts is non-empty, content is ignored.
357        let counter = TokenCounter::new();
358        let parts_text = "hello from parts";
359        let msg = Message {
360            role: Role::User,
361            content: "completely different content that should be ignored".to_string(),
362            parts: vec![MessagePart::Text {
363                text: parts_text.to_string(),
364            }],
365            metadata: MessageMetadata::default(),
366        };
367        let parts_based = counter.count_tokens(parts_text);
368        let content_based = counter.count_tokens(&msg.content);
369        assert_ne!(
370            parts_based, content_based,
371            "test setup: parts and content must differ"
372        );
373        assert_eq!(counter.count_message_tokens(&msg), parts_based);
374    }
375
376    #[test]
377    fn count_part_tokens_tool_result() {
378        // R-3: verify ToolResult arm counting
379        let counter = TokenCounter::new();
380        let tool_use_id = "toolu_xyz";
381        let content = "result text";
382        let part = MessagePart::ToolResult {
383            tool_use_id: tool_use_id.to_string(),
384            content: content.to_string(),
385            is_error: false,
386        };
387        let expected = TOOL_RESULT_OVERHEAD
388            + counter.count_tokens(tool_use_id)
389            + counter.count_tokens(content);
390        assert_eq!(counter.count_part_tokens(&part), expected);
391    }
392
393    #[test]
394    fn count_tokens_empty() {
395        let counter = TokenCounter::new();
396        assert_eq!(counter.count_tokens(""), 0);
397    }
398
399    #[test]
400    fn count_tokens_non_empty() {
401        let counter = TokenCounter::new();
402        assert!(counter.count_tokens("hello world") > 0);
403    }
404
405    #[test]
406    fn count_tokens_cache_hit_returns_same() {
407        let counter = TokenCounter::new();
408        let text = "the quick brown fox";
409        let first = counter.count_tokens(text);
410        let second = counter.count_tokens(text);
411        assert_eq!(first, second);
412    }
413
414    #[test]
415    fn count_tokens_fallback_mode() {
416        let counter = TokenCounter {
417            bpe: None,
418            cache: DashMap::new(),
419            cache_cap: CACHE_CAP,
420        };
421        // 8 chars / 4 = 2
422        assert_eq!(counter.count_tokens("abcdefgh"), 2);
423        assert_eq!(counter.count_tokens(""), 0);
424    }
425
426    #[test]
427    fn count_tokens_oversized_input_uses_fallback_without_cache() {
428        let counter = TokenCounter::new();
429        // Generate input larger than MAX_INPUT_LEN (65536 bytes)
430        let large = "a".repeat(MAX_INPUT_LEN + 1);
431        let result = counter.count_tokens(&large);
432        // chars/4 fallback: (65537 chars) / 4
433        assert_eq!(result, large.chars().count() / 4);
434        // Must not be cached
435        assert!(counter.cache.is_empty());
436    }
437
438    #[test]
439    fn count_tokens_unicode_bpe_differs_from_fallback() {
440        let counter = TokenCounter::new();
441        let text = "Привет мир! 你好世界! こんにちは! 🌍";
442        let bpe_count = counter.count_tokens(text);
443        let fallback_count = text.chars().count() / 4;
444        // BPE should return > 0
445        assert!(bpe_count > 0, "BPE count must be positive");
446        // BPE result should differ from naive chars/4 for multi-byte text
447        assert_ne!(
448            bpe_count, fallback_count,
449            "BPE tokenization should differ from chars/4 for unicode text"
450        );
451    }
452
453    #[test]
454    fn count_tool_schema_tokens_sample() {
455        let counter = TokenCounter::new();
456        let schema = serde_json::json!({
457            "name": "get_weather",
458            "description": "Get the current weather for a location",
459            "parameters": {
460                "type": "object",
461                "properties": {
462                    "location": {
463                        "type": "string",
464                        "description": "The city name"
465                    }
466                },
467                "required": ["location"]
468            }
469        });
470        let tokens = counter.count_tool_schema_tokens(&schema);
471        // Pinned value: computed by running the formula with cl100k_base on this exact schema.
472        // If this fails, a tokenizer or formula change likely occurred.
473        assert_eq!(tokens, 82);
474    }
475
476    #[test]
477    fn cache_eviction_at_capacity() {
478        let counter = TokenCounter {
479            bpe: None,
480            cache: DashMap::new(),
481            cache_cap: 3,
482        };
483        let _ = counter.count_tokens("aaaa");
484        let _ = counter.count_tokens("bbbb");
485        let _ = counter.count_tokens("cccc");
486        assert_eq!(counter.cache.len(), 3);
487        // This should evict one entry
488        let _ = counter.count_tokens("dddd");
489        assert_eq!(counter.cache.len(), 3);
490    }
491}