1use crate::config::constants::models;
2use crate::llm::provider::{LLMProvider, LLMRequest, Message, MessageRole};
3use serde::{Deserialize, Serialize};
4#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ContextCompressionConfig {
9 pub max_context_length: usize,
10 pub compression_threshold: f64, pub summary_max_length: usize,
12 pub preserve_recent_turns: usize, pub preserve_system_messages: bool,
14 pub preserve_error_messages: bool,
15}
16
17impl Default for ContextCompressionConfig {
18 fn default() -> Self {
19 Self {
20 max_context_length: 128000, compression_threshold: 0.8, summary_max_length: 2000,
23 preserve_recent_turns: 5,
24 preserve_system_messages: true,
25 preserve_error_messages: true,
26 }
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct CompressedContext {
33 pub summary: String,
34 pub preserved_messages: Vec<Message>,
35 pub compression_ratio: f64,
36 pub original_length: usize,
37 pub compressed_length: usize,
38 pub timestamp: u64,
39}
40
41pub struct ContextCompressor {
43 config: ContextCompressionConfig,
44 llm_provider: Box<dyn LLMProvider>,
45}
46
47impl ContextCompressor {
48 pub fn new(llm_provider: Box<dyn LLMProvider>) -> Self {
49 Self {
50 config: ContextCompressionConfig::default(),
51 llm_provider,
52 }
53 }
54
55 pub fn with_config(mut self, config: ContextCompressionConfig) -> Self {
56 self.config = config;
57 self
58 }
59
60 pub fn needs_compression(&self, messages: &[Message]) -> bool {
62 let total_length = self.calculate_context_length(messages);
63 total_length
64 > (self.config.max_context_length as f64 * self.config.compression_threshold) as usize
65 }
66
67 pub async fn compress_context(
69 &self,
70 messages: &[Message],
71 ) -> Result<CompressedContext, ContextCompressionError> {
72 if messages.is_empty() {
73 return Err(ContextCompressionError::EmptyContext);
74 }
75
76 let total_length = self.calculate_context_length(messages);
77
78 let (to_preserve, to_summarize) = self.partition_messages(messages);
80
81 if to_summarize.is_empty() {
82 return Ok(CompressedContext {
84 summary: String::new(),
85 preserved_messages: messages.to_vec(),
86 compression_ratio: 1.0,
87 original_length: total_length,
88 compressed_length: total_length,
89 timestamp: std::time::SystemTime::now()
90 .duration_since(std::time::UNIX_EPOCH)
91 .unwrap()
92 .as_secs(),
93 });
94 }
95
96 let summary = self.generate_summary(&to_summarize).await?;
98
99 let mut compressed_messages = Vec::new();
101
102 if !summary.is_empty() {
104 compressed_messages.push(Message {
105 role: MessageRole::System,
106 content: format!("Previous conversation summary: {}", summary),
107 tool_calls: None,
108 tool_call_id: None,
109 });
110 }
111
112 compressed_messages.extend_from_slice(&to_preserve);
114
115 let compressed_length = self.calculate_context_length(&compressed_messages);
116 let compression_ratio = if total_length > 0 {
117 compressed_length as f64 / total_length as f64
118 } else {
119 1.0
120 };
121
122 Ok(CompressedContext {
123 summary,
124 preserved_messages: compressed_messages,
125 compression_ratio,
126 original_length: total_length,
127 compressed_length,
128 timestamp: std::time::SystemTime::now()
129 .duration_since(std::time::UNIX_EPOCH)
130 .unwrap()
131 .as_secs(),
132 })
133 }
134
135 fn partition_messages(&self, messages: &[Message]) -> (Vec<Message>, Vec<Message>) {
137 let mut to_preserve = Vec::new();
138 let mut to_summarize = Vec::new();
139
140 let len = messages.len();
141
142 for (i, message) in messages.iter().enumerate() {
143 let should_preserve = self.should_preserve_message(message, i, len);
144
145 if should_preserve {
146 to_preserve.push(message.clone());
147 } else {
148 to_summarize.push(message.clone());
149 }
150 }
151
152 (to_preserve, to_summarize)
153 }
154
155 fn should_preserve_message(&self, message: &Message, index: usize, total_len: usize) -> bool {
157 if index >= total_len.saturating_sub(self.config.preserve_recent_turns) {
159 return true;
160 }
161
162 if message.content.contains("[Decision Ledger]")
164 || message
165 .content
166 .contains("Decision Ledger (most recent first)")
167 {
168 return true;
169 }
170
171 if self.config.preserve_system_messages && matches!(message.role, MessageRole::System) {
173 return true;
174 }
175
176 if self.config.preserve_error_messages && self.contains_error_indicators(&message.content) {
178 return true;
179 }
180
181 if message.tool_calls.is_some() || message.tool_call_id.is_some() {
183 return true;
184 }
185
186 false
187 }
188
189 fn contains_error_indicators(&self, content: &str) -> bool {
191 let error_keywords = [
192 "error",
193 "failed",
194 "exception",
195 "crash",
196 "bug",
197 "issue",
198 "problem",
199 "unable",
200 "cannot",
201 "failed",
202 "timeout",
203 "connection refused",
204 ];
205
206 let content_lower = content.to_lowercase();
207 error_keywords
208 .iter()
209 .any(|&keyword| content_lower.contains(keyword))
210 }
211
212 async fn generate_summary(
214 &self,
215 messages: &[Message],
216 ) -> Result<String, ContextCompressionError> {
217 if messages.is_empty() {
218 return Ok(String::new());
219 }
220
221 let conversation_text = self.messages_to_text(messages);
223
224 let system_prompt = "You are a helpful assistant that summarizes conversations. \
225 Create a concise summary of the following conversation, \
226 focusing on key decisions, completed tasks, and important context. \
227 Keep the summary under 500 words."
228 .to_string();
229
230 let user_prompt = format!(
231 "Please summarize the following conversation:\n\n{}",
232 conversation_text
233 );
234
235 let request = LLMRequest {
236 messages: vec![
237 Message {
238 role: MessageRole::System,
239 content: system_prompt,
240 tool_calls: None,
241 tool_call_id: None,
242 },
243 Message {
244 role: MessageRole::User,
245 content: user_prompt,
246 tool_calls: None,
247 tool_call_id: None,
248 },
249 ],
250 system_prompt: None,
251 tools: None,
252 model: models::GPT_5_MINI.to_string(), max_tokens: Some(1000),
254 temperature: Some(0.3),
255 stream: false,
256 tool_choice: None,
257 parallel_tool_calls: None,
258 parallel_tool_config: None,
259 reasoning_effort: None,
260 };
261
262 let response = self
263 .llm_provider
264 .generate(request)
265 .await
266 .map_err(|e| ContextCompressionError::LLMError(e.to_string()))?;
267
268 Ok(response.content.unwrap_or_default())
269 }
270
271 fn messages_to_text(&self, messages: &[Message]) -> String {
273 let mut text = String::new();
274
275 for message in messages {
276 let role = match message.role {
277 MessageRole::System => "System",
278 MessageRole::User => "User",
279 MessageRole::Assistant => "Assistant",
280 MessageRole::Tool => "Tool",
281 };
282
283 text.push_str(&format!("{}: {}\n\n", role, message.content));
284
285 if let Some(tool_calls) = &message.tool_calls {
286 for tool_call in tool_calls {
287 text.push_str(&format!(
288 "Tool Call: {}({})\n",
289 tool_call.function.name, tool_call.function.arguments
290 ));
291 }
292 }
293 }
294
295 text
296 }
297
298 fn calculate_context_length(&self, messages: &[Message]) -> usize {
300 let mut total_chars = 0;
301
302 for message in messages {
303 total_chars += message.content.len();
304
305 if let Some(tool_calls) = &message.tool_calls {
306 for tool_call in tool_calls {
307 total_chars += tool_call.function.name.len();
308 total_chars += tool_call.function.arguments.len();
309 }
310 }
311 }
312
313 total_chars / 4
315 }
316}
317
318#[derive(Debug, thiserror::Error)]
320pub enum ContextCompressionError {
321 #[error("Empty context provided")]
322 EmptyContext,
323
324 #[error("LLM error: {0}")]
325 LLMError(String),
326
327 #[error("Serialization error: {0}")]
328 SerializationError(#[from] serde_json::Error),
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::llm::provider::{
335 FinishReason, LLMError, LLMProvider, LLMRequest, LLMResponse, Message, MessageRole,
336 };
337
338 #[test]
339 fn test_context_length_calculation() {
340 let compressor = ContextCompressor::new(Box::new(MockProvider::new()));
341
342 let messages = vec![
343 Message {
344 role: MessageRole::User,
345 content: "Hello world".to_string(),
346 tool_calls: None,
347 tool_call_id: None,
348 },
349 Message {
350 role: MessageRole::Assistant,
351 content: "Hi there! How can I help you?".to_string(),
352 tool_calls: None,
353 tool_call_id: None,
354 },
355 ];
356
357 let length = compressor.calculate_context_length(&messages);
358 assert_eq!(
359 length,
360 ("Hello worldHi there! How can I help you?".len()) / 4
361 );
362 }
363
364 #[test]
365 fn test_needs_compression() {
366 let mut config = ContextCompressionConfig::default();
367 config.max_context_length = 100;
368 config.compression_threshold = 0.8;
369
370 let compressor = ContextCompressor::new(Box::new(MockProvider::new())).with_config(config);
371
372 let messages = vec![Message {
373 role: MessageRole::User,
374 content: "x".repeat(400), tool_calls: None,
376 tool_call_id: None,
377 }];
378
379 assert!(compressor.needs_compression(&messages));
380 }
381
382 struct MockProvider;
384
385 impl MockProvider {
386 fn new() -> Self {
387 Self
388 }
389 }
390
391 #[async_trait::async_trait]
392 impl LLMProvider for MockProvider {
393 fn name(&self) -> &str {
394 "mock"
395 }
396
397 async fn generate(&self, _request: LLMRequest) -> Result<LLMResponse, LLMError> {
398 Ok(LLMResponse {
399 content: Some("Mock summary".to_string()),
400 tool_calls: None,
401 usage: None,
402 finish_reason: FinishReason::Stop,
403 reasoning: None,
404 })
405 }
406
407 fn supported_models(&self) -> Vec<String> {
408 vec!["mock".to_string()]
409 }
410
411 fn validate_request(&self, _request: &LLMRequest) -> Result<(), LLMError> {
412 Ok(())
413 }
414 }
415}