Skip to main content

zeph_memory/
token_counter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_common::text::estimate_tokens;
11use zeph_llm::provider::{Message, MessagePart};
12
13static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
14
15const CACHE_CAP: usize = 10_000;
16/// Inputs larger than this limit bypass BPE encoding and use the chars/4 fallback.
17/// Prevents CPU amplification from pathologically large inputs.
18const MAX_INPUT_LEN: usize = 65_536;
19
20// OpenAI function-calling token overhead constants
21const FUNC_INIT: usize = 7;
22const PROP_INIT: usize = 3;
23const PROP_KEY: usize = 3;
24const ENUM_INIT: isize = -3;
25const ENUM_ITEM: usize = 3;
26const FUNC_END: usize = 12;
27
28// Structural overhead per part type (approximate token counts for JSON framing)
29/// `{"type":"tool_use","id":"","name":"","input":}`
30const TOOL_USE_OVERHEAD: usize = 20;
31/// `{"type":"tool_result","tool_use_id":"","content":""}`
32const TOOL_RESULT_OVERHEAD: usize = 15;
33/// `[tool output: <name>]\n` wrapping
34const TOOL_OUTPUT_OVERHEAD: usize = 8;
35/// Image block JSON structure overhead (dimension-based counting unavailable)
36const IMAGE_OVERHEAD: usize = 50;
37/// Default token estimate for a typical image (dimensions not available in `ImageData`)
38const IMAGE_DEFAULT_TOKENS: usize = 1000;
39/// Thinking/redacted block framing
40const THINKING_OVERHEAD: usize = 10;
41
42/// Token counter backed by `tiktoken` `cl100k_base` BPE encoding.
43///
44/// Estimates how many tokens a piece of text or a full [`Message`] will consume when
45/// sent to an LLM API.  Uses a process-scoped [`OnceLock`] so BPE data is loaded once
46/// and shared across all instances.
47///
48/// Falls back to a `chars/4` heuristic when tiktoken init fails or when the input
49/// exceeds `MAX_INPUT_LEN` bytes (64 KiB).
50///
51/// # Examples
52///
53/// ```
54/// use zeph_memory::TokenCounter;
55///
56/// let counter = TokenCounter::new();
57/// let n = counter.count_tokens("Hello, world!");
58/// assert!(n > 0);
59/// ```
60pub struct TokenCounter {
61    bpe: &'static Option<CoreBPE>,
62    cache: DashMap<u64, usize>,
63    cache_cap: usize,
64}
65
66impl TokenCounter {
67    /// Create a new counter. Falls back to chars/4 if tiktoken init fails.
68    ///
69    /// BPE data is loaded once and cached in a `OnceLock` for the process lifetime.
70    #[must_use]
71    pub fn new() -> Self {
72        let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
73            Ok(b) => Some(b),
74            Err(e) => {
75                tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
76                None
77            }
78        });
79        Self {
80            bpe,
81            cache: DashMap::new(),
82            cache_cap: CACHE_CAP,
83        }
84    }
85
86    /// Count tokens in text. Uses cache, falls back to heuristic.
87    ///
88    /// Inputs exceeding 64 KiB bypass BPE and use chars/4 without caching to
89    /// avoid CPU amplification from oversized inputs.
90    #[must_use]
91    pub fn count_tokens(&self, text: &str) -> usize {
92        if text.is_empty() {
93            return 0;
94        }
95
96        if text.len() > MAX_INPUT_LEN {
97            return zeph_common::text::estimate_tokens(text);
98        }
99
100        let key = hash_text(text);
101
102        if let Some(cached) = self.cache.get(&key) {
103            return *cached;
104        }
105
106        let count = match self.bpe {
107            Some(bpe) => bpe.encode_with_special_tokens(text).len(),
108            None => zeph_common::text::estimate_tokens(text),
109        };
110
111        // TOCTOU between len() check and insert is benign: worst case we evict
112        // one extra entry and temporarily exceed the cap by one slot.
113        if self.cache.len() >= self.cache_cap {
114            let key_to_evict = self.cache.iter().next().map(|e| *e.key());
115            if let Some(k) = key_to_evict {
116                self.cache.remove(&k);
117            }
118        }
119        self.cache.insert(key, count);
120
121        count
122    }
123
124    /// Estimate token count for a message the way the LLM API will see it.
125    ///
126    /// When structured parts exist, counts from parts matching the API payload
127    /// structure. Falls back to `content` (flattened text) when parts is empty.
128    #[must_use]
129    pub fn count_message_tokens(&self, msg: &Message) -> usize {
130        if msg.parts.is_empty() {
131            return self.count_tokens(&msg.content);
132        }
133        msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
134    }
135
136    /// Estimate tokens for a single [`MessagePart`] matching the API payload structure.
137    #[must_use]
138    fn count_part_tokens(&self, part: &MessagePart) -> usize {
139        match part {
140            MessagePart::Text { text }
141            | MessagePart::Recall { text }
142            | MessagePart::CodeContext { text }
143            | MessagePart::Summary { text }
144            | MessagePart::CrossSession { text } => {
145                if text.trim().is_empty() {
146                    return 0;
147                }
148                self.count_tokens(text)
149            }
150
151            // API always emits `[tool output: {name}]\n{body}` regardless of compacted_at.
152            // When body is emptied by compaction, count_tokens(body) returns 0 naturally.
153            MessagePart::ToolOutput {
154                tool_name, body, ..
155            } => {
156                TOOL_OUTPUT_OVERHEAD
157                    + self.count_tokens(tool_name.as_str())
158                    + self.count_tokens(body)
159            }
160
161            // API sends structured JSON block: `{"type":"tool_use","id":"...","name":"...","input":...}`
162            MessagePart::ToolUse { id, name, input } => {
163                TOOL_USE_OVERHEAD
164                    + self.count_tokens(id)
165                    + self.count_tokens(name)
166                    + self.count_tokens(&input.to_string())
167            }
168
169            // API sends structured block: `{"type":"tool_result","tool_use_id":"...","content":"..."}`
170            MessagePart::ToolResult {
171                tool_use_id,
172                content,
173                ..
174            } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
175
176            // Image token count depends on pixel dimensions, which are unavailable in ImageData.
177            // Using a fixed constant is more accurate than bytes-based formula because
178            // Claude's actual formula is width*height based, not payload-size based.
179            MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
180
181            // ThinkingBlock is preserved verbatim in multi-turn requests.
182            MessagePart::ThinkingBlock {
183                thinking,
184                signature,
185            } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
186
187            // RedactedThinkingBlock is an opaque base64 blob — BPE is not meaningful here.
188            MessagePart::RedactedThinkingBlock { data } => {
189                THINKING_OVERHEAD + estimate_tokens(data)
190            }
191
192            // Compaction summary is sent back verbatim to the API.
193            MessagePart::Compaction { summary } => self.count_tokens(summary),
194        }
195    }
196
197    /// Count tokens for an `OpenAI` tool/function schema `JSON` value.
198    #[must_use]
199    pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
200        let base = count_schema_value(self, schema);
201        let total =
202            base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
203        total.max(0).cast_unsigned()
204    }
205}
206
207impl Default for TokenCounter {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213impl zeph_common::memory::TokenCounting for TokenCounter {
214    fn count_tokens(&self, text: &str) -> usize {
215        self.count_tokens(text)
216    }
217
218    fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
219        self.count_tool_schema_tokens(schema)
220    }
221}
222
223fn hash_text(text: &str) -> u64 {
224    let mut hasher = DefaultHasher::new();
225    text.hash(&mut hasher);
226    hasher.finish()
227}
228
229fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
230    match value {
231        serde_json::Value::Object(map) => {
232            let mut tokens = PROP_INIT;
233            for (key, val) in map {
234                tokens += PROP_KEY + counter.count_tokens(key);
235                tokens += count_schema_value(counter, val);
236            }
237            tokens
238        }
239        serde_json::Value::Array(arr) => {
240            let mut tokens = ENUM_ITEM;
241            for item in arr {
242                tokens += count_schema_value(counter, item);
243            }
244            tokens
245        }
246        serde_json::Value::String(s) => counter.count_tokens(s),
247        serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
255
256    static BPE_NONE: Option<CoreBPE> = None;
257
258    fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
259        TokenCounter {
260            bpe: &BPE_NONE,
261            cache: DashMap::new(),
262            cache_cap,
263        }
264    }
265
266    fn make_msg(parts: Vec<MessagePart>) -> Message {
267        Message::from_parts(Role::User, parts)
268    }
269
270    fn make_msg_no_parts(content: &str) -> Message {
271        Message {
272            role: Role::User,
273            content: content.to_string(),
274            parts: vec![],
275            metadata: MessageMetadata::default(),
276        }
277    }
278
279    #[test]
280    fn count_message_tokens_empty_parts_falls_back_to_content() {
281        let counter = TokenCounter::new();
282        let msg = make_msg_no_parts("hello world");
283        assert_eq!(
284            counter.count_message_tokens(&msg),
285            counter.count_tokens("hello world")
286        );
287    }
288
289    #[test]
290    fn count_message_tokens_text_part_matches_count_tokens() {
291        let counter = TokenCounter::new();
292        let text = "the quick brown fox jumps over the lazy dog";
293        let msg = make_msg(vec![MessagePart::Text {
294            text: text.to_string(),
295        }]);
296        assert_eq!(
297            counter.count_message_tokens(&msg),
298            counter.count_tokens(text)
299        );
300    }
301
302    #[test]
303    fn count_message_tokens_tool_use_exceeds_flattened_content() {
304        let counter = TokenCounter::new();
305        // Large JSON input: structured counting should be higher than flattened "[tool_use: bash(id)]"
306        let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
307        let msg = make_msg(vec![MessagePart::ToolUse {
308            id: "toolu_abc".into(),
309            name: "bash".into(),
310            input,
311        }]);
312        let structured = counter.count_message_tokens(&msg);
313        let flattened = counter.count_tokens(&msg.content);
314        assert!(
315            structured > flattened,
316            "structured={structured} should exceed flattened={flattened}"
317        );
318    }
319
320    #[test]
321    fn count_message_tokens_compacted_tool_output_is_small() {
322        let counter = TokenCounter::new();
323        // Compacted ToolOutput has empty body — should count close to overhead only
324        let msg = make_msg(vec![MessagePart::ToolOutput {
325            tool_name: "bash".into(),
326            body: String::new(),
327            compacted_at: Some(1_700_000_000),
328        }]);
329        let tokens = counter.count_message_tokens(&msg);
330        // Should be small: TOOL_OUTPUT_OVERHEAD + count_tokens("bash") + 0
331        assert!(
332            tokens <= 15,
333            "compacted tool output should be small, got {tokens}"
334        );
335    }
336
337    #[test]
338    fn count_message_tokens_image_returns_constant() {
339        let counter = TokenCounter::new();
340        let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
341            data: vec![0u8; 1000],
342            mime_type: "image/jpeg".into(),
343        }))]);
344        assert_eq!(
345            counter.count_message_tokens(&msg),
346            IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
347        );
348    }
349
350    #[test]
351    fn count_message_tokens_thinking_block_counts_text() {
352        let counter = TokenCounter::new();
353        let thinking = "step by step reasoning about the problem";
354        let signature = "sig";
355        let msg = make_msg(vec![MessagePart::ThinkingBlock {
356            thinking: thinking.to_string(),
357            signature: signature.to_string(),
358        }]);
359        let expected =
360            THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
361        assert_eq!(counter.count_message_tokens(&msg), expected);
362    }
363
364    #[test]
365    fn count_part_tokens_empty_text_returns_zero() {
366        let counter = TokenCounter::new();
367        assert_eq!(
368            counter.count_part_tokens(&MessagePart::Text {
369                text: String::new()
370            }),
371            0
372        );
373        assert_eq!(
374            counter.count_part_tokens(&MessagePart::Text {
375                text: "   ".to_string()
376            }),
377            0
378        );
379        assert_eq!(
380            counter.count_part_tokens(&MessagePart::Recall {
381                text: "\n\t".to_string()
382            }),
383            0
384        );
385    }
386
387    #[test]
388    fn count_message_tokens_push_recompute_consistency() {
389        // Verify that sum of count_message_tokens per part equals recompute result
390        let counter = TokenCounter::new();
391        let parts = vec![
392            MessagePart::Text {
393                text: "hello".into(),
394            },
395            MessagePart::ToolOutput {
396                tool_name: "bash".into(),
397                body: "output data".into(),
398                compacted_at: None,
399            },
400        ];
401        let msg = make_msg(parts);
402        let total = counter.count_message_tokens(&msg);
403        let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
404        assert_eq!(total, sum);
405    }
406
407    #[test]
408    fn count_message_tokens_parts_take_priority_over_content() {
409        // R-2: primary regression guard — when parts is non-empty, content is ignored.
410        let counter = TokenCounter::new();
411        let parts_text = "hello from parts";
412        let msg = Message {
413            role: Role::User,
414            content: "completely different content that should be ignored".to_string(),
415            parts: vec![MessagePart::Text {
416                text: parts_text.to_string(),
417            }],
418            metadata: MessageMetadata::default(),
419        };
420        let parts_based = counter.count_tokens(parts_text);
421        let content_based = counter.count_tokens(&msg.content);
422        assert_ne!(
423            parts_based, content_based,
424            "test setup: parts and content must differ"
425        );
426        assert_eq!(counter.count_message_tokens(&msg), parts_based);
427    }
428
429    #[test]
430    fn count_part_tokens_tool_result() {
431        // R-3: verify ToolResult arm counting
432        let counter = TokenCounter::new();
433        let tool_use_id = "toolu_xyz";
434        let content = "result text";
435        let part = MessagePart::ToolResult {
436            tool_use_id: tool_use_id.to_string(),
437            content: content.to_string(),
438            is_error: false,
439        };
440        let expected = TOOL_RESULT_OVERHEAD
441            + counter.count_tokens(tool_use_id)
442            + counter.count_tokens(content);
443        assert_eq!(counter.count_part_tokens(&part), expected);
444    }
445
446    #[test]
447    fn count_tokens_empty() {
448        let counter = TokenCounter::new();
449        assert_eq!(counter.count_tokens(""), 0);
450    }
451
452    #[test]
453    fn count_tokens_non_empty() {
454        let counter = TokenCounter::new();
455        assert!(counter.count_tokens("hello world") > 0);
456    }
457
458    #[test]
459    fn count_tokens_cache_hit_returns_same() {
460        let counter = TokenCounter::new();
461        let text = "the quick brown fox";
462        let first = counter.count_tokens(text);
463        let second = counter.count_tokens(text);
464        assert_eq!(first, second);
465    }
466
467    #[test]
468    fn count_tokens_fallback_mode() {
469        let counter = counter_with_no_bpe(CACHE_CAP);
470        // 8 chars / 4 = 2
471        assert_eq!(counter.count_tokens("abcdefgh"), 2);
472        assert_eq!(counter.count_tokens(""), 0);
473    }
474
475    #[test]
476    fn count_tokens_oversized_input_uses_fallback_without_cache() {
477        let counter = TokenCounter::new();
478        // Generate input larger than MAX_INPUT_LEN (65536 bytes)
479        let large = "a".repeat(MAX_INPUT_LEN + 1);
480        let result = counter.count_tokens(&large);
481        // chars/4 fallback: (65537 chars) / 4
482        assert_eq!(result, zeph_common::text::estimate_tokens(&large));
483        // Must not be cached
484        assert!(counter.cache.is_empty());
485    }
486
487    #[test]
488    fn count_tokens_unicode_bpe_differs_from_fallback() {
489        let counter = TokenCounter::new();
490        let text = "Привет мир! 你好世界! こんにちは! 🌍";
491        let bpe_count = counter.count_tokens(text);
492        let fallback_count = zeph_common::text::estimate_tokens(text);
493        // BPE should return > 0
494        assert!(bpe_count > 0, "BPE count must be positive");
495        // BPE result should differ from naive chars/4 for multi-byte text
496        assert_ne!(
497            bpe_count, fallback_count,
498            "BPE tokenization should differ from chars/4 for unicode text"
499        );
500    }
501
502    #[test]
503    fn count_tool_schema_tokens_sample() {
504        let counter = TokenCounter::new();
505        let schema = serde_json::json!({
506            "name": "get_weather",
507            "description": "Get the current weather for a location",
508            "parameters": {
509                "type": "object",
510                "properties": {
511                    "location": {
512                        "type": "string",
513                        "description": "The city name"
514                    }
515                },
516                "required": ["location"]
517            }
518        });
519        let tokens = counter.count_tool_schema_tokens(&schema);
520        // Pinned value: computed by running the formula with cl100k_base on this exact schema.
521        // If this fails, a tokenizer or formula change likely occurred.
522        assert_eq!(tokens, 82);
523    }
524
525    #[test]
526    fn two_instances_share_bpe_pointer() {
527        let a = TokenCounter::new();
528        let b = TokenCounter::new();
529        assert!(std::ptr::eq(a.bpe, b.bpe));
530    }
531
532    #[test]
533    fn cache_eviction_at_capacity() {
534        let counter = counter_with_no_bpe(3);
535        let _ = counter.count_tokens("aaaa");
536        let _ = counter.count_tokens("bbbb");
537        let _ = counter.count_tokens("cccc");
538        assert_eq!(counter.cache.len(), 3);
539        // This should evict one entry
540        let _ = counter.count_tokens("dddd");
541        assert_eq!(counter.cache.len(), 3);
542    }
543}