zeph_memory/
token_counter.rs1use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6
7use dashmap::DashMap;
8use tiktoken_rs::CoreBPE;
9
10const CACHE_CAP: usize = 10_000;
11const MAX_INPUT_LEN: usize = 65_536;
14
15const FUNC_INIT: usize = 7;
17const PROP_INIT: usize = 3;
18const PROP_KEY: usize = 3;
19const ENUM_INIT: isize = -3;
20const ENUM_ITEM: usize = 3;
21const FUNC_END: usize = 12;
22
23pub struct TokenCounter {
24 bpe: Option<CoreBPE>,
25 cache: DashMap<u64, usize>,
26 cache_cap: usize,
27}
28
29impl TokenCounter {
30 #[must_use]
32 pub fn new() -> Self {
33 let bpe = match tiktoken_rs::cl100k_base() {
34 Ok(b) => Some(b),
35 Err(e) => {
36 tracing::warn!("tiktoken cl100k_base init failed, using chars/4 fallback: {e}");
37 None
38 }
39 };
40 Self {
41 bpe,
42 cache: DashMap::new(),
43 cache_cap: CACHE_CAP,
44 }
45 }
46
47 #[must_use]
52 pub fn count_tokens(&self, text: &str) -> usize {
53 if text.is_empty() {
54 return 0;
55 }
56
57 if text.len() > MAX_INPUT_LEN {
58 return text.chars().count() / 4;
59 }
60
61 let key = hash_text(text);
62
63 if let Some(cached) = self.cache.get(&key) {
64 return *cached;
65 }
66
67 let count = match &self.bpe {
68 Some(bpe) => bpe.encode_with_special_tokens(text).len(),
69 None => text.chars().count() / 4,
70 };
71
72 if self.cache.len() >= self.cache_cap {
75 let key_to_evict = self.cache.iter().next().map(|e| *e.key());
76 if let Some(k) = key_to_evict {
77 self.cache.remove(&k);
78 }
79 }
80 self.cache.insert(key, count);
81
82 count
83 }
84
85 #[must_use]
87 pub fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
88 let base = count_schema_value(self, schema);
89 let total =
90 base.cast_signed() + ENUM_INIT + FUNC_INIT.cast_signed() + FUNC_END.cast_signed();
91 total.max(0).cast_unsigned()
92 }
93}
94
95impl Default for TokenCounter {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101fn hash_text(text: &str) -> u64 {
102 let mut hasher = DefaultHasher::new();
103 text.hash(&mut hasher);
104 hasher.finish()
105}
106
107fn count_schema_value(counter: &TokenCounter, value: &serde_json::Value) -> usize {
108 match value {
109 serde_json::Value::Object(map) => {
110 let mut tokens = PROP_INIT;
111 for (key, val) in map {
112 tokens += PROP_KEY + counter.count_tokens(key);
113 tokens += count_schema_value(counter, val);
114 }
115 tokens
116 }
117 serde_json::Value::Array(arr) => {
118 let mut tokens = ENUM_ITEM;
119 for item in arr {
120 tokens += count_schema_value(counter, item);
121 }
122 tokens
123 }
124 serde_json::Value::String(s) => counter.count_tokens(s),
125 serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::Null => 1,
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn count_tokens_empty() {
135 let counter = TokenCounter::new();
136 assert_eq!(counter.count_tokens(""), 0);
137 }
138
139 #[test]
140 fn count_tokens_non_empty() {
141 let counter = TokenCounter::new();
142 assert!(counter.count_tokens("hello world") > 0);
143 }
144
145 #[test]
146 fn count_tokens_cache_hit_returns_same() {
147 let counter = TokenCounter::new();
148 let text = "the quick brown fox";
149 let first = counter.count_tokens(text);
150 let second = counter.count_tokens(text);
151 assert_eq!(first, second);
152 }
153
154 #[test]
155 fn count_tokens_fallback_mode() {
156 let counter = TokenCounter {
157 bpe: None,
158 cache: DashMap::new(),
159 cache_cap: CACHE_CAP,
160 };
161 assert_eq!(counter.count_tokens("abcdefgh"), 2);
163 assert_eq!(counter.count_tokens(""), 0);
164 }
165
166 #[test]
167 fn count_tokens_oversized_input_uses_fallback_without_cache() {
168 let counter = TokenCounter::new();
169 let large = "a".repeat(MAX_INPUT_LEN + 1);
171 let result = counter.count_tokens(&large);
172 assert_eq!(result, large.chars().count() / 4);
174 assert!(counter.cache.is_empty());
176 }
177
178 #[test]
179 fn count_tokens_unicode_bpe_differs_from_fallback() {
180 let counter = TokenCounter::new();
181 let text = "Привет мир! 你好世界! こんにちは! 🌍";
182 let bpe_count = counter.count_tokens(text);
183 let fallback_count = text.chars().count() / 4;
184 assert!(bpe_count > 0, "BPE count must be positive");
186 assert_ne!(
188 bpe_count, fallback_count,
189 "BPE tokenization should differ from chars/4 for unicode text"
190 );
191 }
192
193 #[test]
194 fn count_tool_schema_tokens_sample() {
195 let counter = TokenCounter::new();
196 let schema = serde_json::json!({
197 "name": "get_weather",
198 "description": "Get the current weather for a location",
199 "parameters": {
200 "type": "object",
201 "properties": {
202 "location": {
203 "type": "string",
204 "description": "The city name"
205 }
206 },
207 "required": ["location"]
208 }
209 });
210 let tokens = counter.count_tool_schema_tokens(&schema);
211 assert_eq!(tokens, 82);
214 }
215
216 #[test]
217 fn cache_eviction_at_capacity() {
218 let counter = TokenCounter {
219 bpe: None,
220 cache: DashMap::new(),
221 cache_cap: 3,
222 };
223 let _ = counter.count_tokens("aaaa");
224 let _ = counter.count_tokens("bbbb");
225 let _ = counter.count_tokens("cccc");
226 assert_eq!(counter.cache.len(), 3);
227 let _ = counter.count_tokens("dddd");
229 assert_eq!(counter.cache.len(), 3);
230 }
231}