symbi_runtime/context/
token_counter.rs1use super::types::ConversationItem;
8
9pub trait TokenCounter: Send + Sync {
11 fn count_tokens(&self, text: &str) -> usize;
13
14 fn count_messages(&self, messages: &[ConversationItem]) -> usize {
16 messages
17 .iter()
18 .map(|m| self.count_tokens(&m.content) + 4) .sum()
20 }
21
22 fn model_context_limit(&self) -> usize;
24}
25
26pub fn context_limit_for_model(model: &str) -> usize {
28 let m = model.to_lowercase();
29
30 if m.contains("claude") {
31 return 200_000;
32 }
33 if m.contains("gpt-4o") || m.contains("gpt-4-turbo") || m.contains("o1") || m.contains("o3") {
34 return 128_000;
35 }
36 if m.contains("gpt-4") {
37 return 128_000;
38 }
39 if m.contains("gemini") {
40 return 1_000_000;
41 }
42 if m.contains("qwen") {
43 return 131_072;
44 }
45 if m.contains("llama") {
46 return 128_000;
47 }
48 if m.contains("mistral") || m.contains("mixtral") {
49 return 32_000;
50 }
51 if m.contains("deepseek") {
52 return 128_000;
53 }
54 if m.contains("kimi") || m.contains("moonshot") {
55 return 128_000;
56 }
57 if m.contains("command-r") {
58 return 128_000;
59 }
60
61 32_000
63}
64
65pub struct TiktokenCounter {
70 bpe: tiktoken_rs::CoreBPE,
71 context_limit: usize,
72}
73
74impl TiktokenCounter {
75 pub fn for_model(model: &str) -> Self {
81 let model_lower = model.to_lowercase();
82
83 if model_lower.contains("gpt-4o")
85 || model_lower.contains("o1")
86 || model_lower.contains("o3")
87 {
88 if let Ok(bpe) = tiktoken_rs::o200k_base() {
89 return Self {
90 bpe,
91 context_limit: context_limit_for_model(model),
92 };
93 }
94 }
95
96 let bpe = tiktoken_rs::cl100k_base().expect("tiktoken-rs failed to load cl100k_base");
98 Self {
99 bpe,
100 context_limit: context_limit_for_model(model),
101 }
102 }
103}
104
105impl TokenCounter for TiktokenCounter {
106 fn count_tokens(&self, text: &str) -> usize {
107 self.bpe.encode_with_special_tokens(text).len()
108 }
109
110 fn model_context_limit(&self) -> usize {
111 self.context_limit
112 }
113}
114
115pub fn create_token_counter(model: &str) -> Box<dyn TokenCounter> {
121 let m = model.to_lowercase();
122
123 let use_tiktoken = m.contains("gpt")
125 || m.contains("claude")
126 || m.contains("o1")
127 || m.contains("o3")
128 || m.contains("text-embedding");
129
130 if use_tiktoken {
131 Box::new(TiktokenCounter::for_model(model))
132 } else {
133 Box::new(HeuristicTokenCounter::new(context_limit_for_model(model)))
136 }
137}
138
139pub struct HeuristicTokenCounter {
141 context_limit: usize,
142}
143
144impl HeuristicTokenCounter {
145 pub fn new(context_limit: usize) -> Self {
146 Self { context_limit }
147 }
148}
149
150impl TokenCounter for HeuristicTokenCounter {
151 fn count_tokens(&self, text: &str) -> usize {
152 let raw = (text.len() as f64 / 3.5).ceil() as usize;
153 raw + raw / 7 }
155
156 fn model_context_limit(&self) -> usize {
157 self.context_limit
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn heuristic_counter_counts_tokens() {
167 let counter = HeuristicTokenCounter::new(32_000);
168 let count = counter.count_tokens("hello world");
169 assert!(count > 0, "should count some tokens");
170 assert!(count < 20, "heuristic should be reasonable for short text");
171 }
172
173 #[test]
174 fn heuristic_counter_empty_string() {
175 let counter = HeuristicTokenCounter::new(32_000);
176 assert_eq!(counter.count_tokens(""), 0);
177 }
178
179 #[test]
180 fn heuristic_counter_context_limit() {
181 let counter = HeuristicTokenCounter::new(128_000);
182 assert_eq!(counter.model_context_limit(), 128_000);
183 }
184
185 #[test]
186 fn tiktoken_counter_counts_gpt4o() {
187 let counter = TiktokenCounter::for_model("gpt-4o");
188 let count = counter.count_tokens("Hello, world!");
189 assert!(count > 0);
190 assert!(
191 count < 10,
192 "short greeting should be under 10 tokens, got {count}"
193 );
194 assert_eq!(counter.model_context_limit(), 128_000);
195 }
196
197 #[test]
198 fn tiktoken_counter_counts_claude() {
199 let counter = TiktokenCounter::for_model("claude-sonnet-4-5-20250929");
200 let count = counter.count_tokens("Hello, world!");
201 assert!(count > 0);
202 assert_eq!(counter.model_context_limit(), 200_000);
203 }
204
205 #[test]
206 fn factory_returns_tiktoken_for_openai() {
207 let counter = create_token_counter("gpt-4o");
208 let count = counter.count_tokens("Hello");
209 assert!(count > 0);
210 assert_eq!(counter.model_context_limit(), 128_000);
211 }
212
213 #[test]
214 fn factory_returns_tiktoken_for_claude() {
215 let counter = create_token_counter("claude-haiku-4-5-20251001");
216 let count = counter.count_tokens("Hello");
217 assert!(count > 0);
218 assert_eq!(counter.model_context_limit(), 200_000);
219 }
220
221 #[test]
222 fn factory_returns_heuristic_for_unknown() {
223 let counter = create_token_counter("my-custom-local-model");
224 let count = counter.count_tokens("Hello");
225 assert!(count > 0);
226 assert_eq!(counter.model_context_limit(), 32_000);
227 }
228}