Skip to main content

zeph_memory/
token_counter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_common::text::estimate_tokens;
11use zeph_llm::provider::{Message, MessagePart};
12
13static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
14
15const CACHE_CAP: usize = 10_000;
16/// Inputs larger than this limit bypass BPE encoding and use the chars/4 fallback.
17/// Prevents CPU amplification from pathologically large inputs.
18const MAX_INPUT_LEN: usize = 65_536;
19
20// OpenAI function-calling token overhead constants
21const FUNC_INIT: usize = 7;
22const PROP_INIT: usize = 3;
23const PROP_KEY: usize = 3;
24const ENUM_INIT: isize = -3;
25const ENUM_ITEM: usize = 3;
26const FUNC_END: usize = 12;
27
28// Structural overhead per part type (approximate token counts for JSON framing)
29/// `{"type":"tool_use","id":"","name":"","input":}`
30const TOOL_USE_OVERHEAD: usize = 20;
31/// `{"type":"tool_result","tool_use_id":"","content":""}`
32const TOOL_RESULT_OVERHEAD: usize = 15;
33/// `[tool output: <name>]\n` wrapping
34const TOOL_OUTPUT_OVERHEAD: usize = 8;
35/// Image block JSON structure overhead (dimension-based counting unavailable)
36const IMAGE_OVERHEAD: usize = 50;
37/// Default token estimate for a typical image (dimensions not available in `ImageData`)
38const IMAGE_DEFAULT_TOKENS: usize = 1000;
39/// Thinking/redacted block framing
40const THINKING_OVERHEAD: usize = 10;
41
42pub struct TokenCounter {
43    bpe: &'static Option<CoreBPE>,
44    cache: DashMap<u64, usize>,
45    cache_cap: usize,
46}
47
48impl TokenCounter {
49    /// Create a new counter. Falls back to chars/4 if tiktoken init fails.
50    ///
51    /// BPE data is loaded once and cached in a `OnceLock` for the process lifetime.
52    #[must_use]
53    pub fn new() -> Self {
54        let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
55            Ok(b) => Some(b),
56            Err(e) => {
57                tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
58                None
59            }
60        });
61        Self {
62            bpe,
63            cache: DashMap::new(),
64            cache_cap: CACHE_CAP,
65        }
66    }
67
68    /// Count tokens in text. Uses cache, falls back to heuristic.
69    ///
70    /// Inputs exceeding 64 KiB bypass BPE and use chars/4 without caching to
71    /// avoid CPU amplification from oversized inputs.
72    #[must_use]
73    pub fn count_tokens(&self, text: &str) -> usize {
74        if text.is_empty() {
75            return 0;
76        }
77
78        if text.len() > MAX_INPUT_LEN {
79            return zeph_common::text::estimate_tokens(text);
80        }
81
82        let key = hash_text(text);
83
84        if let Some(cached) = self.cache.get(&key) {
85            return *cached;
86        }
87
88        let count = match self.bpe {
89            Some(bpe) => bpe.encode_with_special_tokens(text).len(),
90            None => zeph_common::text::estimate_tokens(text),
91        };
92
93        // TOCTOU between len() check and insert is benign: worst case we evict
94        // one extra entry and temporarily exceed the cap by one slot.
95        if self.cache.len() >= self.cache_cap {
96            let key_to_evict = self.cache.iter().next().map(|e| *e.key());
97            if let Some(k) = key_to_evict {
98                self.cache.remove(&k);
99            }
100        }
101        self.cache.insert(key, count);
102
103        count
104    }
105
106    /// Estimate token count for a message the way the LLM API will see it.
107    ///
108    /// When structured parts exist, counts from parts matching the API payload
109    /// structure. Falls back to `content` (flattened text) when parts is empty.
110    #[must_use]
111    pub fn count_message_tokens(&self, msg: &Message) -> usize {
112        if msg.parts.is_empty() {
113            return self.count_tokens(&msg.content);
114        }
115        msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
116    }
117
118    /// Estimate tokens for a single [`MessagePart`] matching the API payload structure.
119    #[must_use]
120    fn count_part_tokens(&self, part: &MessagePart) -> usize {
121        match part {
122            MessagePart::Text { text }
123            | MessagePart::Recall { text }
124            | MessagePart::CodeContext { text }
125            | MessagePart::Summary { text }
126            | MessagePart::CrossSession { text } => {
127                if text.trim().is_empty() {
128                    return 0;
129                }
130                self.count_tokens(text)
131            }
132
133            // API always emits `[tool output: {name}]\n{body}` regardless of compacted_at.
134            // When body is emptied by compaction, count_tokens(body) returns 0 naturally.
135            MessagePart::ToolOutput {
136                tool_name, body, ..
137            } => TOOL_OUTPUT_OVERHEAD + self.count_tokens(tool_name) + self.count_tokens(body),
138
139            // API sends structured JSON block: `{"type":"tool_use","id":"...","name":"...","input":...}`
140            MessagePart::ToolUse { id, name, input } => {
141                TOOL_USE_OVERHEAD
142                    + self.count_tokens(id)
143                    + self.count_tokens(name)
144                    + self.count_tokens(&input.to_string())
145            }
146
147            // API sends structured block: `{"type":"tool_result","tool_use_id":"...","content":"..."}`
148            MessagePart::ToolResult {
149                tool_use_id,
150                content,
151                ..
152            } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
153
154            // Image token count depends on pixel dimensions, which are unavailable in ImageData.
155            // Using a fixed constant is more accurate than bytes-based formula because
156            // Claude's actual formula is width*height based, not payload-size based.
157            MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
158
159            // ThinkingBlock is preserved verbatim in multi-turn requests.
160            MessagePart::ThinkingBlock {
161                thinking,
162                signature,
163            } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
164
165            // RedactedThinkingBlock is an opaque base64 blob — BPE is not meaningful here.
166            MessagePart::RedactedThinkingBlock { data } => {
167                THINKING_OVERHEAD + estimate_tokens(data)
168            }
169
170            // Compaction summary is sent back verbatim to the API.
171            MessagePart::Compaction { summary } => self.count_tokens(summary),
172        }
173    }
174
175    /// Count tokens for an `OpenAI` tool/function schema `JSON` value.
176    #[must_use]
177    pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
178        let base = count_schema_value(self, schema);
179        let total =
180            base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
181        total.max(0).cast_unsigned()
182    }
183}
184
185impl Default for TokenCounter {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191fn hash_text(text: &str) -> u64 {
192    let mut hasher = DefaultHasher::new();
193    text.hash(&mut hasher);
194    hasher.finish()
195}
196
197fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
198    match value {
199        serde_json::Value::Object(map) => {
200            let mut tokens = PROP_INIT;
201            for (key, val) in map {
202                tokens += PROP_KEY + counter.count_tokens(key);
203                tokens += count_schema_value(counter, val);
204            }
205            tokens
206        }
207        serde_json::Value::Array(arr) => {
208            let mut tokens = ENUM_ITEM;
209            for item in arr {
210                tokens += count_schema_value(counter, item);
211            }
212            tokens
213        }
214        serde_json::Value::String(s) => counter.count_tokens(s),
215        serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
223
224    static BPE_NONE: Option<CoreBPE> = None;
225
226    fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
227        TokenCounter {
228            bpe: &BPE_NONE,
229            cache: DashMap::new(),
230            cache_cap,
231        }
232    }
233
234    fn make_msg(parts: Vec<MessagePart>) -> Message {
235        Message::from_parts(Role::User, parts)
236    }
237
238    fn make_msg_no_parts(content: &str) -> Message {
239        Message {
240            role: Role::User,
241            content: content.to_string(),
242            parts: vec![],
243            metadata: MessageMetadata::default(),
244        }
245    }
246
247    #[test]
248    fn count_message_tokens_empty_parts_falls_back_to_content() {
249        let counter = TokenCounter::new();
250        let msg = make_msg_no_parts("hello world");
251        assert_eq!(
252            counter.count_message_tokens(&msg),
253            counter.count_tokens("hello world")
254        );
255    }
256
257    #[test]
258    fn count_message_tokens_text_part_matches_count_tokens() {
259        let counter = TokenCounter::new();
260        let text = "the quick brown fox jumps over the lazy dog";
261        let msg = make_msg(vec![MessagePart::Text {
262            text: text.to_string(),
263        }]);
264        assert_eq!(
265            counter.count_message_tokens(&msg),
266            counter.count_tokens(text)
267        );
268    }
269
270    #[test]
271    fn count_message_tokens_tool_use_exceeds_flattened_content() {
272        let counter = TokenCounter::new();
273        // Large JSON input: structured counting should be higher than flattened "[tool_use: bash(id)]"
274        let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
275        let msg = make_msg(vec![MessagePart::ToolUse {
276            id: "toolu_abc".into(),
277            name: "bash".into(),
278            input,
279        }]);
280        let structured = counter.count_message_tokens(&msg);
281        let flattened = counter.count_tokens(&msg.content);
282        assert!(
283            structured > flattened,
284            "structured={structured} should exceed flattened={flattened}"
285        );
286    }
287
288    #[test]
289    fn count_message_tokens_compacted_tool_output_is_small() {
290        let counter = TokenCounter::new();
291        // Compacted ToolOutput has empty body — should count close to overhead only
292        let msg = make_msg(vec![MessagePart::ToolOutput {
293            tool_name: "bash".into(),
294            body: String::new(),
295            compacted_at: Some(1_700_000_000),
296        }]);
297        let tokens = counter.count_message_tokens(&msg);
298        // Should be small: TOOL_OUTPUT_OVERHEAD + count_tokens("bash") + 0
299        assert!(
300            tokens <= 15,
301            "compacted tool output should be small, got {tokens}"
302        );
303    }
304
305    #[test]
306    fn count_message_tokens_image_returns_constant() {
307        let counter = TokenCounter::new();
308        let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
309            data: vec![0u8; 1000],
310            mime_type: "image/jpeg".into(),
311        }))]);
312        assert_eq!(
313            counter.count_message_tokens(&msg),
314            IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
315        );
316    }
317
318    #[test]
319    fn count_message_tokens_thinking_block_counts_text() {
320        let counter = TokenCounter::new();
321        let thinking = "step by step reasoning about the problem";
322        let signature = "sig";
323        let msg = make_msg(vec![MessagePart::ThinkingBlock {
324            thinking: thinking.to_string(),
325            signature: signature.to_string(),
326        }]);
327        let expected =
328            THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
329        assert_eq!(counter.count_message_tokens(&msg), expected);
330    }
331
332    #[test]
333    fn count_part_tokens_empty_text_returns_zero() {
334        let counter = TokenCounter::new();
335        assert_eq!(
336            counter.count_part_tokens(&MessagePart::Text {
337                text: String::new()
338            }),
339            0
340        );
341        assert_eq!(
342            counter.count_part_tokens(&MessagePart::Text {
343                text: "   ".to_string()
344            }),
345            0
346        );
347        assert_eq!(
348            counter.count_part_tokens(&MessagePart::Recall {
349                text: "\n\t".to_string()
350            }),
351            0
352        );
353    }
354
355    #[test]
356    fn count_message_tokens_push_recompute_consistency() {
357        // Verify that sum of count_message_tokens per part equals recompute result
358        let counter = TokenCounter::new();
359        let parts = vec![
360            MessagePart::Text {
361                text: "hello".into(),
362            },
363            MessagePart::ToolOutput {
364                tool_name: "bash".into(),
365                body: "output data".into(),
366                compacted_at: None,
367            },
368        ];
369        let msg = make_msg(parts);
370        let total = counter.count_message_tokens(&msg);
371        let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
372        assert_eq!(total, sum);
373    }
374
375    #[test]
376    fn count_message_tokens_parts_take_priority_over_content() {
377        // R-2: primary regression guard — when parts is non-empty, content is ignored.
378        let counter = TokenCounter::new();
379        let parts_text = "hello from parts";
380        let msg = Message {
381            role: Role::User,
382            content: "completely different content that should be ignored".to_string(),
383            parts: vec![MessagePart::Text {
384                text: parts_text.to_string(),
385            }],
386            metadata: MessageMetadata::default(),
387        };
388        let parts_based = counter.count_tokens(parts_text);
389        let content_based = counter.count_tokens(&msg.content);
390        assert_ne!(
391            parts_based, content_based,
392            "test setup: parts and content must differ"
393        );
394        assert_eq!(counter.count_message_tokens(&msg), parts_based);
395    }
396
397    #[test]
398    fn count_part_tokens_tool_result() {
399        // R-3: verify ToolResult arm counting
400        let counter = TokenCounter::new();
401        let tool_use_id = "toolu_xyz";
402        let content = "result text";
403        let part = MessagePart::ToolResult {
404            tool_use_id: tool_use_id.to_string(),
405            content: content.to_string(),
406            is_error: false,
407        };
408        let expected = TOOL_RESULT_OVERHEAD
409            + counter.count_tokens(tool_use_id)
410            + counter.count_tokens(content);
411        assert_eq!(counter.count_part_tokens(&part), expected);
412    }
413
414    #[test]
415    fn count_tokens_empty() {
416        let counter = TokenCounter::new();
417        assert_eq!(counter.count_tokens(""), 0);
418    }
419
420    #[test]
421    fn count_tokens_non_empty() {
422        let counter = TokenCounter::new();
423        assert!(counter.count_tokens("hello world") > 0);
424    }
425
426    #[test]
427    fn count_tokens_cache_hit_returns_same() {
428        let counter = TokenCounter::new();
429        let text = "the quick brown fox";
430        let first = counter.count_tokens(text);
431        let second = counter.count_tokens(text);
432        assert_eq!(first, second);
433    }
434
435    #[test]
436    fn count_tokens_fallback_mode() {
437        let counter = counter_with_no_bpe(CACHE_CAP);
438        // 8 chars / 4 = 2
439        assert_eq!(counter.count_tokens("abcdefgh"), 2);
440        assert_eq!(counter.count_tokens(""), 0);
441    }
442
443    #[test]
444    fn count_tokens_oversized_input_uses_fallback_without_cache() {
445        let counter = TokenCounter::new();
446        // Generate input larger than MAX_INPUT_LEN (65536 bytes)
447        let large = "a".repeat(MAX_INPUT_LEN + 1);
448        let result = counter.count_tokens(&large);
449        // chars/4 fallback: (65537 chars) / 4
450        assert_eq!(result, zeph_common::text::estimate_tokens(&large));
451        // Must not be cached
452        assert!(counter.cache.is_empty());
453    }
454
455    #[test]
456    fn count_tokens_unicode_bpe_differs_from_fallback() {
457        let counter = TokenCounter::new();
458        let text = "Привет мир! 你好世界! こんにちは! 🌍";
459        let bpe_count = counter.count_tokens(text);
460        let fallback_count = zeph_common::text::estimate_tokens(text);
461        // BPE should return > 0
462        assert!(bpe_count > 0, "BPE count must be positive");
463        // BPE result should differ from naive chars/4 for multi-byte text
464        assert_ne!(
465            bpe_count, fallback_count,
466            "BPE tokenization should differ from chars/4 for unicode text"
467        );
468    }
469
470    #[test]
471    fn count_tool_schema_tokens_sample() {
472        let counter = TokenCounter::new();
473        let schema = serde_json::json!({
474            "name": "get_weather",
475            "description": "Get the current weather for a location",
476            "parameters": {
477                "type": "object",
478                "properties": {
479                    "location": {
480                        "type": "string",
481                        "description": "The city name"
482                    }
483                },
484                "required": ["location"]
485            }
486        });
487        let tokens = counter.count_tool_schema_tokens(&schema);
488        // Pinned value: computed by running the formula with cl100k_base on this exact schema.
489        // If this fails, a tokenizer or formula change likely occurred.
490        assert_eq!(tokens, 82);
491    }
492
493    #[test]
494    fn two_instances_share_bpe_pointer() {
495        let a = TokenCounter::new();
496        let b = TokenCounter::new();
497        assert!(std::ptr::eq(a.bpe, b.bpe));
498    }
499
500    #[test]
501    fn cache_eviction_at_capacity() {
502        let counter = counter_with_no_bpe(3);
503        let _ = counter.count_tokens("aaaa");
504        let _ = counter.count_tokens("bbbb");
505        let _ = counter.count_tokens("cccc");
506        assert_eq!(counter.cache.len(), 3);
507        // This should evict one entry
508        let _ = counter.count_tokens("dddd");
509        assert_eq!(counter.cache.len(), 3);
510    }
511}