1use anyhow::{Context, Result, anyhow};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13use tiktoken_rs::get_bpe_from_model;
14use tokio::sync::RwLock;
15use tracing::debug;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TokenBudgetConfig {
20 pub max_context_tokens: usize,
22 pub warning_threshold: f64,
24 pub compaction_threshold: f64,
26 pub model: String,
28 pub detailed_tracking: bool,
30}
31
32impl Default for TokenBudgetConfig {
33 fn default() -> Self {
34 Self {
35 max_context_tokens: 128_000,
36 warning_threshold: 0.75,
37 compaction_threshold: 0.85,
38 model: "gpt-4".to_string(),
39 detailed_tracking: false,
40 }
41 }
42}
43
44impl TokenBudgetConfig {
45 pub fn for_model(model: &str, max_tokens: usize) -> Self {
47 Self {
48 max_context_tokens: max_tokens,
49 model: model.to_string(),
50 ..Default::default()
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct TokenUsageStats {
58 pub total_tokens: usize,
59 pub system_prompt_tokens: usize,
60 pub user_messages_tokens: usize,
61 pub assistant_messages_tokens: usize,
62 pub tool_results_tokens: usize,
63 pub decision_ledger_tokens: usize,
64 pub timestamp: u64,
65}
66
67impl TokenUsageStats {
68 pub fn new() -> Self {
69 Self {
70 total_tokens: 0,
71 system_prompt_tokens: 0,
72 user_messages_tokens: 0,
73 assistant_messages_tokens: 0,
74 tool_results_tokens: 0,
75 decision_ledger_tokens: 0,
76 timestamp: SystemTime::now()
77 .duration_since(UNIX_EPOCH)
78 .unwrap_or_default()
79 .as_secs(),
80 }
81 }
82
83 pub fn usage_percentage(&self, max_tokens: usize) -> f64 {
85 if max_tokens == 0 {
86 return 0.0;
87 }
88 (self.total_tokens as f64 / max_tokens as f64) * 100.0
89 }
90
91 pub fn needs_compaction(&self, max_tokens: usize, threshold: f64) -> bool {
93 let usage = self.total_tokens as f64 / max_tokens as f64;
94 usage >= threshold
95 }
96}
97
98impl Default for TokenUsageStats {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
106pub enum ContextComponent {
107 SystemPrompt,
108 UserMessage,
109 AssistantMessage,
110 ToolResult,
111 DecisionLedger,
112 ProjectGuidelines,
113 FileContent,
114}
115
116pub struct TokenBudgetManager {
118 config: Arc<RwLock<TokenBudgetConfig>>,
119 stats: Arc<RwLock<TokenUsageStats>>,
120 component_tokens: Arc<RwLock<HashMap<String, usize>>>,
121 tokenizer_cache: Arc<RwLock<Option<tiktoken_rs::CoreBPE>>>,
122}
123
124impl TokenBudgetManager {
125 pub fn new(config: TokenBudgetConfig) -> Self {
127 Self {
128 config: Arc::new(RwLock::new(config)),
129 stats: Arc::new(RwLock::new(TokenUsageStats::new())),
130 component_tokens: Arc::new(RwLock::new(HashMap::new())),
131 tokenizer_cache: Arc::new(RwLock::new(None)),
132 }
133 }
134
135 async fn ensure_tokenizer(&self) -> Result<()> {
137 let mut cache = self.tokenizer_cache.write().await;
138 if cache.is_none() {
139 let config = self.config.read().await;
140 let bpe = get_bpe_from_model(&config.model)
141 .with_context(|| format!("Failed to get tokenizer for model: {}", config.model))?;
142 *cache = Some(bpe);
143 }
144 Ok(())
145 }
146
147 pub async fn count_tokens(&self, text: &str) -> Result<usize> {
149 self.ensure_tokenizer().await?;
150 let cache = self.tokenizer_cache.read().await;
151 let bpe = cache
152 .as_ref()
153 .ok_or_else(|| anyhow!("Tokenizer not initialized"))?;
154 Ok(bpe.encode_with_special_tokens(text).len())
155 }
156
157 pub async fn count_tokens_for_component(
159 &self,
160 text: &str,
161 component: ContextComponent,
162 component_id: Option<&str>,
163 ) -> Result<usize> {
164 let token_count = self.count_tokens(text).await?;
165
166 if self.config.read().await.detailed_tracking {
168 let key = if let Some(id) = component_id {
169 format!("{:?}:{}", component, id)
170 } else {
171 format!("{:?}", component)
172 };
173 let mut components = self.component_tokens.write().await;
174 *components.entry(key).or_insert(0) += token_count;
175 }
176
177 let mut stats = self.stats.write().await;
179 stats.total_tokens += token_count;
180
181 match component {
182 ContextComponent::SystemPrompt => stats.system_prompt_tokens += token_count,
183 ContextComponent::UserMessage => stats.user_messages_tokens += token_count,
184 ContextComponent::AssistantMessage => stats.assistant_messages_tokens += token_count,
185 ContextComponent::ToolResult => stats.tool_results_tokens += token_count,
186 ContextComponent::DecisionLedger => stats.decision_ledger_tokens += token_count,
187 _ => {}
188 }
189
190 stats.timestamp = SystemTime::now()
191 .duration_since(UNIX_EPOCH)
192 .unwrap_or_default()
193 .as_secs();
194
195 Ok(token_count)
196 }
197
198 pub async fn get_stats(&self) -> TokenUsageStats {
200 self.stats.read().await.clone()
201 }
202
203 pub async fn get_component_breakdown(&self) -> HashMap<String, usize> {
205 self.component_tokens.read().await.clone()
206 }
207
208 pub async fn is_warning_threshold_exceeded(&self) -> bool {
210 let stats = self.stats.read().await;
211 let config = self.config.read().await;
212 stats.needs_compaction(config.max_context_tokens, config.warning_threshold)
213 }
214
215 pub async fn is_compaction_threshold_exceeded(&self) -> bool {
217 let stats = self.stats.read().await;
218 let config = self.config.read().await;
219 stats.needs_compaction(config.max_context_tokens, config.compaction_threshold)
220 }
221
222 pub async fn usage_percentage(&self) -> f64 {
224 let stats = self.stats.read().await;
225 let config = self.config.read().await;
226 stats.usage_percentage(config.max_context_tokens)
227 }
228
229 pub async fn remaining_tokens(&self) -> usize {
231 let stats = self.stats.read().await;
232 let config = self.config.read().await;
233 config.max_context_tokens.saturating_sub(stats.total_tokens)
234 }
235
236 pub async fn reset(&self) {
238 let mut stats = self.stats.write().await;
239 *stats = TokenUsageStats::new();
240 let mut components = self.component_tokens.write().await;
241 components.clear();
242 debug!("Token budget reset");
243 }
244
245 pub async fn deduct_tokens(&self, component: ContextComponent, tokens: usize) {
247 let mut stats = self.stats.write().await;
248 stats.total_tokens = stats.total_tokens.saturating_sub(tokens);
249
250 match component {
251 ContextComponent::SystemPrompt => {
252 stats.system_prompt_tokens = stats.system_prompt_tokens.saturating_sub(tokens)
253 }
254 ContextComponent::UserMessage => {
255 stats.user_messages_tokens = stats.user_messages_tokens.saturating_sub(tokens)
256 }
257 ContextComponent::AssistantMessage => {
258 stats.assistant_messages_tokens =
259 stats.assistant_messages_tokens.saturating_sub(tokens)
260 }
261 ContextComponent::ToolResult => {
262 stats.tool_results_tokens = stats.tool_results_tokens.saturating_sub(tokens)
263 }
264 ContextComponent::DecisionLedger => {
265 stats.decision_ledger_tokens = stats.decision_ledger_tokens.saturating_sub(tokens)
266 }
267 _ => {}
268 }
269
270 debug!("Deducted {} tokens from {:?}", tokens, component);
271 }
272
273 pub async fn generate_report(&self) -> String {
275 let stats = self.stats.read().await;
276 let config = self.config.read().await;
277 let components = self.component_tokens.read().await;
278
279 let usage_pct = stats.usage_percentage(config.max_context_tokens);
280 let remaining = config.max_context_tokens.saturating_sub(stats.total_tokens);
281
282 let mut report = format!(
283 "Token Budget Report\n\
284 ==================\n\
285 Total Tokens: {}/{} ({:.1}%)\n\
286 Remaining: {} tokens\n\n\
287 Breakdown by Category:\n\
288 - System Prompt: {} tokens\n\
289 - User Messages: {} tokens\n\
290 - Assistant Messages: {} tokens\n\
291 - Tool Results: {} tokens\n\
292 - Decision Ledger: {} tokens\n",
293 stats.total_tokens,
294 config.max_context_tokens,
295 usage_pct,
296 remaining,
297 stats.system_prompt_tokens,
298 stats.user_messages_tokens,
299 stats.assistant_messages_tokens,
300 stats.tool_results_tokens,
301 stats.decision_ledger_tokens
302 );
303
304 if config.detailed_tracking && !components.is_empty() {
305 report.push_str("\nDetailed Component Tracking:\n");
306 let mut sorted: Vec<_> = components.iter().collect();
307 sorted.sort_by(|a, b| b.1.cmp(a.1));
308 for (component, tokens) in sorted.iter().take(10) {
309 report.push_str(&format!(" - {}: {} tokens\n", component, tokens));
310 }
311 }
312
313 if usage_pct >= config.compaction_threshold * 100.0 {
314 report.push_str("\nALERT: Compaction threshold exceeded");
315 } else if usage_pct >= config.warning_threshold * 100.0 {
316 report.push_str("\nWARNING: Approaching token limit");
317 }
318
319 report
320 }
321}
322
323impl Default for TokenBudgetManager {
324 fn default() -> Self {
325 Self::new(TokenBudgetConfig::default())
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[tokio::test]
334 async fn test_token_counting() {
335 let config = TokenBudgetConfig::default();
336 let manager = TokenBudgetManager::new(config);
337
338 let text = "Hello, world!";
339 let count = manager.count_tokens(text).await.unwrap();
340 assert!(count > 0);
341 }
342
343 #[tokio::test]
344 async fn test_component_tracking() {
345 let mut config = TokenBudgetConfig::default();
346 config.detailed_tracking = true;
347 let manager = TokenBudgetManager::new(config);
348
349 let text = "This is a test message";
350 let count = manager
351 .count_tokens_for_component(text, ContextComponent::UserMessage, Some("msg1"))
352 .await
353 .unwrap();
354
355 assert!(count > 0);
356
357 let stats = manager.get_stats().await;
358 assert_eq!(stats.user_messages_tokens, count);
359 }
360
361 #[tokio::test]
362 async fn test_threshold_detection() {
363 let mut config = TokenBudgetConfig::default();
364 config.max_context_tokens = 100;
365 config.compaction_threshold = 0.8;
366 let manager = TokenBudgetManager::new(config);
367
368 let text = "word ".repeat(25); manager
371 .count_tokens_for_component(&text, ContextComponent::UserMessage, None)
372 .await
373 .unwrap();
374
375 assert!(manager.is_compaction_threshold_exceeded().await);
376 }
377
378 #[tokio::test]
379 async fn test_token_deduction() {
380 let manager = TokenBudgetManager::new(TokenBudgetConfig::default());
381
382 let text = "Hello, world!";
383 let count = manager
384 .count_tokens_for_component(text, ContextComponent::ToolResult, None)
385 .await
386 .unwrap();
387
388 let initial_total = manager.get_stats().await.total_tokens;
389
390 manager
391 .deduct_tokens(ContextComponent::ToolResult, count)
392 .await;
393
394 let after_total = manager.get_stats().await.total_tokens;
395 assert_eq!(after_total, initial_total - count);
396 }
397}