turbomcp_client/llm/
tokens.rs

1//! Token counting and context management utilities
2//!
3//! Provides utilities for counting tokens, managing context windows, and optimizing
4//! prompt length for different LLM providers.
5
6use crate::llm::core::{LLMError, LLMMessage, LLMResult, MessageRole};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Token usage information with detailed breakdown
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12pub struct TokenUsage {
13    /// Input/prompt tokens
14    pub prompt_tokens: usize,
15
16    /// Output/completion tokens
17    pub completion_tokens: usize,
18
19    /// Total tokens used
20    pub total_tokens: usize,
21
22    /// Cached tokens (if applicable)
23    pub cached_tokens: Option<usize>,
24
25    /// Tokens from images (if applicable)
26    pub image_tokens: Option<usize>,
27}
28
29impl TokenUsage {
30    /// Create new token usage
31    pub fn new(prompt_tokens: usize, completion_tokens: usize) -> Self {
32        Self {
33            prompt_tokens,
34            completion_tokens,
35            total_tokens: prompt_tokens + completion_tokens,
36            cached_tokens: None,
37            image_tokens: None,
38        }
39    }
40
41    /// Create empty token usage
42    pub fn empty() -> Self {
43        Self::new(0, 0)
44    }
45
46    /// Add cached tokens
47    pub fn with_cached_tokens(mut self, cached_tokens: usize) -> Self {
48        self.cached_tokens = Some(cached_tokens);
49        self
50    }
51
52    /// Add image tokens
53    pub fn with_image_tokens(mut self, image_tokens: usize) -> Self {
54        self.image_tokens = Some(image_tokens);
55        self
56    }
57
58    /// Add to existing usage
59    pub fn add(&mut self, other: &TokenUsage) {
60        self.prompt_tokens += other.prompt_tokens;
61        self.completion_tokens += other.completion_tokens;
62        self.total_tokens += other.total_tokens;
63
64        if let Some(other_cached) = other.cached_tokens {
65            self.cached_tokens = Some(self.cached_tokens.unwrap_or(0) + other_cached);
66        }
67
68        if let Some(other_image) = other.image_tokens {
69            self.image_tokens = Some(self.image_tokens.unwrap_or(0) + other_image);
70        }
71    }
72
73    /// Calculate cost estimate (in USD)
74    pub fn estimate_cost(&self, input_cost_per_1k: f64, output_cost_per_1k: f64) -> f64 {
75        let prompt_cost = (self.prompt_tokens as f64 / 1000.0) * input_cost_per_1k;
76        let completion_cost = (self.completion_tokens as f64 / 1000.0) * output_cost_per_1k;
77        prompt_cost + completion_cost
78    }
79}
80
81/// Context window configuration and management
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ContextWindow {
84    /// Maximum tokens in context window
85    pub max_tokens: usize,
86
87    /// Reserved tokens for the response
88    pub response_reserve: usize,
89
90    /// Reserved tokens for system message
91    pub system_reserve: usize,
92
93    /// Minimum tokens to keep from conversation history
94    pub history_minimum: usize,
95}
96
97impl ContextWindow {
98    /// Create new context window configuration
99    pub fn new(max_tokens: usize) -> Self {
100        Self {
101            max_tokens,
102            response_reserve: max_tokens / 4, // Reserve 25% for response
103            system_reserve: 500,              // Reserve for system message
104            history_minimum: 1000,            // Keep at least 1k tokens of history
105        }
106    }
107
108    /// Get available tokens for conversation history
109    pub fn available_for_history(&self) -> usize {
110        self.max_tokens
111            .saturating_sub(self.response_reserve)
112            .saturating_sub(self.system_reserve)
113    }
114
115    /// Check if token count fits in context window
116    pub fn fits(&self, token_count: usize) -> bool {
117        token_count <= self.available_for_history()
118    }
119
120    /// Calculate how many tokens to truncate
121    pub fn tokens_to_truncate(&self, current_tokens: usize) -> usize {
122        if self.fits(current_tokens) {
123            0
124        } else {
125            current_tokens - self.available_for_history()
126        }
127    }
128}
129
130/// Token counter for different providers and models
131#[derive(Debug)]
132pub struct TokenCounter {
133    /// Model-specific token estimates
134    model_estimates: HashMap<String, f64>,
135
136    /// Provider-specific multipliers
137    provider_multipliers: HashMap<String, f64>,
138}
139
140impl Default for TokenCounter {
141    fn default() -> Self {
142        let mut model_estimates = HashMap::new();
143        let mut provider_multipliers = HashMap::new();
144
145        // OpenAI models - tokens per character
146        model_estimates.insert("gpt-3.5-turbo".to_string(), 0.25);
147        model_estimates.insert("gpt-4".to_string(), 0.25);
148        model_estimates.insert("gpt-4-turbo".to_string(), 0.25);
149        model_estimates.insert("gpt-4o".to_string(), 0.25);
150
151        // Anthropic models - slightly different tokenization
152        model_estimates.insert("claude-3-haiku-20240307".to_string(), 0.24);
153        model_estimates.insert("claude-3-sonnet-20240229".to_string(), 0.24);
154        model_estimates.insert("claude-3-opus-20240229".to_string(), 0.24);
155        model_estimates.insert("claude-3-5-sonnet-20240620".to_string(), 0.24);
156
157        // Provider multipliers for conversation overhead
158        provider_multipliers.insert("openai".to_string(), 1.1);
159        provider_multipliers.insert("anthropic".to_string(), 1.05);
160        provider_multipliers.insert("ollama".to_string(), 1.0);
161
162        Self {
163            model_estimates,
164            provider_multipliers,
165        }
166    }
167}
168
169impl TokenCounter {
170    /// Create a new token counter
171    pub fn new() -> Self {
172        Self::default()
173    }
174
175    /// Add custom model estimate
176    pub fn add_model_estimate(&mut self, model: String, tokens_per_char: f64) {
177        self.model_estimates.insert(model, tokens_per_char);
178    }
179
180    /// Add provider multiplier
181    pub fn add_provider_multiplier(&mut self, provider: String, multiplier: f64) {
182        self.provider_multipliers.insert(provider, multiplier);
183    }
184
185    /// Estimate tokens for text
186    pub fn estimate_text_tokens(&self, text: &str, model: Option<&str>) -> usize {
187        let base_estimate = if let Some(model) = model {
188            let tokens_per_char = self.model_estimates.get(model).copied().unwrap_or(0.25); // Default fallback
189            (text.len() as f64 * tokens_per_char) as usize
190        } else {
191            // Simple fallback: ~4 chars per token
192            text.len().div_ceil(4)
193        };
194
195        base_estimate.max(1) // At least 1 token
196    }
197
198    /// Estimate tokens for a message
199    pub fn estimate_message_tokens(
200        &self,
201        message: &LLMMessage,
202        model: Option<&str>,
203        provider: Option<&str>,
204    ) -> usize {
205        let base_tokens = match &message.content {
206            crate::llm::core::MessageContent::Text { text } => {
207                self.estimate_text_tokens(text, model)
208            }
209            crate::llm::core::MessageContent::Image { .. } => {
210                // Image tokens vary by provider and detail level
211                match provider {
212                    Some("openai") => 765,     // GPT-4V standard image cost
213                    Some("anthropic") => 1568, // Claude 3 image cost
214                    _ => 1000,                 // Conservative estimate
215                }
216            }
217            crate::llm::core::MessageContent::ToolCall { arguments, .. } => {
218                let args_str = arguments.to_string();
219                self.estimate_text_tokens(&args_str, model) + 10 // Tool call overhead
220            }
221            crate::llm::core::MessageContent::ToolResult { result, .. } => {
222                let result_str = result.to_string();
223                self.estimate_text_tokens(&result_str, model) + 5 // Tool result overhead
224            }
225        };
226
227        // Add message overhead (role, formatting, etc.)
228        let message_overhead = match message.role {
229            MessageRole::System => 10,
230            MessageRole::User => 5,
231            MessageRole::Assistant => 5,
232            MessageRole::Function => 15,
233        };
234
235        let total_tokens = base_tokens + message_overhead;
236
237        // Apply provider multiplier
238        if let Some(provider) = provider {
239            let multiplier = self
240                .provider_multipliers
241                .get(provider)
242                .copied()
243                .unwrap_or(1.0);
244            (total_tokens as f64 * multiplier) as usize
245        } else {
246            total_tokens
247        }
248    }
249
250    /// Estimate tokens for a conversation
251    pub fn estimate_conversation_tokens(
252        &self,
253        messages: &[LLMMessage],
254        model: Option<&str>,
255        provider: Option<&str>,
256    ) -> usize {
257        let message_tokens: usize = messages
258            .iter()
259            .map(|msg| self.estimate_message_tokens(msg, model, provider))
260            .sum();
261
262        // Add conversation overhead
263        let conversation_overhead = messages.len() * 2;
264
265        message_tokens + conversation_overhead
266    }
267
268    /// Truncate messages to fit in context window
269    pub fn truncate_to_fit(
270        &self,
271        messages: Vec<LLMMessage>,
272        context_window: &ContextWindow,
273        model: Option<&str>,
274        provider: Option<&str>,
275    ) -> LLMResult<Vec<LLMMessage>> {
276        let total_tokens = self.estimate_conversation_tokens(&messages, model, provider);
277
278        if context_window.fits(total_tokens) {
279            return Ok(messages);
280        }
281
282        let tokens_to_remove = context_window.tokens_to_truncate(total_tokens);
283
284        // Strategy: Keep system message, remove oldest user/assistant pairs
285        let mut result = Vec::new();
286        let mut removed_tokens = 0;
287
288        // First pass: separate system messages and conversation
289        let mut system_messages = Vec::new();
290        let mut conversation_messages = Vec::new();
291
292        for message in messages {
293            match message.role {
294                MessageRole::System => system_messages.push(message),
295                _ => conversation_messages.push(message),
296            }
297        }
298
299        // Keep all system messages
300        result.extend(system_messages);
301
302        // Remove messages from the beginning until we fit
303        let mut skip_count = 0;
304        for message in &conversation_messages {
305            let message_tokens = self.estimate_message_tokens(message, model, provider);
306            if removed_tokens + message_tokens >= tokens_to_remove {
307                break;
308            }
309            removed_tokens += message_tokens;
310            skip_count += 1;
311        }
312
313        // Add remaining conversation messages
314        result.extend(conversation_messages.into_iter().skip(skip_count));
315
316        // Ensure we have at least one non-system message
317        if result.iter().all(|msg| msg.role == MessageRole::System) {
318            return Err(LLMError::generic(
319                "Cannot fit conversation in context window even after truncation",
320            ));
321        }
322
323        Ok(result)
324    }
325
326    /// Get context window for a model
327    pub fn get_context_window(&self, model: &str) -> ContextWindow {
328        match model {
329            // OpenAI models
330            "gpt-3.5-turbo" => ContextWindow::new(16385),
331            "gpt-4" => ContextWindow::new(8192),
332            "gpt-4-turbo" | "gpt-4-turbo-preview" => ContextWindow::new(128000),
333            "gpt-4o" => ContextWindow::new(128000),
334
335            // Anthropic models
336            m if m.starts_with("claude-3") => ContextWindow::new(200000),
337
338            // Ollama models (varies)
339            _ => ContextWindow::new(4096), // Conservative default
340        }
341    }
342
343    /// Optimize message history for token efficiency
344    pub fn optimize_messages(
345        &self,
346        messages: Vec<LLMMessage>,
347        context_window: &ContextWindow,
348        model: Option<&str>,
349        provider: Option<&str>,
350    ) -> LLMResult<Vec<LLMMessage>> {
351        // First try: Simple truncation
352        let truncated = self.truncate_to_fit(messages.clone(), context_window, model, provider)?;
353        let truncated_tokens = self.estimate_conversation_tokens(&truncated, model, provider);
354
355        if context_window.fits(truncated_tokens) {
356            return Ok(truncated);
357        }
358
359        // Second try: Summarization (placeholder for future implementation)
360        // For now, just use more aggressive truncation
361        let aggressive_window = ContextWindow {
362            max_tokens: context_window.max_tokens,
363            response_reserve: context_window.response_reserve,
364            system_reserve: context_window.system_reserve,
365            history_minimum: context_window.history_minimum / 2, // More aggressive
366        };
367
368        self.truncate_to_fit(messages, &aggressive_window, model, provider)
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use crate::llm::core::{LLMMessage, MessageRole};
376
377    #[test]
378    fn test_token_usage() {
379        let mut usage = TokenUsage::new(100, 50);
380        assert_eq!(usage.total_tokens, 150);
381
382        let other = TokenUsage::new(20, 10).with_cached_tokens(5);
383        usage.add(&other);
384
385        assert_eq!(usage.prompt_tokens, 120);
386        assert_eq!(usage.completion_tokens, 60);
387        assert_eq!(usage.total_tokens, 180);
388        assert_eq!(usage.cached_tokens, Some(5));
389    }
390
391    #[test]
392    fn test_token_usage_cost() {
393        let usage = TokenUsage::new(1000, 500);
394        let cost = usage.estimate_cost(0.01, 0.03); // $0.01 per 1k input, $0.03 per 1k output
395        assert_eq!(cost, 0.025); // (1000/1000 * 0.01) + (500/1000 * 0.03)
396    }
397
398    #[test]
399    fn test_context_window() {
400        let window = ContextWindow::new(4000);
401        assert_eq!(window.available_for_history(), 2500); // 4000 - 1000 - 500
402
403        assert!(window.fits(2000));
404        assert!(!window.fits(3000));
405
406        assert_eq!(window.tokens_to_truncate(3000), 500);
407        assert_eq!(window.tokens_to_truncate(2000), 0);
408    }
409
410    #[test]
411    fn test_token_counter_text_estimation() {
412        let counter = TokenCounter::new();
413
414        let text = "Hello, world!";
415        let tokens = counter.estimate_text_tokens(text, Some("gpt-4"));
416        assert!(tokens > 0);
417        assert!(tokens < 20); // Reasonable estimate for short text
418
419        let long_text = "This is a much longer text that should result in more tokens being estimated by the token counter system.";
420        let long_tokens = counter.estimate_text_tokens(long_text, Some("gpt-4"));
421        assert!(long_tokens > tokens);
422    }
423
424    #[test]
425    fn test_message_token_estimation() {
426        let counter = TokenCounter::new();
427
428        let message = LLMMessage::user("Hello, world!");
429        let tokens = counter.estimate_message_tokens(&message, Some("gpt-4"), Some("openai"));
430        assert!(tokens > 0);
431
432        let system_message = LLMMessage::system("You are a helpful assistant.");
433        let system_tokens =
434            counter.estimate_message_tokens(&system_message, Some("gpt-4"), Some("openai"));
435        assert!(system_tokens > tokens); // System messages have more overhead
436    }
437
438    #[test]
439    fn test_conversation_token_estimation() {
440        let counter = TokenCounter::new();
441
442        let messages = vec![
443            LLMMessage::system("You are a helpful assistant."),
444            LLMMessage::user("What's 2+2?"),
445            LLMMessage::assistant("2+2 equals 4."),
446        ];
447
448        let tokens = counter.estimate_conversation_tokens(&messages, Some("gpt-4"), Some("openai"));
449        assert!(tokens > 0);
450
451        let single_message_tokens =
452            counter.estimate_message_tokens(&messages[0], Some("gpt-4"), Some("openai"));
453        assert!(tokens > single_message_tokens); // Should be more than just one message
454    }
455
456    #[test]
457    fn test_message_truncation() {
458        let counter = TokenCounter::new();
459        let window = ContextWindow::new(1000); // Much larger window for testing
460
461        let messages = vec![
462            LLMMessage::system("You are a helpful assistant."),
463            LLMMessage::user("First question"),
464            LLMMessage::assistant("First answer"),
465            LLMMessage::user("Second question"),
466            LLMMessage::assistant("Second answer"),
467            LLMMessage::user("Final question"),
468        ];
469
470        let truncated = counter
471            .truncate_to_fit(messages.clone(), &window, Some("gpt-4"), Some("openai"))
472            .unwrap();
473
474        // Should keep system message and some conversation
475        assert!(!truncated.is_empty());
476
477        // Should preserve system message
478        assert!(truncated.iter().any(|msg| msg.role == MessageRole::System));
479
480        // Should have at least one non-system message
481        assert!(truncated.iter().any(|msg| msg.role != MessageRole::System));
482    }
483
484    #[test]
485    fn test_context_window_for_models() {
486        let counter = TokenCounter::new();
487
488        let gpt4_window = counter.get_context_window("gpt-4");
489        assert_eq!(gpt4_window.max_tokens, 8192);
490
491        let gpt4_turbo_window = counter.get_context_window("gpt-4-turbo");
492        assert_eq!(gpt4_turbo_window.max_tokens, 128000);
493
494        let claude_window = counter.get_context_window("claude-3-sonnet-20240229");
495        assert_eq!(claude_window.max_tokens, 200000);
496
497        let unknown_window = counter.get_context_window("unknown-model");
498        assert_eq!(unknown_window.max_tokens, 4096); // Default
499    }
500}