Skip to main content

zeph_memory/
token_counter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::OnceLock;
7
8use dashmap::DashMap;
9use tiktoken_rs::CoreBPE;
10use zeph_common::text::estimate_tokens;
11use zeph_llm::provider::{Message, MessagePart};
12
13static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
14
15const CACHE_CAP: usize = 10_000;
16/// Inputs larger than this limit bypass BPE encoding and use the chars/4 fallback.
17/// Prevents CPU amplification from pathologically large inputs.
18const MAX_INPUT_LEN: usize = 65_536;
19
20// OpenAI function-calling token overhead constants
21const FUNC_INIT: usize = 7;
22const PROP_INIT: usize = 3;
23const PROP_KEY: usize = 3;
24const ENUM_INIT: isize = -3;
25const ENUM_ITEM: usize = 3;
26const FUNC_END: usize = 12;
27
28// Structural overhead per part type (approximate token counts for JSON framing)
29/// `{"type":"tool_use","id":"","name":"","input":}`
30const TOOL_USE_OVERHEAD: usize = 20;
31/// `{"type":"tool_result","tool_use_id":"","content":""}`
32const TOOL_RESULT_OVERHEAD: usize = 15;
33/// `[tool output: <name>]\n` wrapping
34const TOOL_OUTPUT_OVERHEAD: usize = 8;
35/// Image block JSON structure overhead (dimension-based counting unavailable)
36const IMAGE_OVERHEAD: usize = 50;
37/// Default token estimate for a typical image (dimensions not available in `ImageData`)
38const IMAGE_DEFAULT_TOKENS: usize = 1000;
39/// Thinking/redacted block framing
40const THINKING_OVERHEAD: usize = 10;
41
42/// Token counter backed by `tiktoken` `cl100k_base` BPE encoding.
43///
44/// Estimates how many tokens a piece of text or a full [`Message`] will consume when
45/// sent to an LLM API.  Uses a process-scoped [`OnceLock`] so BPE data is loaded once
46/// and shared across all instances.
47///
48/// Falls back to a `chars/4` heuristic when tiktoken init fails or when the input
49/// exceeds `MAX_INPUT_LEN` bytes (64 KiB).
50///
51/// # Examples
52///
53/// ```
54/// use zeph_memory::TokenCounter;
55///
56/// let counter = TokenCounter::new();
57/// let n = counter.count_tokens("Hello, world!");
58/// assert!(n > 0);
59/// ```
60pub struct TokenCounter {
61    bpe: &'static Option<CoreBPE>,
62    cache: DashMap<u64, usize>,
63    cache_cap: usize,
64}
65
66impl TokenCounter {
67    /// Create a new counter. Falls back to chars/4 if tiktoken init fails.
68    ///
69    /// BPE data is loaded once and cached in a `OnceLock` for the process lifetime.
70    #[must_use]
71    pub fn new() -> Self {
72        let bpe = BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
73            Ok(b) => Some(b),
74            Err(e) => {
75                tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
76                None
77            }
78        });
79        Self {
80            bpe,
81            cache: DashMap::new(),
82            cache_cap: CACHE_CAP,
83        }
84    }
85
86    /// Count tokens in text. Uses cache, falls back to heuristic.
87    ///
88    /// Inputs exceeding 64 KiB bypass BPE and use chars/4 without caching to
89    /// avoid CPU amplification from oversized inputs.
90    #[must_use]
91    pub fn count_tokens(&self, text: &str) -> usize {
92        if text.is_empty() {
93            return 0;
94        }
95
96        if text.len() > MAX_INPUT_LEN {
97            return zeph_common::text::estimate_tokens(text);
98        }
99
100        let key = hash_text(text);
101
102        if let Some(cached) = self.cache.get(&key) {
103            return *cached;
104        }
105
106        let count = match self.bpe {
107            Some(bpe) => bpe.encode_with_special_tokens(text).len(),
108            None => zeph_common::text::estimate_tokens(text),
109        };
110
111        // TOCTOU between len() check and insert is benign: worst case we evict
112        // one extra entry and temporarily exceed the cap by one slot.
113        if self.cache.len() >= self.cache_cap {
114            let key_to_evict = self.cache.iter().next().map(|e| *e.key());
115            if let Some(k) = key_to_evict {
116                self.cache.remove(&k);
117            }
118        }
119        self.cache.insert(key, count);
120
121        count
122    }
123
124    /// Estimate token count for a message the way the LLM API will see it.
125    ///
126    /// When structured parts exist, counts from parts matching the API payload
127    /// structure. Falls back to `content` (flattened text) when parts is empty.
128    #[must_use]
129    pub fn count_message_tokens(&self, msg: &Message) -> usize {
130        if msg.parts.is_empty() {
131            return self.count_tokens(&msg.content);
132        }
133        msg.parts.iter().map(|p| self.count_part_tokens(p)).sum()
134    }
135
136    /// Estimate tokens for a single [`MessagePart`] matching the API payload structure.
137    #[must_use]
138    fn count_part_tokens(&self, part: &MessagePart) -> usize {
139        match part {
140            MessagePart::Text { text }
141            | MessagePart::Recall { text }
142            | MessagePart::CodeContext { text }
143            | MessagePart::Summary { text }
144            | MessagePart::CrossSession { text } => {
145                if text.trim().is_empty() {
146                    return 0;
147                }
148                self.count_tokens(text)
149            }
150
151            // API always emits `[tool output: {name}]\n{body}` regardless of compacted_at.
152            // When body is emptied by compaction, count_tokens(body) returns 0 naturally.
153            MessagePart::ToolOutput {
154                tool_name, body, ..
155            } => {
156                TOOL_OUTPUT_OVERHEAD
157                    + self.count_tokens(tool_name.as_str())
158                    + self.count_tokens(body)
159            }
160
161            // API sends structured JSON block: `{"type":"tool_use","id":"...","name":"...","input":...}`
162            MessagePart::ToolUse { id, name, input } => {
163                TOOL_USE_OVERHEAD
164                    + self.count_tokens(id)
165                    + self.count_tokens(name)
166                    + self.count_tokens(&input.to_string())
167            }
168
169            // API sends structured block: `{"type":"tool_result","tool_use_id":"...","content":"..."}`
170            MessagePart::ToolResult {
171                tool_use_id,
172                content,
173                ..
174            } => TOOL_RESULT_OVERHEAD + self.count_tokens(tool_use_id) + self.count_tokens(content),
175
176            // Image token count depends on pixel dimensions, which are unavailable in ImageData.
177            // Using a fixed constant is more accurate than bytes-based formula because
178            // Claude's actual formula is width*height based, not payload-size based.
179            MessagePart::Image(_) => IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS,
180
181            // ThinkingBlock is preserved verbatim in multi-turn requests.
182            MessagePart::ThinkingBlock {
183                thinking,
184                signature,
185            } => THINKING_OVERHEAD + self.count_tokens(thinking) + self.count_tokens(signature),
186
187            // RedactedThinkingBlock is an opaque base64 blob — BPE is not meaningful here.
188            MessagePart::RedactedThinkingBlock { data } => {
189                THINKING_OVERHEAD + estimate_tokens(data)
190            }
191
192            // Compaction summary is sent back verbatim to the API.
193            MessagePart::Compaction { summary } => self.count_tokens(summary),
194
195            // Forward-compatibility: unknown future variants have no known token count.
196            _ => 0,
197        }
198    }
199
200    /// Count tokens for an `OpenAI` tool/function schema `JSON` value.
201    #[must_use]
202    pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
203        let base = count_schema_value(self, schema);
204        let total =
205            base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
206        total.max(0).cast_unsigned()
207    }
208}
209
210impl Default for TokenCounter {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216impl zeph_common::memory::TokenCounting for TokenCounter {
217    fn count_tokens(&self, text: &str) -> usize {
218        self.count_tokens(text)
219    }
220
221    fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
222        self.count_tool_schema_tokens(schema)
223    }
224}
225
226fn hash_text(text: &str) -> u64 {
227    let mut hasher = DefaultHasher::new();
228    text.hash(&mut hasher);
229    hasher.finish()
230}
231
232fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
233    match value {
234        serde_json::Value::Object(map) => {
235            let mut tokens = PROP_INIT;
236            for (key, val) in map {
237                tokens += PROP_KEY + counter.count_tokens(key);
238                tokens += count_schema_value(counter, val);
239            }
240            tokens
241        }
242        serde_json::Value::Array(arr) => {
243            let mut tokens = ENUM_ITEM;
244            for item in arr {
245                tokens += count_schema_value(counter, item);
246            }
247            tokens
248        }
249        serde_json::Value::String(s) => counter.count_tokens(s),
250        serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use zeph_llm::provider::{ImageData, Message, MessageMetadata, MessagePart, Role};
258
259    static BPE_NONE: Option<CoreBPE> = None;
260
261    fn counter_with_no_bpe(cache_cap: usize) -> TokenCounter {
262        TokenCounter {
263            bpe: &BPE_NONE,
264            cache: DashMap::new(),
265            cache_cap,
266        }
267    }
268
269    fn make_msg(parts: Vec<MessagePart>) -> Message {
270        Message::from_parts(Role::User, parts)
271    }
272
273    fn make_msg_no_parts(content: &str) -> Message {
274        Message {
275            role: Role::User,
276            content: content.to_string(),
277            parts: vec![],
278            metadata: MessageMetadata::default(),
279        }
280    }
281
282    #[test]
283    fn count_message_tokens_empty_parts_falls_back_to_content() {
284        let counter = TokenCounter::new();
285        let msg = make_msg_no_parts("hello world");
286        assert_eq!(
287            counter.count_message_tokens(&msg),
288            counter.count_tokens("hello world")
289        );
290    }
291
292    #[test]
293    fn count_message_tokens_text_part_matches_count_tokens() {
294        let counter = TokenCounter::new();
295        let text = "the quick brown fox jumps over the lazy dog";
296        let msg = make_msg(vec![MessagePart::Text {
297            text: text.to_string(),
298        }]);
299        assert_eq!(
300            counter.count_message_tokens(&msg),
301            counter.count_tokens(text)
302        );
303    }
304
305    #[test]
306    fn count_message_tokens_tool_use_exceeds_flattened_content() {
307        let counter = TokenCounter::new();
308        // Large JSON input: structured counting should be higher than flattened "[tool_use: bash(id)]"
309        let input = serde_json::json!({"command": "find /home -name '*.rs' -type f | head -100"});
310        let msg = make_msg(vec![MessagePart::ToolUse {
311            id: "toolu_abc".into(),
312            name: "bash".into(),
313            input,
314        }]);
315        let structured = counter.count_message_tokens(&msg);
316        let flattened = counter.count_tokens(&msg.content);
317        assert!(
318            structured > flattened,
319            "structured={structured} should exceed flattened={flattened}"
320        );
321    }
322
323    #[test]
324    fn count_message_tokens_compacted_tool_output_is_small() {
325        let counter = TokenCounter::new();
326        // Compacted ToolOutput has empty body — should count close to overhead only
327        let msg = make_msg(vec![MessagePart::ToolOutput {
328            tool_name: "bash".into(),
329            body: String::new(),
330            compacted_at: Some(1_700_000_000),
331        }]);
332        let tokens = counter.count_message_tokens(&msg);
333        // Should be small: TOOL_OUTPUT_OVERHEAD + count_tokens("bash") + 0
334        assert!(
335            tokens <= 15,
336            "compacted tool output should be small, got {tokens}"
337        );
338    }
339
340    #[test]
341    fn count_message_tokens_image_returns_constant() {
342        let counter = TokenCounter::new();
343        let msg = make_msg(vec![MessagePart::Image(Box::new(ImageData {
344            data: vec![0u8; 1000],
345            mime_type: "image/jpeg".into(),
346        }))]);
347        assert_eq!(
348            counter.count_message_tokens(&msg),
349            IMAGE_OVERHEAD + IMAGE_DEFAULT_TOKENS
350        );
351    }
352
353    #[test]
354    fn count_message_tokens_thinking_block_counts_text() {
355        let counter = TokenCounter::new();
356        let thinking = "step by step reasoning about the problem";
357        let signature = "sig";
358        let msg = make_msg(vec![MessagePart::ThinkingBlock {
359            thinking: thinking.to_string(),
360            signature: signature.to_string(),
361        }]);
362        let expected =
363            THINKING_OVERHEAD + counter.count_tokens(thinking) + counter.count_tokens(signature);
364        assert_eq!(counter.count_message_tokens(&msg), expected);
365    }
366
367    #[test]
368    fn count_part_tokens_empty_text_returns_zero() {
369        let counter = TokenCounter::new();
370        assert_eq!(
371            counter.count_part_tokens(&MessagePart::Text {
372                text: String::new()
373            }),
374            0
375        );
376        assert_eq!(
377            counter.count_part_tokens(&MessagePart::Text {
378                text: "   ".to_string()
379            }),
380            0
381        );
382        assert_eq!(
383            counter.count_part_tokens(&MessagePart::Recall {
384                text: "\n\t".to_string()
385            }),
386            0
387        );
388    }
389
390    #[test]
391    fn count_message_tokens_push_recompute_consistency() {
392        // Verify that sum of count_message_tokens per part equals recompute result
393        let counter = TokenCounter::new();
394        let parts = vec![
395            MessagePart::Text {
396                text: "hello".into(),
397            },
398            MessagePart::ToolOutput {
399                tool_name: "bash".into(),
400                body: "output data".into(),
401                compacted_at: None,
402            },
403        ];
404        let msg = make_msg(parts);
405        let total = counter.count_message_tokens(&msg);
406        let sum: usize = msg.parts.iter().map(|p| counter.count_part_tokens(p)).sum();
407        assert_eq!(total, sum);
408    }
409
410    #[test]
411    fn count_message_tokens_parts_take_priority_over_content() {
412        // R-2: primary regression guard — when parts is non-empty, content is ignored.
413        let counter = TokenCounter::new();
414        let parts_text = "hello from parts";
415        let msg = Message {
416            role: Role::User,
417            content: "completely different content that should be ignored".to_string(),
418            parts: vec![MessagePart::Text {
419                text: parts_text.to_string(),
420            }],
421            metadata: MessageMetadata::default(),
422        };
423        let parts_based = counter.count_tokens(parts_text);
424        let content_based = counter.count_tokens(&msg.content);
425        assert_ne!(
426            parts_based, content_based,
427            "test setup: parts and content must differ"
428        );
429        assert_eq!(counter.count_message_tokens(&msg), parts_based);
430    }
431
432    #[test]
433    fn count_part_tokens_tool_result() {
434        // R-3: verify ToolResult arm counting
435        let counter = TokenCounter::new();
436        let tool_use_id = "toolu_xyz";
437        let content = "result text";
438        let part = MessagePart::ToolResult {
439            tool_use_id: tool_use_id.to_string(),
440            content: content.to_string(),
441            is_error: false,
442        };
443        let expected = TOOL_RESULT_OVERHEAD
444            + counter.count_tokens(tool_use_id)
445            + counter.count_tokens(content);
446        assert_eq!(counter.count_part_tokens(&part), expected);
447    }
448
449    #[test]
450    fn count_tokens_empty() {
451        let counter = TokenCounter::new();
452        assert_eq!(counter.count_tokens(""), 0);
453    }
454
455    #[test]
456    fn count_tokens_non_empty() {
457        let counter = TokenCounter::new();
458        assert!(counter.count_tokens("hello world") > 0);
459    }
460
461    #[test]
462    fn count_tokens_cache_hit_returns_same() {
463        let counter = TokenCounter::new();
464        let text = "the quick brown fox";
465        let first = counter.count_tokens(text);
466        let second = counter.count_tokens(text);
467        assert_eq!(first, second);
468    }
469
470    #[test]
471    fn count_tokens_fallback_mode() {
472        let counter = counter_with_no_bpe(CACHE_CAP);
473        // 8 chars / 4 = 2
474        assert_eq!(counter.count_tokens("abcdefgh"), 2);
475        assert_eq!(counter.count_tokens(""), 0);
476    }
477
478    #[test]
479    fn count_tokens_oversized_input_uses_fallback_without_cache() {
480        let counter = TokenCounter::new();
481        // Generate input larger than MAX_INPUT_LEN (65536 bytes)
482        let large = "a".repeat(MAX_INPUT_LEN + 1);
483        let result = counter.count_tokens(&large);
484        // chars/4 fallback: (65537 chars) / 4
485        assert_eq!(result, zeph_common::text::estimate_tokens(&large));
486        // Must not be cached
487        assert!(counter.cache.is_empty());
488    }
489
490    #[test]
491    fn count_tokens_unicode_bpe_differs_from_fallback() {
492        let counter = TokenCounter::new();
493        let text = "Привет мир! 你好世界! こんにちは! 🌍";
494        let bpe_count = counter.count_tokens(text);
495        let fallback_count = zeph_common::text::estimate_tokens(text);
496        // BPE should return > 0
497        assert!(bpe_count > 0, "BPE count must be positive");
498        // BPE result should differ from naive chars/4 for multi-byte text
499        assert_ne!(
500            bpe_count, fallback_count,
501            "BPE tokenization should differ from chars/4 for unicode text"
502        );
503    }
504
505    #[test]
506    fn count_tool_schema_tokens_sample() {
507        let counter = TokenCounter::new();
508        let schema = serde_json::json!({
509            "name": "get_weather",
510            "description": "Get the current weather for a location",
511            "parameters": {
512                "type": "object",
513                "properties": {
514                    "location": {
515                        "type": "string",
516                        "description": "The city name"
517                    }
518                },
519                "required": ["location"]
520            }
521        });
522        let tokens = counter.count_tool_schema_tokens(&schema);
523        // Pinned value: computed by running the formula with cl100k_base on this exact schema.
524        // If this fails, a tokenizer or formula change likely occurred.
525        assert_eq!(tokens, 82);
526    }
527
528    #[test]
529    fn two_instances_share_bpe_pointer() {
530        let a = TokenCounter::new();
531        let b = TokenCounter::new();
532        assert!(std::ptr::eq(a.bpe, b.bpe));
533    }
534
535    #[test]
536    fn cache_eviction_at_capacity() {
537        let counter = counter_with_no_bpe(3);
538        let _ = counter.count_tokens("aaaa");
539        let _ = counter.count_tokens("bbbb");
540        let _ = counter.count_tokens("cccc");
541        assert_eq!(counter.cache.len(), 3);
542        // This should evict one entry
543        let _ = counter.count_tokens("dddd");
544        assert_eq!(counter.cache.len(), 3);
545    }
546}