Skip to main content

symbi_runtime/context/
token_counter.rs

1//! Multi-model token counting for context compaction.
2//!
3//! Provides a [`TokenCounter`] trait with implementations for various LLM
4//! providers. Uses tiktoken-rs for OpenAI/Claude models and falls back to
5//! a character-based heuristic for unknown models.
6
7use super::types::ConversationItem;
8
9/// Trait for counting tokens in text and messages.
10pub trait TokenCounter: Send + Sync {
11    /// Count tokens in a single string.
12    fn count_tokens(&self, text: &str) -> usize;
13
14    /// Count tokens across a slice of conversation items.
15    fn count_messages(&self, messages: &[ConversationItem]) -> usize {
16        messages
17            .iter()
18            .map(|m| self.count_tokens(&m.content) + 4) // 4 tokens per-message overhead
19            .sum()
20    }
21
22    /// Return the model's maximum context window size in tokens.
23    fn model_context_limit(&self) -> usize;
24}
25
26/// Look up the context window limit for a model by name.
27pub fn context_limit_for_model(model: &str) -> usize {
28    let m = model.to_lowercase();
29
30    if m.contains("claude") {
31        return 200_000;
32    }
33    if m.contains("gpt-4o") || m.contains("gpt-4-turbo") || m.contains("o1") || m.contains("o3") {
34        return 128_000;
35    }
36    if m.contains("gpt-4") {
37        return 128_000;
38    }
39    if m.contains("gemini") {
40        return 1_000_000;
41    }
42    if m.contains("qwen") {
43        return 131_072;
44    }
45    if m.contains("llama") {
46        return 128_000;
47    }
48    if m.contains("mistral") || m.contains("mixtral") {
49        return 32_000;
50    }
51    if m.contains("deepseek") {
52        return 128_000;
53    }
54    if m.contains("kimi") || m.contains("moonshot") {
55        return 128_000;
56    }
57    if m.contains("command-r") {
58        return 128_000;
59    }
60
61    // Conservative default
62    32_000
63}
64
65/// Token counter using tiktoken-rs (cl100k_base or o200k_base).
66///
67/// Works natively for OpenAI models. For Claude, uses cl100k_base as an
68/// approximation (both are BPE with similar vocab sizes).
69pub struct TiktokenCounter {
70    bpe: tiktoken_rs::CoreBPE,
71    context_limit: usize,
72}
73
74impl TiktokenCounter {
75    /// Create a counter for the given model name.
76    ///
77    /// Resolution order:
78    /// 1. o200k_base for GPT-4o family
79    /// 2. cl100k_base for GPT-4, Claude, and everything else
80    pub fn for_model(model: &str) -> Self {
81        let model_lower = model.to_lowercase();
82
83        // Try o200k_base for GPT-4o family
84        if model_lower.contains("gpt-4o")
85            || model_lower.contains("o1")
86            || model_lower.contains("o3")
87        {
88            if let Ok(bpe) = tiktoken_rs::o200k_base() {
89                return Self {
90                    bpe,
91                    context_limit: context_limit_for_model(model),
92                };
93            }
94        }
95
96        // cl100k_base for GPT-4, Claude, and everything else tiktoken supports
97        let bpe = tiktoken_rs::cl100k_base().expect("tiktoken-rs failed to load cl100k_base");
98        Self {
99            bpe,
100            context_limit: context_limit_for_model(model),
101        }
102    }
103}
104
105impl TokenCounter for TiktokenCounter {
106    fn count_tokens(&self, text: &str) -> usize {
107        self.bpe.encode_with_special_tokens(text).len()
108    }
109
110    fn model_context_limit(&self) -> usize {
111        self.context_limit
112    }
113}
114
115/// Create the best available token counter for the given model.
116///
117/// Resolution:
118/// 1. tiktoken-rs for OpenAI, Claude, and well-known models
119/// 2. Heuristic fallback for unknown models
120pub fn create_token_counter(model: &str) -> Box<dyn TokenCounter> {
121    let m = model.to_lowercase();
122
123    // tiktoken works well for OpenAI, Claude (cl100k approx), and most major models
124    let use_tiktoken = m.contains("gpt")
125        || m.contains("claude")
126        || m.contains("o1")
127        || m.contains("o3")
128        || m.contains("text-embedding");
129
130    if use_tiktoken {
131        Box::new(TiktokenCounter::for_model(model))
132    } else {
133        // For Qwen, Llama, Mistral, Gemini, etc. — use heuristic
134        // (HuggingFace tokenizer loading requires network/cache and is deferred to a future PR)
135        Box::new(HeuristicTokenCounter::new(context_limit_for_model(model)))
136    }
137}
138
139/// Heuristic token counter: chars / 3.5, rounded up, +15% safety margin.
140pub struct HeuristicTokenCounter {
141    context_limit: usize,
142}
143
144impl HeuristicTokenCounter {
145    pub fn new(context_limit: usize) -> Self {
146        Self { context_limit }
147    }
148}
149
150impl TokenCounter for HeuristicTokenCounter {
151    fn count_tokens(&self, text: &str) -> usize {
152        let raw = (text.len() as f64 / 3.5).ceil() as usize;
153        raw + raw / 7 // +~15% safety margin
154    }
155
156    fn model_context_limit(&self) -> usize {
157        self.context_limit
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn heuristic_counter_counts_tokens() {
167        let counter = HeuristicTokenCounter::new(32_000);
168        let count = counter.count_tokens("hello world");
169        assert!(count > 0, "should count some tokens");
170        assert!(count < 20, "heuristic should be reasonable for short text");
171    }
172
173    #[test]
174    fn heuristic_counter_empty_string() {
175        let counter = HeuristicTokenCounter::new(32_000);
176        assert_eq!(counter.count_tokens(""), 0);
177    }
178
179    #[test]
180    fn heuristic_counter_context_limit() {
181        let counter = HeuristicTokenCounter::new(128_000);
182        assert_eq!(counter.model_context_limit(), 128_000);
183    }
184
185    #[test]
186    fn tiktoken_counter_counts_gpt4o() {
187        let counter = TiktokenCounter::for_model("gpt-4o");
188        let count = counter.count_tokens("Hello, world!");
189        assert!(count > 0);
190        assert!(
191            count < 10,
192            "short greeting should be under 10 tokens, got {count}"
193        );
194        assert_eq!(counter.model_context_limit(), 128_000);
195    }
196
197    #[test]
198    fn tiktoken_counter_counts_claude() {
199        let counter = TiktokenCounter::for_model("claude-sonnet-4-5-20250929");
200        let count = counter.count_tokens("Hello, world!");
201        assert!(count > 0);
202        assert_eq!(counter.model_context_limit(), 200_000);
203    }
204
205    #[test]
206    fn factory_returns_tiktoken_for_openai() {
207        let counter = create_token_counter("gpt-4o");
208        let count = counter.count_tokens("Hello");
209        assert!(count > 0);
210        assert_eq!(counter.model_context_limit(), 128_000);
211    }
212
213    #[test]
214    fn factory_returns_tiktoken_for_claude() {
215        let counter = create_token_counter("claude-haiku-4-5-20251001");
216        let count = counter.count_tokens("Hello");
217        assert!(count > 0);
218        assert_eq!(counter.model_context_limit(), 200_000);
219    }
220
221    #[test]
222    fn factory_returns_heuristic_for_unknown() {
223        let counter = create_token_counter("my-custom-local-model");
224        let count = counter.count_tokens("Hello");
225        assert!(count > 0);
226        assert_eq!(counter.model_context_limit(), 32_000);
227    }
228}