Skip to main content

zeph_memory/
token_counter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_common::text::estimate_tokens;
11use zeph_llm::provider::{Message, MessagePart};
12
13static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
14
15const CACHE_CAP: usize = 10_000;
16/// Inputs larger than this limit bypass BPE encoding and use the chars/4 fallback.
17/// Prevents CPU amplification from pathologically large inputs.
18const MAX_INPUT_LEN: usize = 65_536;
19
20// OpenAI function-calling token overhead constants
21const FUNC_INIT: usize = 7;
22const PROP_INIT: usize = 3;
23const PROP_KEY: usize = 3;
24const ENUM_INIT: isize = -3;
25const ENUM_ITEM: usize = 3;
26const FUNC_END: usize = 12;
27
28// Structural overhead per part type (approximate token counts for JSON framing)
29/// `{"type":"tool_use","id":"","name":"","input":}`
30const TOOL_USE_OVERHEAD: usize = 20;
31/// `{"type":"tool_result","tool_use_id":"","content":""}`
32const TOOL_RESULT_OVERHEAD: usize = 15;
33/// `[tool output: <name>]\n` wrapping
34const TOOL_OUTPUT_OVERHEAD: usize = 8;
35/// Image block JSON structure overhead (dimension-based counting unavailable)
36const IMAGE_OVERHEAD: usize = 50;
37/// Default token estimate for a typical image (dimensions not available in `ImageData`)
38const IMAGE_DEFAULT_TOKENS: usize = 1000;
39/// Thinking/redacted block framing
40const THINKING_OVERHEAD: usize = 10;
41
42/// Token counter backed by `tiktoken` `cl100k_base` BPE encoding.
43///
44/// Estimates how many tokens a piece of text or a full [`Message`] will consume when
45/// sent to an LLM API.  Uses a process-scoped [`OnceLock`] so BPE data is loaded once
46/// and shared across all instances.
47///
48/// Falls back to a `chars/4` heuristic when tiktoken init fails or when the input
49/// exceeds `MAX_INPUT_LEN` bytes (64 KiB).
50///
51/// # Examples
52///
53/// ```
54/// use zeph_memory::TokenCounter;
55///
56/// let counter = TokenCounter::new();
57/// let n = counter.count_tokens("Hello, world!");
58/// assert!(n > 0);
59/// ```
60pub struct TokenCounter {
61    bpe: &'static Option<CoreBPE>,
62    cache: DashMap<u64, usize>,
63    cache_cap: usize,
64}
65
66impl TokenCounter {
67    /// Create a new counter. Falls back to chars/4 if tiktoken init fails.
68    ///
69    /// BPE data is loaded once and cached in a `OnceLock` for the process lifetime.
70    #[must_use]
71    pub fn new() -> Self {
72        let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
73            Ok(b) => Some(b),
74            Err(e) => {
75                tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
76                None
77            }
78        });
79        Self {
80            bpe,
81            cache: DashMap::new(),
82            cache_cap: CACHE_CAP,
83        }
84    }
85
86    /// Count tokens in text. Uses cache, falls back to heuristic.
87    ///
88    /// Inputs exceeding 64 KiB bypass BPE and use chars/4 without caching to
89    /// avoid CPU amplification from oversized inputs.
90    #[must_use]
91    pub fn count_tokens(&self, text: &str) -> usize {
92        if text.is_empty() {
93            return 0;
94        }
95
96        if text.len() > MAX_INPUT_LEN {
97            return zeph_common::text::estimate_tokens(text);
98        }
99
100        let key = hash_text(text);
101
102        if let Some(cached) = self.cache.get(&key) {
103            return *cached;
104        }
105
106        let count = match self.bpe {
107            Some(bpe) => bpe.encode_with_special_tokens(text).len(),
108            None => zeph_common::text::estimate_tokens(text),
109        };
110
111        // TOCTOU between len() check and insert is benign: worst case we evict
112        // one extra entry and temporarily exceed the cap by one slot.
113        if self.cache.len() >= self.cache_cap {
114            let key_to_evict = self.cache.iter().next().map(|e| *e.key());
115            if let Some(k) = key_to_evict {
116                self.cache.remove(&k);
117            }
118        }
119        self.cache.insert(key, count);
120
121        count
122    }
123
124    /// Estimate token count for a message the way the LLM API will see it.
125    ///
126    /// When structured parts exist, counts from parts matching the API payload
127    /// structure. Falls back to `content` (flattened text) when parts is empty.
128    #[must_use]
129    pub fn count_message_tokens(&self, msg: &Message) -> usize {
130        if msg.parts.is_empty() {
131            return self.count_tokens(&msg.content);
132        }
133        msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
134    }
135
136    /// Estimate tokens for a single [`MessagePart`] matching the API payload structure.
137    #[must_use]
138    fn count_part_tokens(&self, part: &MessagePart) -> usize {
139        match part {
140            MessagePart::Text { text }
141            | MessagePart::Recall { text }
142            | MessagePart::CodeContext { text }
143            | MessagePart::Summary { text }
144            | MessagePart::CrossSession { text } => {
145                if text.trim().is_empty() {
146                    return 0;
147                }
148                self.count_tokens(text)
149            }
150
151            // API always emits `[tool output: {name}]\n{body}` regardless of compacted_at.
152            // When body is emptied by compaction, count_tokens(body) returns 0 naturally.
153            MessagePart::ToolOutput {
154                tool_name, body, ..
155            } => {
156                TOOL_OUTPUT_OVERHEAD
157                    + self.count_tokens(tool_name.as_str())
158                    + self.count_tokens(body)
159            }
160
161            // API sends structured JSON block: `{"type":"tool_use","id":"...","name":"...","input":...}`
162            MessagePart::ToolUse { id, name, input } => {
163                TOOL_USE_OVERHEAD
164                    + self.count_tokens(id)
165                    + self.count_tokens(name)
166                    + self.count_tokens(&input.to_string())
167            }
168
169            // API sends structured block: `{"type":"tool_result","tool_use_id":"...","content":"..."}`
170            MessagePart::ToolResult {
171                tool_use_id,
172                content,
173                ..
174            } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
175
176            // Image token count depends on pixel dimensions, which are unavailable in ImageData.
177            // Using a fixed constant is more accurate than bytes-based formula because
178            // Claude's actual formula is width*height based, not payload-size based.
179            MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
180
181            // ThinkingBlock is preserved verbatim in multi-turn requests.
182            MessagePart::ThinkingBlock {
183                thinking,
184                signature,
185            } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
186
187            // RedactedThinkingBlock is an opaque base64 blob — BPE is not meaningful here.
188            MessagePart::RedactedThinkingBlock { data } => {
189                THINKING_OVERHEAD + estimate_tokens(data)
190            }
191
192            // Compaction summary is sent back verbatim to the API.
193            MessagePart::Compaction { summary } => self.count_tokens(summary),
194        }
195    }
196
197    /// Count tokens for an `OpenAI` tool/function schema `JSON` value.
198    #[must_use]
199    pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
200        let base = count_schema_value(self, schema);
201        let total =
202            base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
203        total.max(0).cast_unsigned()
204    }
205}
206
207impl Default for TokenCounter {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213fn hash_text(text: &str) -> u64 {
214    let mut hasher = DefaultHasher::new();
215    text.hash(&mut hasher);
216    hasher.finish()
217}
218
219fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
220    match value {
221        serde_json::Value::Object(map) => {
222            let mut tokens = PROP_INIT;
223            for (key, val) in map {
224                tokens += PROP_KEY + counter.count_tokens(key);
225                tokens += count_schema_value(counter, val);
226            }
227            tokens
228        }
229        serde_json::Value::Array(arr) => {
230            let mut tokens = ENUM_ITEM;
231            for item in arr {
232                tokens += count_schema_value(counter, item);
233            }
234            tokens
235        }
236        serde_json::Value::String(s) => counter.count_tokens(s),
237        serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
245
246    static BPE_NONE: Option<CoreBPE> = None;
247
248    fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
249        TokenCounter {
250            bpe: &BPE_NONE,
251            cache: DashMap::new(),
252            cache_cap,
253        }
254    }
255
256    fn make_msg(parts: Vec<MessagePart>) -> Message {
257        Message::from_parts(Role::User, parts)
258    }
259
260    fn make_msg_no_parts(content: &str) -> Message {
261        Message {
262            role: Role::User,
263            content: content.to_string(),
264            parts: vec![],
265            metadata: MessageMetadata::default(),
266        }
267    }
268
269    #[test]
270    fn count_message_tokens_empty_parts_falls_back_to_content() {
271        let counter = TokenCounter::new();
272        let msg = make_msg_no_parts("hello world");
273        assert_eq!(
274            counter.count_message_tokens(&msg),
275            counter.count_tokens("hello world")
276        );
277    }
278
279    #[test]
280    fn count_message_tokens_text_part_matches_count_tokens() {
281        let counter = TokenCounter::new();
282        let text = "the quick brown fox jumps over the lazy dog";
283        let msg = make_msg(vec![MessagePart::Text {
284            text: text.to_string(),
285        }]);
286        assert_eq!(
287            counter.count_message_tokens(&msg),
288            counter.count_tokens(text)
289        );
290    }
291
292    #[test]
293    fn count_message_tokens_tool_use_exceeds_flattened_content() {
294        let counter = TokenCounter::new();
295        // Large JSON input: structured counting should be higher than flattened "[tool_use: bash(id)]"
296        let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
297        let msg = make_msg(vec![MessagePart::ToolUse {
298            id: "toolu_abc".into(),
299            name: "bash".into(),
300            input,
301        }]);
302        let structured = counter.count_message_tokens(&msg);
303        let flattened = counter.count_tokens(&msg.content);
304        assert!(
305            structured > flattened,
306            "structured={structured} should exceed flattened={flattened}"
307        );
308    }
309
310    #[test]
311    fn count_message_tokens_compacted_tool_output_is_small() {
312        let counter = TokenCounter::new();
313        // Compacted ToolOutput has empty body — should count close to overhead only
314        let msg = make_msg(vec![MessagePart::ToolOutput {
315            tool_name: "bash".into(),
316            body: String::new(),
317            compacted_at: Some(1_700_000_000),
318        }]);
319        let tokens = counter.count_message_tokens(&msg);
320        // Should be small: TOOL_OUTPUT_OVERHEAD + count_tokens("bash") + 0
321        assert!(
322            tokens <= 15,
323            "compacted tool output should be small, got {tokens}"
324        );
325    }
326
327    #[test]
328    fn count_message_tokens_image_returns_constant() {
329        let counter = TokenCounter::new();
330        let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
331            data: vec![0u8; 1000],
332            mime_type: "image/jpeg".into(),
333        }))]);
334        assert_eq!(
335            counter.count_message_tokens(&msg),
336            IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
337        );
338    }
339
340    #[test]
341    fn count_message_tokens_thinking_block_counts_text() {
342        let counter = TokenCounter::new();
343        let thinking = "step by step reasoning about the problem";
344        let signature = "sig";
345        let msg = make_msg(vec![MessagePart::ThinkingBlock {
346            thinking: thinking.to_string(),
347            signature: signature.to_string(),
348        }]);
349        let expected =
350            THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
351        assert_eq!(counter.count_message_tokens(&msg), expected);
352    }
353
354    #[test]
355    fn count_part_tokens_empty_text_returns_zero() {
356        let counter = TokenCounter::new();
357        assert_eq!(
358            counter.count_part_tokens(&MessagePart::Text {
359                text: String::new()
360            }),
361            0
362        );
363        assert_eq!(
364            counter.count_part_tokens(&MessagePart::Text {
365                text: "   ".to_string()
366            }),
367            0
368        );
369        assert_eq!(
370            counter.count_part_tokens(&MessagePart::Recall {
371                text: "\n\t".to_string()
372            }),
373            0
374        );
375    }
376
377    #[test]
378    fn count_message_tokens_push_recompute_consistency() {
379        // Verify that sum of count_message_tokens per part equals recompute result
380        let counter = TokenCounter::new();
381        let parts = vec![
382            MessagePart::Text {
383                text: "hello".into(),
384            },
385            MessagePart::ToolOutput {
386                tool_name: "bash".into(),
387                body: "output data".into(),
388                compacted_at: None,
389            },
390        ];
391        let msg = make_msg(parts);
392        let total = counter.count_message_tokens(&msg);
393        let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
394        assert_eq!(total, sum);
395    }
396
397    #[test]
398    fn count_message_tokens_parts_take_priority_over_content() {
399        // R-2: primary regression guard — when parts is non-empty, content is ignored.
400        let counter = TokenCounter::new();
401        let parts_text = "hello from parts";
402        let msg = Message {
403            role: Role::User,
404            content: "completely different content that should be ignored".to_string(),
405            parts: vec![MessagePart::Text {
406                text: parts_text.to_string(),
407            }],
408            metadata: MessageMetadata::default(),
409        };
410        let parts_based = counter.count_tokens(parts_text);
411        let content_based = counter.count_tokens(&msg.content);
412        assert_ne!(
413            parts_based, content_based,
414            "test setup: parts and content must differ"
415        );
416        assert_eq!(counter.count_message_tokens(&msg), parts_based);
417    }
418
419    #[test]
420    fn count_part_tokens_tool_result() {
421        // R-3: verify ToolResult arm counting
422        let counter = TokenCounter::new();
423        let tool_use_id = "toolu_xyz";
424        let content = "result text";
425        let part = MessagePart::ToolResult {
426            tool_use_id: tool_use_id.to_string(),
427            content: content.to_string(),
428            is_error: false,
429        };
430        let expected = TOOL_RESULT_OVERHEAD
431            + counter.count_tokens(tool_use_id)
432            + counter.count_tokens(content);
433        assert_eq!(counter.count_part_tokens(&part), expected);
434    }
435
436    #[test]
437    fn count_tokens_empty() {
438        let counter = TokenCounter::new();
439        assert_eq!(counter.count_tokens(""), 0);
440    }
441
442    #[test]
443    fn count_tokens_non_empty() {
444        let counter = TokenCounter::new();
445        assert!(counter.count_tokens("hello world") > 0);
446    }
447
448    #[test]
449    fn count_tokens_cache_hit_returns_same() {
450        let counter = TokenCounter::new();
451        let text = "the quick brown fox";
452        let first = counter.count_tokens(text);
453        let second = counter.count_tokens(text);
454        assert_eq!(first, second);
455    }
456
457    #[test]
458    fn count_tokens_fallback_mode() {
459        let counter = counter_with_no_bpe(CACHE_CAP);
460        // 8 chars / 4 = 2
461        assert_eq!(counter.count_tokens("abcdefgh"), 2);
462        assert_eq!(counter.count_tokens(""), 0);
463    }
464
465    #[test]
466    fn count_tokens_oversized_input_uses_fallback_without_cache() {
467        let counter = TokenCounter::new();
468        // Generate input larger than MAX_INPUT_LEN (65536 bytes)
469        let large = "a".repeat(MAX_INPUT_LEN + 1);
470        let result = counter.count_tokens(&large);
471        // chars/4 fallback: (65537 chars) / 4
472        assert_eq!(result, zeph_common::text::estimate_tokens(&large));
473        // Must not be cached
474        assert!(counter.cache.is_empty());
475    }
476
477    #[test]
478    fn count_tokens_unicode_bpe_differs_from_fallback() {
479        let counter = TokenCounter::new();
480        let text = "Привет мир! 你好世界! こんにちは! 🌍";
481        let bpe_count = counter.count_tokens(text);
482        let fallback_count = zeph_common::text::estimate_tokens(text);
483        // BPE should return > 0
484        assert!(bpe_count > 0, "BPE count must be positive");
485        // BPE result should differ from naive chars/4 for multi-byte text
486        assert_ne!(
487            bpe_count, fallback_count,
488            "BPE tokenization should differ from chars/4 for unicode text"
489        );
490    }
491
492    #[test]
493    fn count_tool_schema_tokens_sample() {
494        let counter = TokenCounter::new();
495        let schema = serde_json::json!({
496            "name": "get_weather",
497            "description": "Get the current weather for a location",
498            "parameters": {
499                "type": "object",
500                "properties": {
501                    "location": {
502                        "type": "string",
503                        "description": "The city name"
504                    }
505                },
506                "required": ["location"]
507            }
508        });
509        let tokens = counter.count_tool_schema_tokens(&schema);
510        // Pinned value: computed by running the formula with cl100k_base on this exact schema.
511        // If this fails, a tokenizer or formula change likely occurred.
512        assert_eq!(tokens, 82);
513    }
514
515    #[test]
516    fn two_instances_share_bpe_pointer() {
517        let a = TokenCounter::new();
518        let b = TokenCounter::new();
519        assert!(std::ptr::eq(a.bpe, b.bpe));
520    }
521
522    #[test]
523    fn cache_eviction_at_capacity() {
524        let counter = counter_with_no_bpe(3);
525        let _ = counter.count_tokens("aaaa");
526        let _ = counter.count_tokens("bbbb");
527        let _ = counter.count_tokens("cccc");
528        assert_eq!(counter.cache.len(), 3);
529        // This should evict one entry
530        let _ = counter.count_tokens("dddd");
531        assert_eq!(counter.cache.len(), 3);
532    }
533}