Skip to main content

zeph_memory/
token_counter.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6
7use dashmap::DashMap;
8use tiktoken_rs::CoreBPE;
9
10const CACHE_CAP: usize = 10_000;
11/// Inputs larger than this limit bypass BPE encoding and use the chars/4 fallback.
12/// Prevents CPU amplification from pathologically large inputs.
13const MAX_INPUT_LEN: usize = 65_536;
14
15// OpenAI function-calling token overhead constants
16const FUNC_INIT: usize = 7;
17const PROP_INIT: usize = 3;
18const PROP_KEY: usize = 3;
19const ENUM_INIT: isize = -3;
20const ENUM_ITEM: usize = 3;
21const FUNC_END: usize = 12;
22
23pub struct TokenCounter {
24    bpe: Option<CoreBPE>,
25    cache: DashMap<u64, usize>,
26    cache_cap: usize,
27}
28
29impl TokenCounter {
30    /// Create a new counter. Falls back to chars/4 if tiktoken init fails.
31    #[must_use]
32    pub fn new() -> Self {
33        let bpe = match tiktoken_rs::cl100k_base() {
34            Ok(b) => Some(b),
35            Err(e) => {
36                tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
37                None
38            }
39        };
40        Self {
41            bpe,
42            cache: DashMap::new(),
43            cache_cap: CACHE_CAP,
44        }
45    }
46
47    /// Count tokens in text. Uses cache, falls back to heuristic.
48    ///
49    /// Inputs exceeding 64 KiB bypass BPE and use chars/4 without caching to
50    /// avoid CPU amplification from oversized inputs.
51    #[must_use]
52    pub fn count_tokens(&self, text: &str) -> usize {
53        if text.is_empty() {
54            return 0;
55        }
56
57        if text.len() > MAX_INPUT_LEN {
58            return text.chars().count() / 4;
59        }
60
61        let key = hash_text(text);
62
63        if let Some(cached) = self.cache.get(&key) {
64            return *cached;
65        }
66
67        let count = match &self.bpe {
68            Some(bpe) => bpe.encode_with_special_tokens(text).len(),
69            None => text.chars().count() / 4,
70        };
71
72        // TOCTOU between len() check and insert is benign: worst case we evict
73        // one extra entry and temporarily exceed the cap by one slot.
74        if self.cache.len() >= self.cache_cap {
75            let key_to_evict = self.cache.iter().next().map(|e| *e.key());
76            if let Some(k) = key_to_evict {
77                self.cache.remove(&k);
78            }
79        }
80        self.cache.insert(key, count);
81
82        count
83    }
84
85    /// Count tokens for an `OpenAI` tool/function schema `JSON` value.
86    #[must_use]
87    pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
88        let base = count_schema_value(self, schema);
89        let total =
90            base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
91        total.max(0).cast_unsigned()
92    }
93}
94
95impl Default for TokenCounter {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101fn hash_text(text: &str) -> u64 {
102    let mut hasher = DefaultHasher::new();
103    text.hash(&mut hasher);
104    hasher.finish()
105}
106
107fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
108    match value {
109        serde_json::Value::Object(map) => {
110            let mut tokens = PROP_INIT;
111            for (key, val) in map {
112                tokens += PROP_KEY + counter.count_tokens(key);
113                tokens += count_schema_value(counter, val);
114            }
115            tokens
116        }
117        serde_json::Value::Array(arr) => {
118            let mut tokens = ENUM_ITEM;
119            for item in arr {
120                tokens += count_schema_value(counter, item);
121            }
122            tokens
123        }
124        serde_json::Value::String(s) => counter.count_tokens(s),
125        serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn count_tokens_empty() {
135        let counter = TokenCounter::new();
136        assert_eq!(counter.count_tokens(""), 0);
137    }
138
139    #[test]
140    fn count_tokens_non_empty() {
141        let counter = TokenCounter::new();
142        assert!(counter.count_tokens("hello world") > 0);
143    }
144
145    #[test]
146    fn count_tokens_cache_hit_returns_same() {
147        let counter = TokenCounter::new();
148        let text = "the quick brown fox";
149        let first = counter.count_tokens(text);
150        let second = counter.count_tokens(text);
151        assert_eq!(first, second);
152    }
153
154    #[test]
155    fn count_tokens_fallback_mode() {
156        let counter = TokenCounter {
157            bpe: None,
158            cache: DashMap::new(),
159            cache_cap: CACHE_CAP,
160        };
161        // 8 chars / 4 = 2
162        assert_eq!(counter.count_tokens("abcdefgh"), 2);
163        assert_eq!(counter.count_tokens(""), 0);
164    }
165
166    #[test]
167    fn count_tokens_oversized_input_uses_fallback_without_cache() {
168        let counter = TokenCounter::new();
169        // Generate input larger than MAX_INPUT_LEN (65536 bytes)
170        let large = "a".repeat(MAX_INPUT_LEN + 1);
171        let result = counter.count_tokens(&large);
172        // chars/4 fallback: (65537 chars) / 4
173        assert_eq!(result, large.chars().count() / 4);
174        // Must not be cached
175        assert!(counter.cache.is_empty());
176    }
177
178    #[test]
179    fn count_tokens_unicode_bpe_differs_from_fallback() {
180        let counter = TokenCounter::new();
181        let text = "Привет мир! 你好世界! こんにちは! 🌍";
182        let bpe_count = counter.count_tokens(text);
183        let fallback_count = text.chars().count() / 4;
184        // BPE should return > 0
185        assert!(bpe_count > 0, "BPE count must be positive");
186        // BPE result should differ from naive chars/4 for multi-byte text
187        assert_ne!(
188            bpe_count, fallback_count,
189            "BPE tokenization should differ from chars/4 for unicode text"
190        );
191    }
192
193    #[test]
194    fn count_tool_schema_tokens_sample() {
195        let counter = TokenCounter::new();
196        let schema = serde_json::json!({
197            "name": "get_weather",
198            "description": "Get the current weather for a location",
199            "parameters": {
200                "type": "object",
201                "properties": {
202                    "location": {
203                        "type": "string",
204                        "description": "The city name"
205                    }
206                },
207                "required": ["location"]
208            }
209        });
210        let tokens = counter.count_tool_schema_tokens(&schema);
211        // Pinned value: computed by running the formula with cl100k_base on this exact schema.
212        // If this fails, a tokenizer or formula change likely occurred.
213        assert_eq!(tokens, 82);
214    }
215
216    #[test]
217    fn cache_eviction_at_capacity() {
218        let counter = TokenCounter {
219            bpe: None,
220            cache: DashMap::new(),
221            cache_cap: 3,
222        };
223        let _ = counter.count_tokens("aaaa");
224        let _ = counter.count_tokens("bbbb");
225        let _ = counter.count_tokens("cccc");
226        assert_eq!(counter.cache.len(), 3);
227        // This should evict one entry
228        let _ = counter.count_tokens("dddd");
229        assert_eq!(counter.cache.len(), 3);
230    }
231}