Skip to main content

zeph_memory/
token_counter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6
7use dashmap::DashMap;
8use tiktoken_rs::CoreBPE;
9use zeph_llm::provider::{Message, MessagePart};
10
11const CACHE_CAP: usize = 10_000;
12/// Inputs larger than this limit bypass BPE encoding and use the chars/4 fallback.
13/// Prevents CPU amplification from pathologically large inputs.
14const MAX_INPUT_LEN: usize = 65_536;
15
16// OpenAI function-calling token overhead constants
17const FUNC_INIT: usize = 7;
18const PROP_INIT: usize = 3;
19const PROP_KEY: usize = 3;
20const ENUM_INIT: isize = -3;
21const ENUM_ITEM: usize = 3;
22const FUNC_END: usize = 12;
23
24// Structural overhead per part type (approximate token counts for JSON framing)
25/// `{"type":"tool_use","id":"","name":"","input":}`
26const TOOL_USE_OVERHEAD: usize = 20;
27/// `{"type":"tool_result","tool_use_id":"","content":""}`
28const TOOL_RESULT_OVERHEAD: usize = 15;
29/// `[tool output: <name>]\n` wrapping
30const TOOL_OUTPUT_OVERHEAD: usize = 8;
31/// Image block JSON structure overhead (dimension-based counting unavailable)
32const IMAGE_OVERHEAD: usize = 50;
33/// Default token estimate for a typical image (dimensions not available in `ImageData`)
34const IMAGE_DEFAULT_TOKENS: usize = 1000;
35/// Thinking/redacted block framing
36const THINKING_OVERHEAD: usize = 10;
37
38pub struct TokenCounter {
39    bpe: Option<CoreBPE>,
40    cache: DashMap<u64, usize>,
41    cache_cap: usize,
42}
43
44impl TokenCounter {
45    /// Create a new counter. Falls back to chars/4 if tiktoken init fails.
46    #[must_use]
47    pub fn new() -> Self {
48        let bpe = match tiktoken_rs::cl100k_base() {
49            Ok(b) => Some(b),
50            Err(e) => {
51                tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
52                None
53            }
54        };
55        Self {
56            bpe,
57            cache: DashMap::new(),
58            cache_cap: CACHE_CAP,
59        }
60    }
61
62    /// Count tokens in text. Uses cache, falls back to heuristic.
63    ///
64    /// Inputs exceeding 64 KiB bypass BPE and use chars/4 without caching to
65    /// avoid CPU amplification from oversized inputs.
66    #[must_use]
67    pub fn count_tokens(&self, text: &str) -> usize {
68        if text.is_empty() {
69            return 0;
70        }
71
72        if text.len() > MAX_INPUT_LEN {
73            return text.chars().count() / 4;
74        }
75
76        let key = hash_text(text);
77
78        if let Some(cached) = self.cache.get(&key) {
79            return *cached;
80        }
81
82        let count = match &self.bpe {
83            Some(bpe) => bpe.encode_with_special_tokens(text).len(),
84            None => text.chars().count() / 4,
85        };
86
87        // TOCTOU between len() check and insert is benign: worst case we evict
88        // one extra entry and temporarily exceed the cap by one slot.
89        if self.cache.len() >= self.cache_cap {
90            let key_to_evict = self.cache.iter().next().map(|e| *e.key());
91            if let Some(k) = key_to_evict {
92                self.cache.remove(&k);
93            }
94        }
95        self.cache.insert(key, count);
96
97        count
98    }
99
100    /// Estimate token count for a message the way the LLM API will see it.
101    ///
102    /// When structured parts exist, counts from parts matching the API payload
103    /// structure. Falls back to `content` (flattened text) when parts is empty.
104    #[must_use]
105    pub fn count_message_tokens(&self, msg: &Message) -> usize {
106        if msg.parts.is_empty() {
107            return self.count_tokens(&msg.content);
108        }
109        msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
110    }
111
112    /// Estimate tokens for a single [`MessagePart`] matching the API payload structure.
113    #[must_use]
114    fn count_part_tokens(&self, part: &MessagePart) -> usize {
115        match part {
116            MessagePart::Text { text }
117            | MessagePart::Recall { text }
118            | MessagePart::CodeContext { text }
119            | MessagePart::Summary { text }
120            | MessagePart::CrossSession { text } => {
121                if text.trim().is_empty() {
122                    return 0;
123                }
124                self.count_tokens(text)
125            }
126
127            // API always emits `[tool output: {name}]\n{body}` regardless of compacted_at.
128            // When body is emptied by compaction, count_tokens(body) returns 0 naturally.
129            MessagePart::ToolOutput {
130                tool_name, body, ..
131            } => TOOL_OUTPUT_OVERHEAD + self.count_tokens(tool_name) + self.count_tokens(body),
132
133            // API sends structured JSON block: `{"type":"tool_use","id":"...","name":"...","input":...}`
134            MessagePart::ToolUse { id, name, input } => {
135                TOOL_USE_OVERHEAD
136                    + self.count_tokens(id)
137                    + self.count_tokens(name)
138                    + self.count_tokens(&input.to_string())
139            }
140
141            // API sends structured block: `{"type":"tool_result","tool_use_id":"...","content":"..."}`
142            MessagePart::ToolResult {
143                tool_use_id,
144                content,
145                ..
146            } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
147
148            // Image token count depends on pixel dimensions, which are unavailable in ImageData.
149            // Using a fixed constant is more accurate than bytes-based formula because
150            // Claude's actual formula is width*height based, not payload-size based.
151            MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
152
153            // ThinkingBlock is preserved verbatim in multi-turn requests.
154            MessagePart::ThinkingBlock {
155                thinking,
156                signature,
157            } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
158
159            // RedactedThinkingBlock is an opaque base64 blob — BPE is not meaningful here.
160            MessagePart::RedactedThinkingBlock { data } => THINKING_OVERHEAD + data.len() / 4,
161
162            // Compaction summary is sent back verbatim to the API.
163            MessagePart::Compaction { summary } => self.count_tokens(summary),
164        }
165    }
166
167    /// Count tokens for an `OpenAI` tool/function schema `JSON` value.
168    #[must_use]
169    pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
170        let base = count_schema_value(self, schema);
171        let total =
172            base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
173        total.max(0).cast_unsigned()
174    }
175}
176
177impl Default for TokenCounter {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183fn hash_text(text: &str) -> u64 {
184    let mut hasher = DefaultHasher::new();
185    text.hash(&mut hasher);
186    hasher.finish()
187}
188
189fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
190    match value {
191        serde_json::Value::Object(map) => {
192            let mut tokens = PROP_INIT;
193            for (key, val) in map {
194                tokens += PROP_KEY + counter.count_tokens(key);
195                tokens += count_schema_value(counter, val);
196            }
197            tokens
198        }
199        serde_json::Value::Array(arr) => {
200            let mut tokens = ENUM_ITEM;
201            for item in arr {
202                tokens += count_schema_value(counter, item);
203            }
204            tokens
205        }
206        serde_json::Value::String(s) => counter.count_tokens(s),
207        serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
215
216    fn make_msg(parts: Vec<MessagePart>) -> Message {
217        Message::from_parts(Role::User, parts)
218    }
219
220    fn make_msg_no_parts(content: &str) -> Message {
221        Message {
222            role: Role::User,
223            content: content.to_string(),
224            parts: vec![],
225            metadata: MessageMetadata::default(),
226        }
227    }
228
229    #[test]
230    fn count_message_tokens_empty_parts_falls_back_to_content() {
231        let counter = TokenCounter::new();
232        let msg = make_msg_no_parts("hello world");
233        assert_eq!(
234            counter.count_message_tokens(&msg),
235            counter.count_tokens("hello world")
236        );
237    }
238
239    #[test]
240    fn count_message_tokens_text_part_matches_count_tokens() {
241        let counter = TokenCounter::new();
242        let text = "the quick brown fox jumps over the lazy dog";
243        let msg = make_msg(vec![MessagePart::Text {
244            text: text.to_string(),
245        }]);
246        assert_eq!(
247            counter.count_message_tokens(&msg),
248            counter.count_tokens(text)
249        );
250    }
251
252    #[test]
253    fn count_message_tokens_tool_use_exceeds_flattened_content() {
254        let counter = TokenCounter::new();
255        // Large JSON input: structured counting should be higher than flattened "[tool_use: bash(id)]"
256        let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
257        let msg = make_msg(vec![MessagePart::ToolUse {
258            id: "toolu_abc".into(),
259            name: "bash".into(),
260            input,
261        }]);
262        let structured = counter.count_message_tokens(&msg);
263        let flattened = counter.count_tokens(&msg.content);
264        assert!(
265            structured > flattened,
266            "structured={structured} should exceed flattened={flattened}"
267        );
268    }
269
270    #[test]
271    fn count_message_tokens_compacted_tool_output_is_small() {
272        let counter = TokenCounter::new();
273        // Compacted ToolOutput has empty body — should count close to overhead only
274        let msg = make_msg(vec![MessagePart::ToolOutput {
275            tool_name: "bash".into(),
276            body: String::new(),
277            compacted_at: Some(1_700_000_000),
278        }]);
279        let tokens = counter.count_message_tokens(&msg);
280        // Should be small: TOOL_OUTPUT_OVERHEAD + count_tokens("bash") + 0
281        assert!(
282            tokens <= 15,
283            "compacted tool output should be small, got {tokens}"
284        );
285    }
286
287    #[test]
288    fn count_message_tokens_image_returns_constant() {
289        let counter = TokenCounter::new();
290        let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
291            data: vec![0u8; 1000],
292            mime_type: "image/jpeg".into(),
293        }))]);
294        assert_eq!(
295            counter.count_message_tokens(&msg),
296            IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
297        );
298    }
299
300    #[test]
301    fn count_message_tokens_thinking_block_counts_text() {
302        let counter = TokenCounter::new();
303        let thinking = "step by step reasoning about the problem";
304        let signature = "sig";
305        let msg = make_msg(vec![MessagePart::ThinkingBlock {
306            thinking: thinking.to_string(),
307            signature: signature.to_string(),
308        }]);
309        let expected =
310            THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
311        assert_eq!(counter.count_message_tokens(&msg), expected);
312    }
313
314    #[test]
315    fn count_part_tokens_empty_text_returns_zero() {
316        let counter = TokenCounter::new();
317        assert_eq!(
318            counter.count_part_tokens(&MessagePart::Text {
319                text: String::new()
320            }),
321            0
322        );
323        assert_eq!(
324            counter.count_part_tokens(&MessagePart::Text {
325                text: "   ".to_string()
326            }),
327            0
328        );
329        assert_eq!(
330            counter.count_part_tokens(&MessagePart::Recall {
331                text: "\n\t".to_string()
332            }),
333            0
334        );
335    }
336
337    #[test]
338    fn count_message_tokens_push_recompute_consistency() {
339        // Verify that sum of count_message_tokens per part equals recompute result
340        let counter = TokenCounter::new();
341        let parts = vec![
342            MessagePart::Text {
343                text: "hello".into(),
344            },
345            MessagePart::ToolOutput {
346                tool_name: "bash".into(),
347                body: "output data".into(),
348                compacted_at: None,
349            },
350        ];
351        let msg = make_msg(parts);
352        let total = counter.count_message_tokens(&msg);
353        let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
354        assert_eq!(total, sum);
355    }
356
357    #[test]
358    fn count_message_tokens_parts_take_priority_over_content() {
359        // R-2: primary regression guard — when parts is non-empty, content is ignored.
360        let counter = TokenCounter::new();
361        let parts_text = "hello from parts";
362        let msg = Message {
363            role: Role::User,
364            content: "completely different content that should be ignored".to_string(),
365            parts: vec![MessagePart::Text {
366                text: parts_text.to_string(),
367            }],
368            metadata: MessageMetadata::default(),
369        };
370        let parts_based = counter.count_tokens(parts_text);
371        let content_based = counter.count_tokens(&msg.content);
372        assert_ne!(
373            parts_based, content_based,
374            "test setup: parts and content must differ"
375        );
376        assert_eq!(counter.count_message_tokens(&msg), parts_based);
377    }
378
379    #[test]
380    fn count_part_tokens_tool_result() {
381        // R-3: verify ToolResult arm counting
382        let counter = TokenCounter::new();
383        let tool_use_id = "toolu_xyz";
384        let content = "result text";
385        let part = MessagePart::ToolResult {
386            tool_use_id: tool_use_id.to_string(),
387            content: content.to_string(),
388            is_error: false,
389        };
390        let expected = TOOL_RESULT_OVERHEAD
391            + counter.count_tokens(tool_use_id)
392            + counter.count_tokens(content);
393        assert_eq!(counter.count_part_tokens(&part), expected);
394    }
395
396    #[test]
397    fn count_tokens_empty() {
398        let counter = TokenCounter::new();
399        assert_eq!(counter.count_tokens(""), 0);
400    }
401
402    #[test]
403    fn count_tokens_non_empty() {
404        let counter = TokenCounter::new();
405        assert!(counter.count_tokens("hello world") > 0);
406    }
407
408    #[test]
409    fn count_tokens_cache_hit_returns_same() {
410        let counter = TokenCounter::new();
411        let text = "the quick brown fox";
412        let first = counter.count_tokens(text);
413        let second = counter.count_tokens(text);
414        assert_eq!(first, second);
415    }
416
417    #[test]
418    fn count_tokens_fallback_mode() {
419        let counter = TokenCounter {
420            bpe: None,
421            cache: DashMap::new(),
422            cache_cap: CACHE_CAP,
423        };
424        // 8 chars / 4 = 2
425        assert_eq!(counter.count_tokens("abcdefgh"), 2);
426        assert_eq!(counter.count_tokens(""), 0);
427    }
428
429    #[test]
430    fn count_tokens_oversized_input_uses_fallback_without_cache() {
431        let counter = TokenCounter::new();
432        // Generate input larger than MAX_INPUT_LEN (65536 bytes)
433        let large = "a".repeat(MAX_INPUT_LEN + 1);
434        let result = counter.count_tokens(&large);
435        // chars/4 fallback: (65537 chars) / 4
436        assert_eq!(result, large.chars().count() / 4);
437        // Must not be cached
438        assert!(counter.cache.is_empty());
439    }
440
441    #[test]
442    fn count_tokens_unicode_bpe_differs_from_fallback() {
443        let counter = TokenCounter::new();
444        let text = "Привет мир! 你好世界! こんにちは! 🌍";
445        let bpe_count = counter.count_tokens(text);
446        let fallback_count = text.chars().count() / 4;
447        // BPE should return > 0
448        assert!(bpe_count > 0, "BPE count must be positive");
449        // BPE result should differ from naive chars/4 for multi-byte text
450        assert_ne!(
451            bpe_count, fallback_count,
452            "BPE tokenization should differ from chars/4 for unicode text"
453        );
454    }
455
456    #[test]
457    fn count_tool_schema_tokens_sample() {
458        let counter = TokenCounter::new();
459        let schema = serde_json::json!({
460            "name": "get_weather",
461            "description": "Get the current weather for a location",
462            "parameters": {
463                "type": "object",
464                "properties": {
465                    "location": {
466                        "type": "string",
467                        "description": "The city name"
468                    }
469                },
470                "required": ["location"]
471            }
472        });
473        let tokens = counter.count_tool_schema_tokens(&schema);
474        // Pinned value: computed by running the formula with cl100k_base on this exact schema.
475        // If this fails, a tokenizer or formula change likely occurred.
476        assert_eq!(tokens, 82);
477    }
478
479    #[test]
480    fn cache_eviction_at_capacity() {
481        let counter = TokenCounter {
482            bpe: None,
483            cache: DashMap::new(),
484            cache_cap: 3,
485        };
486        let _ = counter.count_tokens("aaaa");
487        let _ = counter.count_tokens("bbbb");
488        let _ = counter.count_tokens("cccc");
489        assert_eq!(counter.cache.len(), 3);
490        // This should evict one entry
491        let _ = counter.count_tokens("dddd");
492        assert_eq!(counter.cache.len(), 3);
493    }
494}