1use crate::llm::core::{LLMError, LLMMessage, LLMResult, MessageRole};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12pub struct TokenUsage {
13 pub prompt_tokens: usize,
15
16 pub completion_tokens: usize,
18
19 pub total_tokens: usize,
21
22 pub cached_tokens: Option<usize>,
24
25 pub image_tokens: Option<usize>,
27}
28
29impl TokenUsage {
30 pub fn new(prompt_tokens: usize, completion_tokens: usize) -> Self {
32 Self {
33 prompt_tokens,
34 completion_tokens,
35 total_tokens: prompt_tokens + completion_tokens,
36 cached_tokens: None,
37 image_tokens: None,
38 }
39 }
40
41 pub fn empty() -> Self {
43 Self::new(0, 0)
44 }
45
46 pub fn with_cached_tokens(mut self, cached_tokens: usize) -> Self {
48 self.cached_tokens = Some(cached_tokens);
49 self
50 }
51
52 pub fn with_image_tokens(mut self, image_tokens: usize) -> Self {
54 self.image_tokens = Some(image_tokens);
55 self
56 }
57
58 pub fn add(&mut self, other: &TokenUsage) {
60 self.prompt_tokens += other.prompt_tokens;
61 self.completion_tokens += other.completion_tokens;
62 self.total_tokens += other.total_tokens;
63
64 if let Some(other_cached) = other.cached_tokens {
65 self.cached_tokens = Some(self.cached_tokens.unwrap_or(0) + other_cached);
66 }
67
68 if let Some(other_image) = other.image_tokens {
69 self.image_tokens = Some(self.image_tokens.unwrap_or(0) + other_image);
70 }
71 }
72
73 pub fn estimate_cost(&self, input_cost_per_1k: f64, output_cost_per_1k: f64) -> f64 {
75 let prompt_cost = (self.prompt_tokens as f64 / 1000.0) * input_cost_per_1k;
76 let completion_cost = (self.completion_tokens as f64 / 1000.0) * output_cost_per_1k;
77 prompt_cost + completion_cost
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ContextWindow {
84 pub max_tokens: usize,
86
87 pub response_reserve: usize,
89
90 pub system_reserve: usize,
92
93 pub history_minimum: usize,
95}
96
97impl ContextWindow {
98 pub fn new(max_tokens: usize) -> Self {
100 Self {
101 max_tokens,
102 response_reserve: max_tokens / 4, system_reserve: 500, history_minimum: 1000, }
106 }
107
108 pub fn available_for_history(&self) -> usize {
110 self.max_tokens
111 .saturating_sub(self.response_reserve)
112 .saturating_sub(self.system_reserve)
113 }
114
115 pub fn fits(&self, token_count: usize) -> bool {
117 token_count <= self.available_for_history()
118 }
119
120 pub fn tokens_to_truncate(&self, current_tokens: usize) -> usize {
122 if self.fits(current_tokens) {
123 0
124 } else {
125 current_tokens - self.available_for_history()
126 }
127 }
128}
129
130#[derive(Debug)]
132pub struct TokenCounter {
133 model_estimates: HashMap<String, f64>,
135
136 provider_multipliers: HashMap<String, f64>,
138}
139
140impl Default for TokenCounter {
141 fn default() -> Self {
142 let mut model_estimates = HashMap::new();
143 let mut provider_multipliers = HashMap::new();
144
145 model_estimates.insert("gpt-3.5-turbo".to_string(), 0.25);
147 model_estimates.insert("gpt-4".to_string(), 0.25);
148 model_estimates.insert("gpt-4-turbo".to_string(), 0.25);
149 model_estimates.insert("gpt-4o".to_string(), 0.25);
150
151 model_estimates.insert("claude-3-haiku-20240307".to_string(), 0.24);
153 model_estimates.insert("claude-3-sonnet-20240229".to_string(), 0.24);
154 model_estimates.insert("claude-3-opus-20240229".to_string(), 0.24);
155 model_estimates.insert("claude-3-5-sonnet-20240620".to_string(), 0.24);
156
157 provider_multipliers.insert("openai".to_string(), 1.1);
159 provider_multipliers.insert("anthropic".to_string(), 1.05);
160 provider_multipliers.insert("ollama".to_string(), 1.0);
161
162 Self {
163 model_estimates,
164 provider_multipliers,
165 }
166 }
167}
168
169impl TokenCounter {
170 pub fn new() -> Self {
172 Self::default()
173 }
174
175 pub fn add_model_estimate(&mut self, model: String, tokens_per_char: f64) {
177 self.model_estimates.insert(model, tokens_per_char);
178 }
179
180 pub fn add_provider_multiplier(&mut self, provider: String, multiplier: f64) {
182 self.provider_multipliers.insert(provider, multiplier);
183 }
184
185 pub fn estimate_text_tokens(&self, text: &str, model: Option<&str>) -> usize {
187 let base_estimate = if let Some(model) = model {
188 let tokens_per_char = self.model_estimates.get(model).copied().unwrap_or(0.25); (text.len() as f64 * tokens_per_char) as usize
190 } else {
191 text.len().div_ceil(4)
193 };
194
195 base_estimate.max(1) }
197
198 pub fn estimate_message_tokens(
200 &self,
201 message: &LLMMessage,
202 model: Option<&str>,
203 provider: Option<&str>,
204 ) -> usize {
205 let base_tokens = match &message.content {
206 crate::llm::core::MessageContent::Text { text } => {
207 self.estimate_text_tokens(text, model)
208 }
209 crate::llm::core::MessageContent::Image { .. } => {
210 match provider {
212 Some("openai") => 765, Some("anthropic") => 1568, _ => 1000, }
216 }
217 crate::llm::core::MessageContent::ToolCall { arguments, .. } => {
218 let args_str = arguments.to_string();
219 self.estimate_text_tokens(&args_str, model) + 10 }
221 crate::llm::core::MessageContent::ToolResult { result, .. } => {
222 let result_str = result.to_string();
223 self.estimate_text_tokens(&result_str, model) + 5 }
225 };
226
227 let message_overhead = match message.role {
229 MessageRole::System => 10,
230 MessageRole::User => 5,
231 MessageRole::Assistant => 5,
232 MessageRole::Function => 15,
233 };
234
235 let total_tokens = base_tokens + message_overhead;
236
237 if let Some(provider) = provider {
239 let multiplier = self
240 .provider_multipliers
241 .get(provider)
242 .copied()
243 .unwrap_or(1.0);
244 (total_tokens as f64 * multiplier) as usize
245 } else {
246 total_tokens
247 }
248 }
249
250 pub fn estimate_conversation_tokens(
252 &self,
253 messages: &[LLMMessage],
254 model: Option<&str>,
255 provider: Option<&str>,
256 ) -> usize {
257 let message_tokens: usize = messages
258 .iter()
259 .map(|msg| self.estimate_message_tokens(msg, model, provider))
260 .sum();
261
262 let conversation_overhead = messages.len() * 2;
264
265 message_tokens + conversation_overhead
266 }
267
268 pub fn truncate_to_fit(
270 &self,
271 messages: Vec<LLMMessage>,
272 context_window: &ContextWindow,
273 model: Option<&str>,
274 provider: Option<&str>,
275 ) -> LLMResult<Vec<LLMMessage>> {
276 let total_tokens = self.estimate_conversation_tokens(&messages, model, provider);
277
278 if context_window.fits(total_tokens) {
279 return Ok(messages);
280 }
281
282 let tokens_to_remove = context_window.tokens_to_truncate(total_tokens);
283
284 let mut result = Vec::new();
286 let mut removed_tokens = 0;
287
288 let mut system_messages = Vec::new();
290 let mut conversation_messages = Vec::new();
291
292 for message in messages {
293 match message.role {
294 MessageRole::System => system_messages.push(message),
295 _ => conversation_messages.push(message),
296 }
297 }
298
299 result.extend(system_messages);
301
302 let mut skip_count = 0;
304 for message in &conversation_messages {
305 let message_tokens = self.estimate_message_tokens(message, model, provider);
306 if removed_tokens + message_tokens >= tokens_to_remove {
307 break;
308 }
309 removed_tokens += message_tokens;
310 skip_count += 1;
311 }
312
313 result.extend(conversation_messages.into_iter().skip(skip_count));
315
316 if result.iter().all(|msg| msg.role == MessageRole::System) {
318 return Err(LLMError::generic(
319 "Cannot fit conversation in context window even after truncation",
320 ));
321 }
322
323 Ok(result)
324 }
325
326 pub fn get_context_window(&self, model: &str) -> ContextWindow {
328 match model {
329 "gpt-3.5-turbo" => ContextWindow::new(16385),
331 "gpt-4" => ContextWindow::new(8192),
332 "gpt-4-turbo" | "gpt-4-turbo-preview" => ContextWindow::new(128000),
333 "gpt-4o" => ContextWindow::new(128000),
334
335 m if m.starts_with("claude-3") => ContextWindow::new(200000),
337
338 _ => ContextWindow::new(4096), }
341 }
342
343 pub fn optimize_messages(
345 &self,
346 messages: Vec<LLMMessage>,
347 context_window: &ContextWindow,
348 model: Option<&str>,
349 provider: Option<&str>,
350 ) -> LLMResult<Vec<LLMMessage>> {
351 let truncated = self.truncate_to_fit(messages.clone(), context_window, model, provider)?;
353 let truncated_tokens = self.estimate_conversation_tokens(&truncated, model, provider);
354
355 if context_window.fits(truncated_tokens) {
356 return Ok(truncated);
357 }
358
359 let aggressive_window = ContextWindow {
362 max_tokens: context_window.max_tokens,
363 response_reserve: context_window.response_reserve,
364 system_reserve: context_window.system_reserve,
365 history_minimum: context_window.history_minimum / 2, };
367
368 self.truncate_to_fit(messages, &aggressive_window, model, provider)
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use crate::llm::core::{LLMMessage, MessageRole};
376
377 #[test]
378 fn test_token_usage() {
379 let mut usage = TokenUsage::new(100, 50);
380 assert_eq!(usage.total_tokens, 150);
381
382 let other = TokenUsage::new(20, 10).with_cached_tokens(5);
383 usage.add(&other);
384
385 assert_eq!(usage.prompt_tokens, 120);
386 assert_eq!(usage.completion_tokens, 60);
387 assert_eq!(usage.total_tokens, 180);
388 assert_eq!(usage.cached_tokens, Some(5));
389 }
390
391 #[test]
392 fn test_token_usage_cost() {
393 let usage = TokenUsage::new(1000, 500);
394 let cost = usage.estimate_cost(0.01, 0.03); assert_eq!(cost, 0.025); }
397
398 #[test]
399 fn test_context_window() {
400 let window = ContextWindow::new(4000);
401 assert_eq!(window.available_for_history(), 2500); assert!(window.fits(2000));
404 assert!(!window.fits(3000));
405
406 assert_eq!(window.tokens_to_truncate(3000), 500);
407 assert_eq!(window.tokens_to_truncate(2000), 0);
408 }
409
410 #[test]
411 fn test_token_counter_text_estimation() {
412 let counter = TokenCounter::new();
413
414 let text = "Hello, world!";
415 let tokens = counter.estimate_text_tokens(text, Some("gpt-4"));
416 assert!(tokens > 0);
417 assert!(tokens < 20); let long_text = "This is a much longer text that should result in more tokens being estimated by the token counter system.";
420 let long_tokens = counter.estimate_text_tokens(long_text, Some("gpt-4"));
421 assert!(long_tokens > tokens);
422 }
423
424 #[test]
425 fn test_message_token_estimation() {
426 let counter = TokenCounter::new();
427
428 let message = LLMMessage::user("Hello, world!");
429 let tokens = counter.estimate_message_tokens(&message, Some("gpt-4"), Some("openai"));
430 assert!(tokens > 0);
431
432 let system_message = LLMMessage::system("You are a helpful assistant.");
433 let system_tokens =
434 counter.estimate_message_tokens(&system_message, Some("gpt-4"), Some("openai"));
435 assert!(system_tokens > tokens); }
437
438 #[test]
439 fn test_conversation_token_estimation() {
440 let counter = TokenCounter::new();
441
442 let messages = vec![
443 LLMMessage::system("You are a helpful assistant."),
444 LLMMessage::user("What's 2+2?"),
445 LLMMessage::assistant("2+2 equals 4."),
446 ];
447
448 let tokens = counter.estimate_conversation_tokens(&messages, Some("gpt-4"), Some("openai"));
449 assert!(tokens > 0);
450
451 let single_message_tokens =
452 counter.estimate_message_tokens(&messages[0], Some("gpt-4"), Some("openai"));
453 assert!(tokens > single_message_tokens); }
455
456 #[test]
457 fn test_message_truncation() {
458 let counter = TokenCounter::new();
459 let window = ContextWindow::new(1000); let messages = vec![
462 LLMMessage::system("You are a helpful assistant."),
463 LLMMessage::user("First question"),
464 LLMMessage::assistant("First answer"),
465 LLMMessage::user("Second question"),
466 LLMMessage::assistant("Second answer"),
467 LLMMessage::user("Final question"),
468 ];
469
470 let truncated = counter
471 .truncate_to_fit(messages.clone(), &window, Some("gpt-4"), Some("openai"))
472 .unwrap();
473
474 assert!(!truncated.is_empty());
476
477 assert!(truncated.iter().any(|msg| msg.role == MessageRole::System));
479
480 assert!(truncated.iter().any(|msg| msg.role != MessageRole::System));
482 }
483
484 #[test]
485 fn test_context_window_for_models() {
486 let counter = TokenCounter::new();
487
488 let gpt4_window = counter.get_context_window("gpt-4");
489 assert_eq!(gpt4_window.max_tokens, 8192);
490
491 let gpt4_turbo_window = counter.get_context_window("gpt-4-turbo");
492 assert_eq!(gpt4_turbo_window.max_tokens, 128000);
493
494 let claude_window = counter.get_context_window("claude-3-sonnet-20240229");
495 assert_eq!(claude_window.max_tokens, 200000);
496
497 let unknown_window = counter.get_context_window("unknown-model");
498 assert_eq!(unknown_window.max_tokens, 4096); }
500}