Skip to main content

zeph_memory/
token_counter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_llm::provider::{Message, MessagePart};
11
12static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
13
14const CACHE_CAP: usize = 10_000;
15/// Inputs larger than this limit bypass BPE encoding and use the chars/4 fallback.
16/// Prevents CPU amplification from pathologically large inputs.
17const MAX_INPUT_LEN: usize = 65_536;
18
19// OpenAI function-calling token overhead constants
20const FUNC_INIT: usize = 7;
21const PROP_INIT: usize = 3;
22const PROP_KEY: usize = 3;
23const ENUM_INIT: isize = -3;
24const ENUM_ITEM: usize = 3;
25const FUNC_END: usize = 12;
26
27// Structural overhead per part type (approximate token counts for JSON framing)
28/// `{"type":"tool_use","id":"","name":"","input":}`
29const TOOL_USE_OVERHEAD: usize = 20;
30/// `{"type":"tool_result","tool_use_id":"","content":""}`
31const TOOL_RESULT_OVERHEAD: usize = 15;
32/// `[tool output: <name>]\n` wrapping
33const TOOL_OUTPUT_OVERHEAD: usize = 8;
34/// Image block JSON structure overhead (dimension-based counting unavailable)
35const IMAGE_OVERHEAD: usize = 50;
36/// Default token estimate for a typical image (dimensions not available in `ImageData`)
37const IMAGE_DEFAULT_TOKENS: usize = 1000;
38/// Thinking/redacted block framing
39const THINKING_OVERHEAD: usize = 10;
40
41pub struct TokenCounter {
42    bpe: &'static Option<CoreBPE>,
43    cache: DashMap<u64, usize>,
44    cache_cap: usize,
45}
46
47impl TokenCounter {
48    /// Create a new counter. Falls back to chars/4 if tiktoken init fails.
49    ///
50    /// BPE data is loaded once and cached in a `OnceLock` for the process lifetime.
51    #[must_use]
52    pub fn new() -> Self {
53        let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
54            Ok(b) => Some(b),
55            Err(e) => {
56                tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
57                None
58            }
59        });
60        Self {
61            bpe,
62            cache: DashMap::new(),
63            cache_cap: CACHE_CAP,
64        }
65    }
66
67    /// Count tokens in text. Uses cache, falls back to heuristic.
68    ///
69    /// Inputs exceeding 64 KiB bypass BPE and use chars/4 without caching to
70    /// avoid CPU amplification from oversized inputs.
71    #[must_use]
72    pub fn count_tokens(&self, text: &str) -> usize {
73        if text.is_empty() {
74            return 0;
75        }
76
77        if text.len() > MAX_INPUT_LEN {
78            return zeph_common::text::estimate_tokens(text);
79        }
80
81        let key = hash_text(text);
82
83        if let Some(cached) = self.cache.get(&key) {
84            return *cached;
85        }
86
87        let count = match self.bpe {
88            Some(bpe) => bpe.encode_with_special_tokens(text).len(),
89            None => zeph_common::text::estimate_tokens(text),
90        };
91
92        // TOCTOU between len() check and insert is benign: worst case we evict
93        // one extra entry and temporarily exceed the cap by one slot.
94        if self.cache.len() >= self.cache_cap {
95            let key_to_evict = self.cache.iter().next().map(|e| *e.key());
96            if let Some(k) = key_to_evict {
97                self.cache.remove(&k);
98            }
99        }
100        self.cache.insert(key, count);
101
102        count
103    }
104
105    /// Estimate token count for a message the way the LLM API will see it.
106    ///
107    /// When structured parts exist, counts from parts matching the API payload
108    /// structure. Falls back to `content` (flattened text) when parts is empty.
109    #[must_use]
110    pub fn count_message_tokens(&self, msg: &Message) -> usize {
111        if msg.parts.is_empty() {
112            return self.count_tokens(&msg.content);
113        }
114        msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
115    }
116
117    /// Estimate tokens for a single [`MessagePart`] matching the API payload structure.
118    #[must_use]
119    fn count_part_tokens(&self, part: &MessagePart) -> usize {
120        match part {
121            MessagePart::Text { text }
122            | MessagePart::Recall { text }
123            | MessagePart::CodeContext { text }
124            | MessagePart::Summary { text }
125            | MessagePart::CrossSession { text } => {
126                if text.trim().is_empty() {
127                    return 0;
128                }
129                self.count_tokens(text)
130            }
131
132            // API always emits `[tool output: {name}]\n{body}` regardless of compacted_at.
133            // When body is emptied by compaction, count_tokens(body) returns 0 naturally.
134            MessagePart::ToolOutput {
135                tool_name, body, ..
136            } => TOOL_OUTPUT_OVERHEAD + self.count_tokens(tool_name) + self.count_tokens(body),
137
138            // API sends structured JSON block: `{"type":"tool_use","id":"...","name":"...","input":...}`
139            MessagePart::ToolUse { id, name, input } => {
140                TOOL_USE_OVERHEAD
141                    + self.count_tokens(id)
142                    + self.count_tokens(name)
143                    + self.count_tokens(&input.to_string())
144            }
145
146            // API sends structured block: `{"type":"tool_result","tool_use_id":"...","content":"..."}`
147            MessagePart::ToolResult {
148                tool_use_id,
149                content,
150                ..
151            } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
152
153            // Image token count depends on pixel dimensions, which are unavailable in ImageData.
154            // Using a fixed constant is more accurate than bytes-based formula because
155            // Claude's actual formula is width*height based, not payload-size based.
156            MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
157
158            // ThinkingBlock is preserved verbatim in multi-turn requests.
159            MessagePart::ThinkingBlock {
160                thinking,
161                signature,
162            } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
163
164            // RedactedThinkingBlock is an opaque base64 blob — BPE is not meaningful here.
165            MessagePart::RedactedThinkingBlock { data } => THINKING_OVERHEAD + data.len() / 4,
166
167            // Compaction summary is sent back verbatim to the API.
168            MessagePart::Compaction { summary } => self.count_tokens(summary),
169        }
170    }
171
172    /// Count tokens for an `OpenAI` tool/function schema `JSON` value.
173    #[must_use]
174    pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
175        let base = count_schema_value(self, schema);
176        let total =
177            base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
178        total.max(0).cast_unsigned()
179    }
180}
181
182impl Default for TokenCounter {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188fn hash_text(text: &str) -> u64 {
189    let mut hasher = DefaultHasher::new();
190    text.hash(&mut hasher);
191    hasher.finish()
192}
193
194fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
195    match value {
196        serde_json::Value::Object(map) => {
197            let mut tokens = PROP_INIT;
198            for (key, val) in map {
199                tokens += PROP_KEY + counter.count_tokens(key);
200                tokens += count_schema_value(counter, val);
201            }
202            tokens
203        }
204        serde_json::Value::Array(arr) => {
205            let mut tokens = ENUM_ITEM;
206            for item in arr {
207                tokens += count_schema_value(counter, item);
208            }
209            tokens
210        }
211        serde_json::Value::String(s) => counter.count_tokens(s),
212        serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
220
221    static BPE_NONE: Option<CoreBPE> = None;
222
223    fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
224        TokenCounter {
225            bpe: &BPE_NONE,
226            cache: DashMap::new(),
227            cache_cap,
228        }
229    }
230
231    fn make_msg(parts: Vec<MessagePart>) -> Message {
232        Message::from_parts(Role::User, parts)
233    }
234
235    fn make_msg_no_parts(content: &str) -> Message {
236        Message {
237            role: Role::User,
238            content: content.to_string(),
239            parts: vec![],
240            metadata: MessageMetadata::default(),
241        }
242    }
243
244    #[test]
245    fn count_message_tokens_empty_parts_falls_back_to_content() {
246        let counter = TokenCounter::new();
247        let msg = make_msg_no_parts("hello world");
248        assert_eq!(
249            counter.count_message_tokens(&msg),
250            counter.count_tokens("hello world")
251        );
252    }
253
254    #[test]
255    fn count_message_tokens_text_part_matches_count_tokens() {
256        let counter = TokenCounter::new();
257        let text = "the quick brown fox jumps over the lazy dog";
258        let msg = make_msg(vec![MessagePart::Text {
259            text: text.to_string(),
260        }]);
261        assert_eq!(
262            counter.count_message_tokens(&msg),
263            counter.count_tokens(text)
264        );
265    }
266
267    #[test]
268    fn count_message_tokens_tool_use_exceeds_flattened_content() {
269        let counter = TokenCounter::new();
270        // Large JSON input: structured counting should be higher than flattened "[tool_use: bash(id)]"
271        let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
272        let msg = make_msg(vec![MessagePart::ToolUse {
273            id: "toolu_abc".into(),
274            name: "bash".into(),
275            input,
276        }]);
277        let structured = counter.count_message_tokens(&msg);
278        let flattened = counter.count_tokens(&msg.content);
279        assert!(
280            structured > flattened,
281            "structured={structured} should exceed flattened={flattened}"
282        );
283    }
284
285    #[test]
286    fn count_message_tokens_compacted_tool_output_is_small() {
287        let counter = TokenCounter::new();
288        // Compacted ToolOutput has empty body — should count close to overhead only
289        let msg = make_msg(vec![MessagePart::ToolOutput {
290            tool_name: "bash".into(),
291            body: String::new(),
292            compacted_at: Some(1_700_000_000),
293        }]);
294        let tokens = counter.count_message_tokens(&msg);
295        // Should be small: TOOL_OUTPUT_OVERHEAD + count_tokens("bash") + 0
296        assert!(
297            tokens <= 15,
298            "compacted tool output should be small, got {tokens}"
299        );
300    }
301
302    #[test]
303    fn count_message_tokens_image_returns_constant() {
304        let counter = TokenCounter::new();
305        let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
306            data: vec![0u8; 1000],
307            mime_type: "image/jpeg".into(),
308        }))]);
309        assert_eq!(
310            counter.count_message_tokens(&msg),
311            IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
312        );
313    }
314
315    #[test]
316    fn count_message_tokens_thinking_block_counts_text() {
317        let counter = TokenCounter::new();
318        let thinking = "step by step reasoning about the problem";
319        let signature = "sig";
320        let msg = make_msg(vec![MessagePart::ThinkingBlock {
321            thinking: thinking.to_string(),
322            signature: signature.to_string(),
323        }]);
324        let expected =
325            THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
326        assert_eq!(counter.count_message_tokens(&msg), expected);
327    }
328
329    #[test]
330    fn count_part_tokens_empty_text_returns_zero() {
331        let counter = TokenCounter::new();
332        assert_eq!(
333            counter.count_part_tokens(&MessagePart::Text {
334                text: String::new()
335            }),
336            0
337        );
338        assert_eq!(
339            counter.count_part_tokens(&MessagePart::Text {
340                text: "   ".to_string()
341            }),
342            0
343        );
344        assert_eq!(
345            counter.count_part_tokens(&MessagePart::Recall {
346                text: "\n\t".to_string()
347            }),
348            0
349        );
350    }
351
352    #[test]
353    fn count_message_tokens_push_recompute_consistency() {
354        // Verify that sum of count_message_tokens per part equals recompute result
355        let counter = TokenCounter::new();
356        let parts = vec![
357            MessagePart::Text {
358                text: "hello".into(),
359            },
360            MessagePart::ToolOutput {
361                tool_name: "bash".into(),
362                body: "output data".into(),
363                compacted_at: None,
364            },
365        ];
366        let msg = make_msg(parts);
367        let total = counter.count_message_tokens(&msg);
368        let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
369        assert_eq!(total, sum);
370    }
371
372    #[test]
373    fn count_message_tokens_parts_take_priority_over_content() {
374        // R-2: primary regression guard — when parts is non-empty, content is ignored.
375        let counter = TokenCounter::new();
376        let parts_text = "hello from parts";
377        let msg = Message {
378            role: Role::User,
379            content: "completely different content that should be ignored".to_string(),
380            parts: vec![MessagePart::Text {
381                text: parts_text.to_string(),
382            }],
383            metadata: MessageMetadata::default(),
384        };
385        let parts_based = counter.count_tokens(parts_text);
386        let content_based = counter.count_tokens(&msg.content);
387        assert_ne!(
388            parts_based, content_based,
389            "test setup: parts and content must differ"
390        );
391        assert_eq!(counter.count_message_tokens(&msg), parts_based);
392    }
393
394    #[test]
395    fn count_part_tokens_tool_result() {
396        // R-3: verify ToolResult arm counting
397        let counter = TokenCounter::new();
398        let tool_use_id = "toolu_xyz";
399        let content = "result text";
400        let part = MessagePart::ToolResult {
401            tool_use_id: tool_use_id.to_string(),
402            content: content.to_string(),
403            is_error: false,
404        };
405        let expected = TOOL_RESULT_OVERHEAD
406            + counter.count_tokens(tool_use_id)
407            + counter.count_tokens(content);
408        assert_eq!(counter.count_part_tokens(&part), expected);
409    }
410
411    #[test]
412    fn count_tokens_empty() {
413        let counter = TokenCounter::new();
414        assert_eq!(counter.count_tokens(""), 0);
415    }
416
417    #[test]
418    fn count_tokens_non_empty() {
419        let counter = TokenCounter::new();
420        assert!(counter.count_tokens("hello world") > 0);
421    }
422
423    #[test]
424    fn count_tokens_cache_hit_returns_same() {
425        let counter = TokenCounter::new();
426        let text = "the quick brown fox";
427        let first = counter.count_tokens(text);
428        let second = counter.count_tokens(text);
429        assert_eq!(first, second);
430    }
431
432    #[test]
433    fn count_tokens_fallback_mode() {
434        let counter = counter_with_no_bpe(CACHE_CAP);
435        // 8 chars / 4 = 2
436        assert_eq!(counter.count_tokens("abcdefgh"), 2);
437        assert_eq!(counter.count_tokens(""), 0);
438    }
439
440    #[test]
441    fn count_tokens_oversized_input_uses_fallback_without_cache() {
442        let counter = TokenCounter::new();
443        // Generate input larger than MAX_INPUT_LEN (65536 bytes)
444        let large = "a".repeat(MAX_INPUT_LEN + 1);
445        let result = counter.count_tokens(&large);
446        // chars/4 fallback: (65537 chars) / 4
447        assert_eq!(result, zeph_common::text::estimate_tokens(&large));
448        // Must not be cached
449        assert!(counter.cache.is_empty());
450    }
451
452    #[test]
453    fn count_tokens_unicode_bpe_differs_from_fallback() {
454        let counter = TokenCounter::new();
455        let text = "Привет мир! 你好世界! こんにちは! 🌍";
456        let bpe_count = counter.count_tokens(text);
457        let fallback_count = zeph_common::text::estimate_tokens(text);
458        // BPE should return > 0
459        assert!(bpe_count > 0, "BPE count must be positive");
460        // BPE result should differ from naive chars/4 for multi-byte text
461        assert_ne!(
462            bpe_count, fallback_count,
463            "BPE tokenization should differ from chars/4 for unicode text"
464        );
465    }
466
467    #[test]
468    fn count_tool_schema_tokens_sample() {
469        let counter = TokenCounter::new();
470        let schema = serde_json::json!({
471            "name": "get_weather",
472            "description": "Get the current weather for a location",
473            "parameters": {
474                "type": "object",
475                "properties": {
476                    "location": {
477                        "type": "string",
478                        "description": "The city name"
479                    }
480                },
481                "required": ["location"]
482            }
483        });
484        let tokens = counter.count_tool_schema_tokens(&schema);
485        // Pinned value: computed by running the formula with cl100k_base on this exact schema.
486        // If this fails, a tokenizer or formula change likely occurred.
487        assert_eq!(tokens, 82);
488    }
489
490    #[test]
491    fn two_instances_share_bpe_pointer() {
492        let a = TokenCounter::new();
493        let b = TokenCounter::new();
494        assert!(std::ptr::eq(a.bpe, b.bpe));
495    }
496
497    #[test]
498    fn cache_eviction_at_capacity() {
499        let counter = counter_with_no_bpe(3);
500        let _ = counter.count_tokens("aaaa");
501        let _ = counter.count_tokens("bbbb");
502        let _ = counter.count_tokens("cccc");
503        assert_eq!(counter.cache.len(), 3);
504        // This should evict one entry
505        let _ = counter.count_tokens("dddd");
506        assert_eq!(counter.cache.len(), 3);
507    }
508}