vtcode_core/core/
token_budget.rs

1//! Token budget management for context engineering
2//!
3//! This module implements token counting and budget tracking to manage
4//! the attention budget of LLMs. Following Anthropic's context engineering
5//! principles, it helps prevent context rot by tracking token usage and
6//! triggering compaction when thresholds are exceeded.
7
8use anyhow::{Context, Result, anyhow};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13use tiktoken_rs::get_bpe_from_model;
14use tokio::sync::RwLock;
15use tracing::debug;
16
17/// Token budget configuration
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TokenBudgetConfig {
20    /// Maximum tokens allowed in context window
21    pub max_context_tokens: usize,
22    /// Threshold percentage to trigger warnings (0.0-1.0)
23    pub warning_threshold: f64,
24    /// Threshold percentage to trigger compaction (0.0-1.0)
25    pub compaction_threshold: f64,
26    /// Model name for tokenizer selection
27    pub model: String,
28    /// Enable detailed token tracking
29    pub detailed_tracking: bool,
30}
31
32impl Default for TokenBudgetConfig {
33    fn default() -> Self {
34        Self {
35            max_context_tokens: 128_000,
36            warning_threshold: 0.75,
37            compaction_threshold: 0.85,
38            model: "gpt-4".to_string(),
39            detailed_tracking: false,
40        }
41    }
42}
43
44impl TokenBudgetConfig {
45    /// Create config for specific model
46    pub fn for_model(model: &str, max_tokens: usize) -> Self {
47        Self {
48            max_context_tokens: max_tokens,
49            model: model.to_string(),
50            ..Default::default()
51        }
52    }
53}
54
55/// Token usage statistics
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct TokenUsageStats {
58    pub total_tokens: usize,
59    pub system_prompt_tokens: usize,
60    pub user_messages_tokens: usize,
61    pub assistant_messages_tokens: usize,
62    pub tool_results_tokens: usize,
63    pub decision_ledger_tokens: usize,
64    pub timestamp: u64,
65}
66
67impl TokenUsageStats {
68    pub fn new() -> Self {
69        Self {
70            total_tokens: 0,
71            system_prompt_tokens: 0,
72            user_messages_tokens: 0,
73            assistant_messages_tokens: 0,
74            tool_results_tokens: 0,
75            decision_ledger_tokens: 0,
76            timestamp: SystemTime::now()
77                .duration_since(UNIX_EPOCH)
78                .unwrap_or_default()
79                .as_secs(),
80        }
81    }
82
83    /// Calculate percentage of max context used
84    pub fn usage_percentage(&self, max_tokens: usize) -> f64 {
85        if max_tokens == 0 {
86            return 0.0;
87        }
88        (self.total_tokens as f64 / max_tokens as f64) * 100.0
89    }
90
91    /// Check if compaction is needed
92    pub fn needs_compaction(&self, max_tokens: usize, threshold: f64) -> bool {
93        let usage = self.total_tokens as f64 / max_tokens as f64;
94        usage >= threshold
95    }
96}
97
98impl Default for TokenUsageStats {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104/// Component types for detailed tracking
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
106pub enum ContextComponent {
107    SystemPrompt,
108    UserMessage,
109    AssistantMessage,
110    ToolResult,
111    DecisionLedger,
112    ProjectGuidelines,
113    FileContent,
114}
115
116/// Token budget manager
117pub struct TokenBudgetManager {
118    config: Arc<RwLock<TokenBudgetConfig>>,
119    stats: Arc<RwLock<TokenUsageStats>>,
120    component_tokens: Arc<RwLock<HashMap<String, usize>>>,
121    tokenizer_cache: Arc<RwLock<Option<tiktoken_rs::CoreBPE>>>,
122}
123
124impl TokenBudgetManager {
125    /// Create a new token budget manager
126    pub fn new(config: TokenBudgetConfig) -> Self {
127        Self {
128            config: Arc::new(RwLock::new(config)),
129            stats: Arc::new(RwLock::new(TokenUsageStats::new())),
130            component_tokens: Arc::new(RwLock::new(HashMap::new())),
131            tokenizer_cache: Arc::new(RwLock::new(None)),
132        }
133    }
134
135    /// Initialize or update tokenizer for the current model
136    async fn ensure_tokenizer(&self) -> Result<()> {
137        let mut cache = self.tokenizer_cache.write().await;
138        if cache.is_none() {
139            let config = self.config.read().await;
140            let bpe = get_bpe_from_model(&config.model)
141                .with_context(|| format!("Failed to get tokenizer for model: {}", config.model))?;
142            *cache = Some(bpe);
143        }
144        Ok(())
145    }
146
147    /// Count tokens in text
148    pub async fn count_tokens(&self, text: &str) -> Result<usize> {
149        self.ensure_tokenizer().await?;
150        let cache = self.tokenizer_cache.read().await;
151        let bpe = cache
152            .as_ref()
153            .ok_or_else(|| anyhow!("Tokenizer not initialized"))?;
154        Ok(bpe.encode_with_special_tokens(text).len())
155    }
156
157    /// Count tokens with component tracking
158    pub async fn count_tokens_for_component(
159        &self,
160        text: &str,
161        component: ContextComponent,
162        component_id: Option<&str>,
163    ) -> Result<usize> {
164        let token_count = self.count_tokens(text).await?;
165
166        self.record_tokens_for_component(component, token_count, component_id)
167            .await;
168
169        Ok(token_count)
170    }
171
172    /// Record token usage for a component using a provided token count.
173    pub async fn record_tokens_for_component(
174        &self,
175        component: ContextComponent,
176        tokens: usize,
177        component_id: Option<&str>,
178    ) {
179        if tokens == 0 {
180            return;
181        }
182
183        let detailed_tracking = {
184            let config = self.config.read().await;
185            config.detailed_tracking
186        };
187
188        if detailed_tracking {
189            let key = if let Some(id) = component_id {
190                format!("{:?}:{}", component, id)
191            } else {
192                format!("{:?}", component)
193            };
194            let mut components = self.component_tokens.write().await;
195            *components.entry(key).or_insert(0) += tokens;
196        }
197
198        let mut stats = self.stats.write().await;
199        stats.total_tokens += tokens;
200
201        match component {
202            ContextComponent::SystemPrompt => stats.system_prompt_tokens += tokens,
203            ContextComponent::UserMessage => stats.user_messages_tokens += tokens,
204            ContextComponent::AssistantMessage => stats.assistant_messages_tokens += tokens,
205            ContextComponent::ToolResult => stats.tool_results_tokens += tokens,
206            ContextComponent::DecisionLedger => stats.decision_ledger_tokens += tokens,
207            _ => {}
208        }
209
210        stats.timestamp = SystemTime::now()
211            .duration_since(UNIX_EPOCH)
212            .unwrap_or_default()
213            .as_secs();
214    }
215
216    /// Get current usage statistics
217    pub async fn get_stats(&self) -> TokenUsageStats {
218        self.stats.read().await.clone()
219    }
220
221    /// Get component-level token breakdown
222    pub async fn get_component_breakdown(&self) -> HashMap<String, usize> {
223        self.component_tokens.read().await.clone()
224    }
225
226    /// Check if warning threshold is exceeded
227    pub async fn is_warning_threshold_exceeded(&self) -> bool {
228        let stats = self.stats.read().await;
229        let config = self.config.read().await;
230        stats.needs_compaction(config.max_context_tokens, config.warning_threshold)
231    }
232
233    /// Check if compaction threshold is exceeded
234    pub async fn is_compaction_threshold_exceeded(&self) -> bool {
235        let stats = self.stats.read().await;
236        let config = self.config.read().await;
237        stats.needs_compaction(config.max_context_tokens, config.compaction_threshold)
238    }
239
240    /// Get current usage percentage
241    pub async fn usage_percentage(&self) -> f64 {
242        let stats = self.stats.read().await;
243        let config = self.config.read().await;
244        stats.usage_percentage(config.max_context_tokens)
245    }
246
247    /// Get remaining tokens in budget
248    pub async fn remaining_tokens(&self) -> usize {
249        let stats = self.stats.read().await;
250        let config = self.config.read().await;
251        config.max_context_tokens.saturating_sub(stats.total_tokens)
252    }
253
254    /// Reset token counts (e.g., after compaction)
255    pub async fn reset(&self) {
256        let mut stats = self.stats.write().await;
257        *stats = TokenUsageStats::new();
258        let mut components = self.component_tokens.write().await;
259        components.clear();
260        debug!("Token budget reset");
261    }
262
263    /// Deduct tokens (after compaction/removal)
264    pub async fn deduct_tokens(&self, component: ContextComponent, tokens: usize) {
265        let mut stats = self.stats.write().await;
266        stats.total_tokens = stats.total_tokens.saturating_sub(tokens);
267
268        match component {
269            ContextComponent::SystemPrompt => {
270                stats.system_prompt_tokens = stats.system_prompt_tokens.saturating_sub(tokens)
271            }
272            ContextComponent::UserMessage => {
273                stats.user_messages_tokens = stats.user_messages_tokens.saturating_sub(tokens)
274            }
275            ContextComponent::AssistantMessage => {
276                stats.assistant_messages_tokens =
277                    stats.assistant_messages_tokens.saturating_sub(tokens)
278            }
279            ContextComponent::ToolResult => {
280                stats.tool_results_tokens = stats.tool_results_tokens.saturating_sub(tokens)
281            }
282            ContextComponent::DecisionLedger => {
283                stats.decision_ledger_tokens = stats.decision_ledger_tokens.saturating_sub(tokens)
284            }
285            _ => {}
286        }
287
288        debug!("Deducted {} tokens from {:?}", tokens, component);
289    }
290
291    /// Generate a budget report
292    pub async fn generate_report(&self) -> String {
293        let stats = self.stats.read().await;
294        let config = self.config.read().await;
295        let components = self.component_tokens.read().await;
296
297        let usage_pct = stats.usage_percentage(config.max_context_tokens);
298        let remaining = config.max_context_tokens.saturating_sub(stats.total_tokens);
299
300        let mut report = format!(
301            "Token Budget Report\n\
302             ==================\n\
303             Total Tokens: {}/{} ({:.1}%)\n\
304             Remaining: {} tokens\n\n\
305             Breakdown by Category:\n\
306             - System Prompt: {} tokens\n\
307             - User Messages: {} tokens\n\
308             - Assistant Messages: {} tokens\n\
309             - Tool Results: {} tokens\n\
310             - Decision Ledger: {} tokens\n",
311            stats.total_tokens,
312            config.max_context_tokens,
313            usage_pct,
314            remaining,
315            stats.system_prompt_tokens,
316            stats.user_messages_tokens,
317            stats.assistant_messages_tokens,
318            stats.tool_results_tokens,
319            stats.decision_ledger_tokens
320        );
321
322        if config.detailed_tracking && !components.is_empty() {
323            report.push_str("\nDetailed Component Tracking:\n");
324            let mut sorted: Vec<_> = components.iter().collect();
325            sorted.sort_by(|a, b| b.1.cmp(a.1));
326            for (component, tokens) in sorted.iter().take(10) {
327                report.push_str(&format!("  - {}: {} tokens\n", component, tokens));
328            }
329        }
330
331        if usage_pct >= config.compaction_threshold * 100.0 {
332            report.push_str("\nALERT: Compaction threshold exceeded");
333        } else if usage_pct >= config.warning_threshold * 100.0 {
334            report.push_str("\nWARNING: Approaching token limit");
335        }
336
337        report
338    }
339}
340
341impl Default for TokenBudgetManager {
342    fn default() -> Self {
343        Self::new(TokenBudgetConfig::default())
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[tokio::test]
352    async fn test_token_counting() {
353        let config = TokenBudgetConfig::default();
354        let manager = TokenBudgetManager::new(config);
355
356        let text = "Hello, world!";
357        let count = manager.count_tokens(text).await.unwrap();
358        assert!(count > 0);
359    }
360
361    #[tokio::test]
362    async fn test_component_tracking() {
363        let mut config = TokenBudgetConfig::default();
364        config.detailed_tracking = true;
365        let manager = TokenBudgetManager::new(config);
366
367        let text = "This is a test message";
368        let count = manager
369            .count_tokens_for_component(text, ContextComponent::UserMessage, Some("msg1"))
370            .await
371            .unwrap();
372
373        assert!(count > 0);
374
375        let stats = manager.get_stats().await;
376        assert_eq!(stats.user_messages_tokens, count);
377    }
378
379    #[tokio::test]
380    async fn test_threshold_detection() {
381        let mut config = TokenBudgetConfig::default();
382        config.max_context_tokens = 100;
383        config.compaction_threshold = 0.8;
384        let manager = TokenBudgetManager::new(config);
385
386        // Add enough tokens to exceed threshold
387        let text = "word ".repeat(25); // Should be > 80 tokens
388        manager
389            .count_tokens_for_component(&text, ContextComponent::UserMessage, None)
390            .await
391            .unwrap();
392
393        assert!(manager.is_compaction_threshold_exceeded().await);
394    }
395
396    #[tokio::test]
397    async fn test_token_deduction() {
398        let manager = TokenBudgetManager::new(TokenBudgetConfig::default());
399
400        let text = "Hello, world!";
401        let count = manager
402            .count_tokens_for_component(text, ContextComponent::ToolResult, None)
403            .await
404            .unwrap();
405
406        let initial_total = manager.get_stats().await.total_tokens;
407
408        manager
409            .deduct_tokens(ContextComponent::ToolResult, count)
410            .await;
411
412        let after_total = manager.get_stats().await.total_tokens;
413        assert_eq!(after_total, initial_total - count);
414    }
415}