vtcode_core/core/
token_budget.rs

1//! Token budget management for context engineering
2//!
3//! This module implements token counting and budget tracking to manage
4//! the attention budget of LLMs. Following Anthropic's context engineering
5//! principles, it helps prevent context rot by tracking token usage and
6//! triggering compaction when thresholds are exceeded.
7
8use anyhow::{Context, Result, anyhow};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{SystemTime, UNIX_EPOCH};
13use tiktoken_rs::get_bpe_from_model;
14use tokio::sync::RwLock;
15use tracing::debug;
16
17/// Token budget configuration
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct TokenBudgetConfig {
20    /// Maximum tokens allowed in context window
21    pub max_context_tokens: usize,
22    /// Threshold percentage to trigger warnings (0.0-1.0)
23    pub warning_threshold: f64,
24    /// Threshold percentage to trigger compaction (0.0-1.0)
25    pub compaction_threshold: f64,
26    /// Model name for tokenizer selection
27    pub model: String,
28    /// Enable detailed token tracking
29    pub detailed_tracking: bool,
30}
31
32impl Default for TokenBudgetConfig {
33    fn default() -> Self {
34        Self {
35            max_context_tokens: 128_000,
36            warning_threshold: 0.75,
37            compaction_threshold: 0.85,
38            model: "gpt-4".to_string(),
39            detailed_tracking: false,
40        }
41    }
42}
43
44impl TokenBudgetConfig {
45    /// Create config for specific model
46    pub fn for_model(model: &str, max_tokens: usize) -> Self {
47        Self {
48            max_context_tokens: max_tokens,
49            model: model.to_string(),
50            ..Default::default()
51        }
52    }
53}
54
55/// Token usage statistics
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct TokenUsageStats {
58    pub total_tokens: usize,
59    pub system_prompt_tokens: usize,
60    pub user_messages_tokens: usize,
61    pub assistant_messages_tokens: usize,
62    pub tool_results_tokens: usize,
63    pub decision_ledger_tokens: usize,
64    pub timestamp: u64,
65}
66
67impl TokenUsageStats {
68    pub fn new() -> Self {
69        Self {
70            total_tokens: 0,
71            system_prompt_tokens: 0,
72            user_messages_tokens: 0,
73            assistant_messages_tokens: 0,
74            tool_results_tokens: 0,
75            decision_ledger_tokens: 0,
76            timestamp: SystemTime::now()
77                .duration_since(UNIX_EPOCH)
78                .unwrap_or_default()
79                .as_secs(),
80        }
81    }
82
83    /// Calculate percentage of max context used
84    pub fn usage_percentage(&self, max_tokens: usize) -> f64 {
85        if max_tokens == 0 {
86            return 0.0;
87        }
88        (self.total_tokens as f64 / max_tokens as f64) * 100.0
89    }
90
91    /// Check if compaction is needed
92    pub fn needs_compaction(&self, max_tokens: usize, threshold: f64) -> bool {
93        let usage = self.total_tokens as f64 / max_tokens as f64;
94        usage >= threshold
95    }
96}
97
98impl Default for TokenUsageStats {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104/// Component types for detailed tracking
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
106pub enum ContextComponent {
107    SystemPrompt,
108    UserMessage,
109    AssistantMessage,
110    ToolResult,
111    DecisionLedger,
112    ProjectGuidelines,
113    FileContent,
114}
115
116/// Token budget manager
117pub struct TokenBudgetManager {
118    config: Arc<RwLock<TokenBudgetConfig>>,
119    stats: Arc<RwLock<TokenUsageStats>>,
120    component_tokens: Arc<RwLock<HashMap<String, usize>>>,
121    tokenizer_cache: Arc<RwLock<Option<tiktoken_rs::CoreBPE>>>,
122}
123
124impl TokenBudgetManager {
125    /// Create a new token budget manager
126    pub fn new(config: TokenBudgetConfig) -> Self {
127        Self {
128            config: Arc::new(RwLock::new(config)),
129            stats: Arc::new(RwLock::new(TokenUsageStats::new())),
130            component_tokens: Arc::new(RwLock::new(HashMap::new())),
131            tokenizer_cache: Arc::new(RwLock::new(None)),
132        }
133    }
134
135    /// Initialize or update tokenizer for the current model
136    async fn ensure_tokenizer(&self) -> Result<()> {
137        let mut cache = self.tokenizer_cache.write().await;
138        if cache.is_none() {
139            let config = self.config.read().await;
140            let bpe = get_bpe_from_model(&config.model)
141                .with_context(|| format!("Failed to get tokenizer for model: {}", config.model))?;
142            *cache = Some(bpe);
143        }
144        Ok(())
145    }
146
147    /// Count tokens in text
148    pub async fn count_tokens(&self, text: &str) -> Result<usize> {
149        self.ensure_tokenizer().await?;
150        let cache = self.tokenizer_cache.read().await;
151        let bpe = cache
152            .as_ref()
153            .ok_or_else(|| anyhow!("Tokenizer not initialized"))?;
154        Ok(bpe.encode_with_special_tokens(text).len())
155    }
156
157    /// Count tokens with component tracking
158    pub async fn count_tokens_for_component(
159        &self,
160        text: &str,
161        component: ContextComponent,
162        component_id: Option<&str>,
163    ) -> Result<usize> {
164        let token_count = self.count_tokens(text).await?;
165
166        // Update component tracking
167        if self.config.read().await.detailed_tracking {
168            let key = if let Some(id) = component_id {
169                format!("{:?}:{}", component, id)
170            } else {
171                format!("{:?}", component)
172            };
173            let mut components = self.component_tokens.write().await;
174            *components.entry(key).or_insert(0) += token_count;
175        }
176
177        // Update stats
178        let mut stats = self.stats.write().await;
179        stats.total_tokens += token_count;
180
181        match component {
182            ContextComponent::SystemPrompt => stats.system_prompt_tokens += token_count,
183            ContextComponent::UserMessage => stats.user_messages_tokens += token_count,
184            ContextComponent::AssistantMessage => stats.assistant_messages_tokens += token_count,
185            ContextComponent::ToolResult => stats.tool_results_tokens += token_count,
186            ContextComponent::DecisionLedger => stats.decision_ledger_tokens += token_count,
187            _ => {}
188        }
189
190        stats.timestamp = SystemTime::now()
191            .duration_since(UNIX_EPOCH)
192            .unwrap_or_default()
193            .as_secs();
194
195        Ok(token_count)
196    }
197
198    /// Get current usage statistics
199    pub async fn get_stats(&self) -> TokenUsageStats {
200        self.stats.read().await.clone()
201    }
202
203    /// Get component-level token breakdown
204    pub async fn get_component_breakdown(&self) -> HashMap<String, usize> {
205        self.component_tokens.read().await.clone()
206    }
207
208    /// Check if warning threshold is exceeded
209    pub async fn is_warning_threshold_exceeded(&self) -> bool {
210        let stats = self.stats.read().await;
211        let config = self.config.read().await;
212        stats.needs_compaction(config.max_context_tokens, config.warning_threshold)
213    }
214
215    /// Check if compaction threshold is exceeded
216    pub async fn is_compaction_threshold_exceeded(&self) -> bool {
217        let stats = self.stats.read().await;
218        let config = self.config.read().await;
219        stats.needs_compaction(config.max_context_tokens, config.compaction_threshold)
220    }
221
222    /// Get current usage percentage
223    pub async fn usage_percentage(&self) -> f64 {
224        let stats = self.stats.read().await;
225        let config = self.config.read().await;
226        stats.usage_percentage(config.max_context_tokens)
227    }
228
229    /// Get remaining tokens in budget
230    pub async fn remaining_tokens(&self) -> usize {
231        let stats = self.stats.read().await;
232        let config = self.config.read().await;
233        config.max_context_tokens.saturating_sub(stats.total_tokens)
234    }
235
236    /// Reset token counts (e.g., after compaction)
237    pub async fn reset(&self) {
238        let mut stats = self.stats.write().await;
239        *stats = TokenUsageStats::new();
240        let mut components = self.component_tokens.write().await;
241        components.clear();
242        debug!("Token budget reset");
243    }
244
245    /// Deduct tokens (after compaction/removal)
246    pub async fn deduct_tokens(&self, component: ContextComponent, tokens: usize) {
247        let mut stats = self.stats.write().await;
248        stats.total_tokens = stats.total_tokens.saturating_sub(tokens);
249
250        match component {
251            ContextComponent::SystemPrompt => {
252                stats.system_prompt_tokens = stats.system_prompt_tokens.saturating_sub(tokens)
253            }
254            ContextComponent::UserMessage => {
255                stats.user_messages_tokens = stats.user_messages_tokens.saturating_sub(tokens)
256            }
257            ContextComponent::AssistantMessage => {
258                stats.assistant_messages_tokens =
259                    stats.assistant_messages_tokens.saturating_sub(tokens)
260            }
261            ContextComponent::ToolResult => {
262                stats.tool_results_tokens = stats.tool_results_tokens.saturating_sub(tokens)
263            }
264            ContextComponent::DecisionLedger => {
265                stats.decision_ledger_tokens = stats.decision_ledger_tokens.saturating_sub(tokens)
266            }
267            _ => {}
268        }
269
270        debug!("Deducted {} tokens from {:?}", tokens, component);
271    }
272
273    /// Generate a budget report
274    pub async fn generate_report(&self) -> String {
275        let stats = self.stats.read().await;
276        let config = self.config.read().await;
277        let components = self.component_tokens.read().await;
278
279        let usage_pct = stats.usage_percentage(config.max_context_tokens);
280        let remaining = config.max_context_tokens.saturating_sub(stats.total_tokens);
281
282        let mut report = format!(
283            "Token Budget Report\n\
284             ==================\n\
285             Total Tokens: {}/{} ({:.1}%)\n\
286             Remaining: {} tokens\n\n\
287             Breakdown by Category:\n\
288             - System Prompt: {} tokens\n\
289             - User Messages: {} tokens\n\
290             - Assistant Messages: {} tokens\n\
291             - Tool Results: {} tokens\n\
292             - Decision Ledger: {} tokens\n",
293            stats.total_tokens,
294            config.max_context_tokens,
295            usage_pct,
296            remaining,
297            stats.system_prompt_tokens,
298            stats.user_messages_tokens,
299            stats.assistant_messages_tokens,
300            stats.tool_results_tokens,
301            stats.decision_ledger_tokens
302        );
303
304        if config.detailed_tracking && !components.is_empty() {
305            report.push_str("\nDetailed Component Tracking:\n");
306            let mut sorted: Vec<_> = components.iter().collect();
307            sorted.sort_by(|a, b| b.1.cmp(a.1));
308            for (component, tokens) in sorted.iter().take(10) {
309                report.push_str(&format!("  - {}: {} tokens\n", component, tokens));
310            }
311        }
312
313        if usage_pct >= config.compaction_threshold * 100.0 {
314            report.push_str("\nALERT: Compaction threshold exceeded");
315        } else if usage_pct >= config.warning_threshold * 100.0 {
316            report.push_str("\nWARNING: Approaching token limit");
317        }
318
319        report
320    }
321}
322
323impl Default for TokenBudgetManager {
324    fn default() -> Self {
325        Self::new(TokenBudgetConfig::default())
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[tokio::test]
334    async fn test_token_counting() {
335        let config = TokenBudgetConfig::default();
336        let manager = TokenBudgetManager::new(config);
337
338        let text = "Hello, world!";
339        let count = manager.count_tokens(text).await.unwrap();
340        assert!(count > 0);
341    }
342
343    #[tokio::test]
344    async fn test_component_tracking() {
345        let mut config = TokenBudgetConfig::default();
346        config.detailed_tracking = true;
347        let manager = TokenBudgetManager::new(config);
348
349        let text = "This is a test message";
350        let count = manager
351            .count_tokens_for_component(text, ContextComponent::UserMessage, Some("msg1"))
352            .await
353            .unwrap();
354
355        assert!(count > 0);
356
357        let stats = manager.get_stats().await;
358        assert_eq!(stats.user_messages_tokens, count);
359    }
360
361    #[tokio::test]
362    async fn test_threshold_detection() {
363        let mut config = TokenBudgetConfig::default();
364        config.max_context_tokens = 100;
365        config.compaction_threshold = 0.8;
366        let manager = TokenBudgetManager::new(config);
367
368        // Add enough tokens to exceed threshold
369        let text = "word ".repeat(25); // Should be > 80 tokens
370        manager
371            .count_tokens_for_component(&text, ContextComponent::UserMessage, None)
372            .await
373            .unwrap();
374
375        assert!(manager.is_compaction_threshold_exceeded().await);
376    }
377
378    #[tokio::test]
379    async fn test_token_deduction() {
380        let manager = TokenBudgetManager::new(TokenBudgetConfig::default());
381
382        let text = "Hello, world!";
383        let count = manager
384            .count_tokens_for_component(text, ContextComponent::ToolResult, None)
385            .await
386            .unwrap();
387
388        let initial_total = manager.get_stats().await.total_tokens;
389
390        manager
391            .deduct_tokens(ContextComponent::ToolResult, count)
392            .await;
393
394        let after_total = manager.get_stats().await.total_tokens;
395        assert_eq!(after_total, initial_total - count);
396    }
397}