1use anyhow::{Context, Result, anyhow};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13use tiktoken_rs::get_bpe_from_model;
14use tokio::sync::RwLock;
15use tracing::debug;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TokenBudgetConfig {
20 pub max_context_tokens: usize,
22 pub warning_threshold: f64,
24 pub compaction_threshold: f64,
26 pub model: String,
28 pub detailed_tracking: bool,
30}
31
32impl Default for TokenBudgetConfig {
33 fn default() -> Self {
34 Self {
35 max_context_tokens: 128_000,
36 warning_threshold: 0.75,
37 compaction_threshold: 0.85,
38 model: "gpt-4".to_string(),
39 detailed_tracking: false,
40 }
41 }
42}
43
44impl TokenBudgetConfig {
45 pub fn for_model(model: &str, max_tokens: usize) -> Self {
47 Self {
48 max_context_tokens: max_tokens,
49 model: model.to_string(),
50 ..Default::default()
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct TokenUsageStats {
58 pub total_tokens: usize,
59 pub system_prompt_tokens: usize,
60 pub user_messages_tokens: usize,
61 pub assistant_messages_tokens: usize,
62 pub tool_results_tokens: usize,
63 pub decision_ledger_tokens: usize,
64 pub timestamp: u64,
65}
66
67impl TokenUsageStats {
68 pub fn new() -> Self {
69 Self {
70 total_tokens: 0,
71 system_prompt_tokens: 0,
72 user_messages_tokens: 0,
73 assistant_messages_tokens: 0,
74 tool_results_tokens: 0,
75 decision_ledger_tokens: 0,
76 timestamp: SystemTime::now()
77 .duration_since(UNIX_EPOCH)
78 .unwrap_or_default()
79 .as_secs(),
80 }
81 }
82
83 pub fn usage_percentage(&self, max_tokens: usize) -> f64 {
85 if max_tokens == 0 {
86 return 0.0;
87 }
88 (self.total_tokens as f64 / max_tokens as f64) * 100.0
89 }
90
91 pub fn needs_compaction(&self, max_tokens: usize, threshold: f64) -> bool {
93 let usage = self.total_tokens as f64 / max_tokens as f64;
94 usage >= threshold
95 }
96}
97
98impl Default for TokenUsageStats {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
106pub enum ContextComponent {
107 SystemPrompt,
108 UserMessage,
109 AssistantMessage,
110 ToolResult,
111 DecisionLedger,
112 ProjectGuidelines,
113 FileContent,
114}
115
116pub struct TokenBudgetManager {
118 config: Arc<RwLock<TokenBudgetConfig>>,
119 stats: Arc<RwLock<TokenUsageStats>>,
120 component_tokens: Arc<RwLock<HashMap<String, usize>>>,
121 tokenizer_cache: Arc<RwLock<Option<tiktoken_rs::CoreBPE>>>,
122}
123
124impl TokenBudgetManager {
125 pub fn new(config: TokenBudgetConfig) -> Self {
127 Self {
128 config: Arc::new(RwLock::new(config)),
129 stats: Arc::new(RwLock::new(TokenUsageStats::new())),
130 component_tokens: Arc::new(RwLock::new(HashMap::new())),
131 tokenizer_cache: Arc::new(RwLock::new(None)),
132 }
133 }
134
135 async fn ensure_tokenizer(&self) -> Result<()> {
137 let mut cache = self.tokenizer_cache.write().await;
138 if cache.is_none() {
139 let config = self.config.read().await;
140 let bpe = get_bpe_from_model(&config.model)
141 .with_context(|| format!("Failed to get tokenizer for model: {}", config.model))?;
142 *cache = Some(bpe);
143 }
144 Ok(())
145 }
146
147 pub async fn count_tokens(&self, text: &str) -> Result<usize> {
149 self.ensure_tokenizer().await?;
150 let cache = self.tokenizer_cache.read().await;
151 let bpe = cache
152 .as_ref()
153 .ok_or_else(|| anyhow!("Tokenizer not initialized"))?;
154 Ok(bpe.encode_with_special_tokens(text).len())
155 }
156
157 pub async fn count_tokens_for_component(
159 &self,
160 text: &str,
161 component: ContextComponent,
162 component_id: Option<&str>,
163 ) -> Result<usize> {
164 let token_count = self.count_tokens(text).await?;
165
166 self.record_tokens_for_component(component, token_count, component_id)
167 .await;
168
169 Ok(token_count)
170 }
171
172 pub async fn record_tokens_for_component(
174 &self,
175 component: ContextComponent,
176 tokens: usize,
177 component_id: Option<&str>,
178 ) {
179 if tokens == 0 {
180 return;
181 }
182
183 let detailed_tracking = {
184 let config = self.config.read().await;
185 config.detailed_tracking
186 };
187
188 if detailed_tracking {
189 let key = if let Some(id) = component_id {
190 format!("{:?}:{}", component, id)
191 } else {
192 format!("{:?}", component)
193 };
194 let mut components = self.component_tokens.write().await;
195 *components.entry(key).or_insert(0) += tokens;
196 }
197
198 let mut stats = self.stats.write().await;
199 stats.total_tokens += tokens;
200
201 match component {
202 ContextComponent::SystemPrompt => stats.system_prompt_tokens += tokens,
203 ContextComponent::UserMessage => stats.user_messages_tokens += tokens,
204 ContextComponent::AssistantMessage => stats.assistant_messages_tokens += tokens,
205 ContextComponent::ToolResult => stats.tool_results_tokens += tokens,
206 ContextComponent::DecisionLedger => stats.decision_ledger_tokens += tokens,
207 _ => {}
208 }
209
210 stats.timestamp = SystemTime::now()
211 .duration_since(UNIX_EPOCH)
212 .unwrap_or_default()
213 .as_secs();
214 }
215
216 pub async fn get_stats(&self) -> TokenUsageStats {
218 self.stats.read().await.clone()
219 }
220
221 pub async fn get_component_breakdown(&self) -> HashMap<String, usize> {
223 self.component_tokens.read().await.clone()
224 }
225
226 pub async fn is_warning_threshold_exceeded(&self) -> bool {
228 let stats = self.stats.read().await;
229 let config = self.config.read().await;
230 stats.needs_compaction(config.max_context_tokens, config.warning_threshold)
231 }
232
233 pub async fn is_compaction_threshold_exceeded(&self) -> bool {
235 let stats = self.stats.read().await;
236 let config = self.config.read().await;
237 stats.needs_compaction(config.max_context_tokens, config.compaction_threshold)
238 }
239
240 pub async fn usage_percentage(&self) -> f64 {
242 let stats = self.stats.read().await;
243 let config = self.config.read().await;
244 stats.usage_percentage(config.max_context_tokens)
245 }
246
247 pub async fn remaining_tokens(&self) -> usize {
249 let stats = self.stats.read().await;
250 let config = self.config.read().await;
251 config.max_context_tokens.saturating_sub(stats.total_tokens)
252 }
253
254 pub async fn reset(&self) {
256 let mut stats = self.stats.write().await;
257 *stats = TokenUsageStats::new();
258 let mut components = self.component_tokens.write().await;
259 components.clear();
260 debug!("Token budget reset");
261 }
262
263 pub async fn deduct_tokens(&self, component: ContextComponent, tokens: usize) {
265 let mut stats = self.stats.write().await;
266 stats.total_tokens = stats.total_tokens.saturating_sub(tokens);
267
268 match component {
269 ContextComponent::SystemPrompt => {
270 stats.system_prompt_tokens = stats.system_prompt_tokens.saturating_sub(tokens)
271 }
272 ContextComponent::UserMessage => {
273 stats.user_messages_tokens = stats.user_messages_tokens.saturating_sub(tokens)
274 }
275 ContextComponent::AssistantMessage => {
276 stats.assistant_messages_tokens =
277 stats.assistant_messages_tokens.saturating_sub(tokens)
278 }
279 ContextComponent::ToolResult => {
280 stats.tool_results_tokens = stats.tool_results_tokens.saturating_sub(tokens)
281 }
282 ContextComponent::DecisionLedger => {
283 stats.decision_ledger_tokens = stats.decision_ledger_tokens.saturating_sub(tokens)
284 }
285 _ => {}
286 }
287
288 debug!("Deducted {} tokens from {:?}", tokens, component);
289 }
290
291 pub async fn generate_report(&self) -> String {
293 let stats = self.stats.read().await;
294 let config = self.config.read().await;
295 let components = self.component_tokens.read().await;
296
297 let usage_pct = stats.usage_percentage(config.max_context_tokens);
298 let remaining = config.max_context_tokens.saturating_sub(stats.total_tokens);
299
300 let mut report = format!(
301 "Token Budget Report\n\
302 ==================\n\
303 Total Tokens: {}/{} ({:.1}%)\n\
304 Remaining: {} tokens\n\n\
305 Breakdown by Category:\n\
306 - System Prompt: {} tokens\n\
307 - User Messages: {} tokens\n\
308 - Assistant Messages: {} tokens\n\
309 - Tool Results: {} tokens\n\
310 - Decision Ledger: {} tokens\n",
311 stats.total_tokens,
312 config.max_context_tokens,
313 usage_pct,
314 remaining,
315 stats.system_prompt_tokens,
316 stats.user_messages_tokens,
317 stats.assistant_messages_tokens,
318 stats.tool_results_tokens,
319 stats.decision_ledger_tokens
320 );
321
322 if config.detailed_tracking && !components.is_empty() {
323 report.push_str("\nDetailed Component Tracking:\n");
324 let mut sorted: Vec<_> = components.iter().collect();
325 sorted.sort_by(|a, b| b.1.cmp(a.1));
326 for (component, tokens) in sorted.iter().take(10) {
327 report.push_str(&format!(" - {}: {} tokens\n", component, tokens));
328 }
329 }
330
331 if usage_pct >= config.compaction_threshold * 100.0 {
332 report.push_str("\nALERT: Compaction threshold exceeded");
333 } else if usage_pct >= config.warning_threshold * 100.0 {
334 report.push_str("\nWARNING: Approaching token limit");
335 }
336
337 report
338 }
339}
340
341impl Default for TokenBudgetManager {
342 fn default() -> Self {
343 Self::new(TokenBudgetConfig::default())
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[tokio::test]
352 async fn test_token_counting() {
353 let config = TokenBudgetConfig::default();
354 let manager = TokenBudgetManager::new(config);
355
356 let text = "Hello, world!";
357 let count = manager.count_tokens(text).await.unwrap();
358 assert!(count > 0);
359 }
360
361 #[tokio::test]
362 async fn test_component_tracking() {
363 let mut config = TokenBudgetConfig::default();
364 config.detailed_tracking = true;
365 let manager = TokenBudgetManager::new(config);
366
367 let text = "This is a test message";
368 let count = manager
369 .count_tokens_for_component(text, ContextComponent::UserMessage, Some("msg1"))
370 .await
371 .unwrap();
372
373 assert!(count > 0);
374
375 let stats = manager.get_stats().await;
376 assert_eq!(stats.user_messages_tokens, count);
377 }
378
379 #[tokio::test]
380 async fn test_threshold_detection() {
381 let mut config = TokenBudgetConfig::default();
382 config.max_context_tokens = 100;
383 config.compaction_threshold = 0.8;
384 let manager = TokenBudgetManager::new(config);
385
386 let text = "word ".repeat(25); manager
389 .count_tokens_for_component(&text, ContextComponent::UserMessage, None)
390 .await
391 .unwrap();
392
393 assert!(manager.is_compaction_threshold_exceeded().await);
394 }
395
396 #[tokio::test]
397 async fn test_token_deduction() {
398 let manager = TokenBudgetManager::new(TokenBudgetConfig::default());
399
400 let text = "Hello, world!";
401 let count = manager
402 .count_tokens_for_component(text, ContextComponent::ToolResult, None)
403 .await
404 .unwrap();
405
406 let initial_total = manager.get_stats().await.total_tokens;
407
408 manager
409 .deduct_tokens(ContextComponent::ToolResult, count)
410 .await;
411
412 let after_total = manager.get_stats().await.total_tokens;
413 assert_eq!(after_total, initial_total - count);
414 }
415}