vtcode_core/core/
context_compression.rs

1use crate::config::constants::models;
2use crate::llm::provider::{LLMProvider, LLMRequest, Message, MessageRole};
3use serde::{Deserialize, Serialize};
4// std::collections::HashMap import removed as it's not used
5
6/// Context compression configuration
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ContextCompressionConfig {
9    pub max_context_length: usize,
10    pub compression_threshold: f64, // Percentage of max length to trigger compression
11    pub summary_max_length: usize,
12    pub preserve_recent_turns: usize, // Number of recent turns to always keep
13    pub preserve_system_messages: bool,
14    pub preserve_error_messages: bool,
15}
16
17impl Default for ContextCompressionConfig {
18    fn default() -> Self {
19        Self {
20            max_context_length: 128000, // ~128K tokens
21            compression_threshold: 0.8, // 80% of max length
22            summary_max_length: 2000,
23            preserve_recent_turns: 5,
24            preserve_system_messages: true,
25            preserve_error_messages: true,
26        }
27    }
28}
29
30/// Compressed context representation
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct CompressedContext {
33    pub summary: String,
34    pub preserved_messages: Vec<Message>,
35    pub compression_ratio: f64,
36    pub original_length: usize,
37    pub compressed_length: usize,
38    pub timestamp: u64,
39}
40
41/// Context compression engine
42pub struct ContextCompressor {
43    config: ContextCompressionConfig,
44    llm_provider: Box<dyn LLMProvider>,
45}
46
47impl ContextCompressor {
48    pub fn new(llm_provider: Box<dyn LLMProvider>) -> Self {
49        Self {
50            config: ContextCompressionConfig::default(),
51            llm_provider,
52        }
53    }
54
55    pub fn with_config(mut self, config: ContextCompressionConfig) -> Self {
56        self.config = config;
57        self
58    }
59
60    /// Check if context needs compression
61    pub fn needs_compression(&self, messages: &[Message]) -> bool {
62        let total_length = self.calculate_context_length(messages);
63        total_length
64            > (self.config.max_context_length as f64 * self.config.compression_threshold) as usize
65    }
66
67    /// Compress context by summarizing older messages
68    pub async fn compress_context(
69        &self,
70        messages: &[Message],
71    ) -> Result<CompressedContext, ContextCompressionError> {
72        if messages.is_empty() {
73            return Err(ContextCompressionError::EmptyContext);
74        }
75
76        let total_length = self.calculate_context_length(messages);
77
78        // Separate messages to preserve and summarize
79        let (to_preserve, to_summarize) = self.partition_messages(messages);
80
81        if to_summarize.is_empty() {
82            // No messages to summarize, return original
83            return Ok(CompressedContext {
84                summary: String::new(),
85                preserved_messages: messages.to_vec(),
86                compression_ratio: 1.0,
87                original_length: total_length,
88                compressed_length: total_length,
89                timestamp: std::time::SystemTime::now()
90                    .duration_since(std::time::UNIX_EPOCH)
91                    .unwrap()
92                    .as_secs(),
93            });
94        }
95
96        // Generate summary of messages to compress
97        let summary = self.generate_summary(&to_summarize).await?;
98
99        // Combine summary with preserved messages
100        let mut compressed_messages = Vec::new();
101
102        // Add summary as a system message if we have content to summarize
103        if !summary.is_empty() {
104            compressed_messages.push(Message {
105                role: MessageRole::System,
106                content: format!("Previous conversation summary: {}", summary),
107                tool_calls: None,
108                tool_call_id: None,
109            });
110        }
111
112        // Add preserved messages
113        compressed_messages.extend_from_slice(&to_preserve);
114
115        let compressed_length = self.calculate_context_length(&compressed_messages);
116        let compression_ratio = if total_length > 0 {
117            compressed_length as f64 / total_length as f64
118        } else {
119            1.0
120        };
121
122        Ok(CompressedContext {
123            summary,
124            preserved_messages: compressed_messages,
125            compression_ratio,
126            original_length: total_length,
127            compressed_length,
128            timestamp: std::time::SystemTime::now()
129                .duration_since(std::time::UNIX_EPOCH)
130                .unwrap()
131                .as_secs(),
132        })
133    }
134
135    /// Partition messages into those to preserve and those to summarize
136    fn partition_messages(&self, messages: &[Message]) -> (Vec<Message>, Vec<Message>) {
137        let mut to_preserve = Vec::new();
138        let mut to_summarize = Vec::new();
139
140        let len = messages.len();
141
142        for (i, message) in messages.iter().enumerate() {
143            let should_preserve = self.should_preserve_message(message, i, len);
144
145            if should_preserve {
146                to_preserve.push(message.clone());
147            } else {
148                to_summarize.push(message.clone());
149            }
150        }
151
152        (to_preserve, to_summarize)
153    }
154
155    /// Determine if a message should be preserved
156    fn should_preserve_message(&self, message: &Message, index: usize, total_len: usize) -> bool {
157        // Always preserve recent messages
158        if index >= total_len.saturating_sub(self.config.preserve_recent_turns) {
159            return true;
160        }
161
162        // Preserve decision ledger summaries explicitly
163        if message.content.contains("[Decision Ledger]")
164            || message
165                .content
166                .contains("Decision Ledger (most recent first)")
167        {
168            return true;
169        }
170
171        // Preserve system messages if configured
172        if self.config.preserve_system_messages && matches!(message.role, MessageRole::System) {
173            return true;
174        }
175
176        // Preserve messages that contain errors if configured
177        if self.config.preserve_error_messages && self.contains_error_indicators(&message.content) {
178            return true;
179        }
180
181        // Preserve tool calls and their results
182        if message.tool_calls.is_some() || message.tool_call_id.is_some() {
183            return true;
184        }
185
186        false
187    }
188
189    /// Check if message content contains error indicators
190    fn contains_error_indicators(&self, content: &str) -> bool {
191        let error_keywords = [
192            "error",
193            "failed",
194            "exception",
195            "crash",
196            "bug",
197            "issue",
198            "problem",
199            "unable",
200            "cannot",
201            "failed",
202            "timeout",
203            "connection refused",
204        ];
205
206        let content_lower = content.to_lowercase();
207        error_keywords
208            .iter()
209            .any(|&keyword| content_lower.contains(keyword))
210    }
211
212    /// Generate summary of messages using LLM
213    async fn generate_summary(
214        &self,
215        messages: &[Message],
216    ) -> Result<String, ContextCompressionError> {
217        if messages.is_empty() {
218            return Ok(String::new());
219        }
220
221        // Create a prompt for summarization
222        let conversation_text = self.messages_to_text(messages);
223
224        let system_prompt = "You are a helpful assistant that summarizes conversations. \
225                           Create a concise summary of the following conversation, \
226                           focusing on key decisions, completed tasks, and important context. \
227                           Keep the summary under 500 words."
228            .to_string();
229
230        let user_prompt = format!(
231            "Please summarize the following conversation:\n\n{}",
232            conversation_text
233        );
234
235        let request = LLMRequest {
236            messages: vec![
237                Message {
238                    role: MessageRole::System,
239                    content: system_prompt,
240                    tool_calls: None,
241                    tool_call_id: None,
242                },
243                Message {
244                    role: MessageRole::User,
245                    content: user_prompt,
246                    tool_calls: None,
247                    tool_call_id: None,
248                },
249            ],
250            system_prompt: None,
251            tools: None,
252            model: models::GPT_5_MINI.to_string(), // Use a lightweight model for summarization
253            max_tokens: Some(1000),
254            temperature: Some(0.3),
255            stream: false,
256            tool_choice: None,
257            parallel_tool_calls: None,
258            parallel_tool_config: None,
259            reasoning_effort: None,
260        };
261
262        let response = self
263            .llm_provider
264            .generate(request)
265            .await
266            .map_err(|e| ContextCompressionError::LLMError(e.to_string()))?;
267
268        Ok(response.content.unwrap_or_default())
269    }
270
271    /// Convert messages to readable text
272    fn messages_to_text(&self, messages: &[Message]) -> String {
273        let mut text = String::new();
274
275        for message in messages {
276            let role = match message.role {
277                MessageRole::System => "System",
278                MessageRole::User => "User",
279                MessageRole::Assistant => "Assistant",
280                MessageRole::Tool => "Tool",
281            };
282
283            text.push_str(&format!("{}: {}\n\n", role, message.content));
284
285            if let Some(tool_calls) = &message.tool_calls {
286                for tool_call in tool_calls {
287                    text.push_str(&format!(
288                        "Tool Call: {}({})\n",
289                        tool_call.function.name, tool_call.function.arguments
290                    ));
291                }
292            }
293        }
294
295        text
296    }
297
298    /// Calculate total context length (approximate token count)
299    fn calculate_context_length(&self, messages: &[Message]) -> usize {
300        let mut total_chars = 0;
301
302        for message in messages {
303            total_chars += message.content.len();
304
305            if let Some(tool_calls) = &message.tool_calls {
306                for tool_call in tool_calls {
307                    total_chars += tool_call.function.name.len();
308                    total_chars += tool_call.function.arguments.len();
309                }
310            }
311        }
312
313        // Rough approximation: 1 token ≈ 4 characters
314        total_chars / 4
315    }
316}
317
318/// Context compression errors
319#[derive(Debug, thiserror::Error)]
320pub enum ContextCompressionError {
321    #[error("Empty context provided")]
322    EmptyContext,
323
324    #[error("LLM error: {0}")]
325    LLMError(String),
326
327    #[error("Serialization error: {0}")]
328    SerializationError(#[from] serde_json::Error),
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate::llm::provider::{
335        FinishReason, LLMError, LLMProvider, LLMRequest, LLMResponse, Message, MessageRole,
336    };
337
338    #[test]
339    fn test_context_length_calculation() {
340        let compressor = ContextCompressor::new(Box::new(MockProvider::new()));
341
342        let messages = vec![
343            Message {
344                role: MessageRole::User,
345                content: "Hello world".to_string(),
346                tool_calls: None,
347                tool_call_id: None,
348            },
349            Message {
350                role: MessageRole::Assistant,
351                content: "Hi there! How can I help you?".to_string(),
352                tool_calls: None,
353                tool_call_id: None,
354            },
355        ];
356
357        let length = compressor.calculate_context_length(&messages);
358        assert_eq!(
359            length,
360            ("Hello worldHi there! How can I help you?".len()) / 4
361        );
362    }
363
364    #[test]
365    fn test_needs_compression() {
366        let mut config = ContextCompressionConfig::default();
367        config.max_context_length = 100;
368        config.compression_threshold = 0.8;
369
370        let compressor = ContextCompressor::new(Box::new(MockProvider::new())).with_config(config);
371
372        let messages = vec![Message {
373            role: MessageRole::User,
374            content: "x".repeat(400), // ~100 tokens
375            tool_calls: None,
376            tool_call_id: None,
377        }];
378
379        assert!(compressor.needs_compression(&messages));
380    }
381
382    // Mock provider for testing
383    struct MockProvider;
384
385    impl MockProvider {
386        fn new() -> Self {
387            Self
388        }
389    }
390
391    #[async_trait::async_trait]
392    impl LLMProvider for MockProvider {
393        fn name(&self) -> &str {
394            "mock"
395        }
396
397        async fn generate(&self, _request: LLMRequest) -> Result<LLMResponse, LLMError> {
398            Ok(LLMResponse {
399                content: Some("Mock summary".to_string()),
400                tool_calls: None,
401                usage: None,
402                finish_reason: FinishReason::Stop,
403                reasoning: None,
404            })
405        }
406
407        fn supported_models(&self) -> Vec<String> {
408            vec!["mock".to_string()]
409        }
410
411        fn validate_request(&self, _request: &LLMRequest) -> Result<(), LLMError> {
412            Ok(())
413        }
414    }
415}