1use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use crate::traits::context_manager::ContextManager;
12use crate::traits::provider::Provider;
13use crate::types::agent_state::AgentState;
14use crate::types::completion::{CompletionRequest, ResponseContent};
15use crate::types::message::{Message, MessageRole};
16
17fn estimate_tokens(messages: &[Message]) -> usize {
19 messages.iter().map(|m| m.content.len() / 4 + 1).sum()
20}
21
22pub struct RuleBasedCompressor {
42 threshold: f64,
44 recent_count: usize,
46}
47
48impl RuleBasedCompressor {
49 #[must_use]
54 pub fn new(threshold: f64, recent_count: usize) -> Self {
55 Self {
56 threshold: threshold.clamp(0.0, 1.0),
57 recent_count,
58 }
59 }
60
61 fn score_message(msg: &Message, is_recent: bool) -> f64 {
63 if msg.role == MessageRole::System {
64 return f64::INFINITY; }
66 if is_recent {
67 return 0.9;
68 }
69 if msg.tool_call_id.is_some() || msg.role == MessageRole::Tool {
70 return 0.7;
71 }
72 0.3
73 }
74}
75
76impl Default for RuleBasedCompressor {
77 fn default() -> Self {
78 Self::new(0.85, 3)
79 }
80}
81
82#[async_trait]
83impl ContextManager for RuleBasedCompressor {
84 #[allow(
85 clippy::cast_possible_truncation,
86 clippy::cast_sign_loss,
87 clippy::cast_precision_loss
88 )]
89 async fn prepare(
90 &self,
91 messages: &mut Vec<Message>,
92 context_window: usize,
93 state: &mut AgentState,
94 ) {
95 let max_tokens = (context_window as f64 * self.threshold) as usize;
96
97 if estimate_tokens(messages) <= max_tokens {
98 return;
99 }
100
101 let total_non_system = messages
103 .iter()
104 .filter(|m| m.role != MessageRole::System)
105 .count();
106 let recent_start = total_non_system.saturating_sub(self.recent_count);
107
108 let mut scored: Vec<(usize, f64, usize)> = Vec::new(); let mut non_system_idx = 0usize;
110 for (i, msg) in messages.iter().enumerate() {
111 if msg.role == MessageRole::System {
112 continue;
113 }
114 let is_recent = non_system_idx >= recent_start;
115 let tokens = msg.content.len() / 4 + 1;
116 scored.push((i, Self::score_message(msg, is_recent), tokens));
117 non_system_idx += 1;
118 }
119
120 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
122
123 let mut projected_tokens = estimate_tokens(messages);
125 let mut remove_set: Vec<usize> = Vec::new();
126 for &(idx, score, tokens) in &scored {
127 if projected_tokens <= max_tokens {
128 break;
129 }
130 if score.is_infinite() {
131 continue; }
133 remove_set.push(idx);
134 projected_tokens = projected_tokens.saturating_sub(tokens);
135 }
136
137 if !remove_set.is_empty() {
138 remove_set.sort_unstable();
140 for &idx in remove_set.iter().rev() {
141 messages.remove(idx);
142 }
143 state.last_output_truncated = true;
144 }
145 }
146}
147
148pub struct LlmCompressor {
166 provider: Arc<dyn Provider>,
168 summary_prompt: String,
170 threshold: f64,
172 keep_recent: usize,
174}
175
176impl LlmCompressor {
177 const DEFAULT_PROMPT: &str = "Summarize the following conversation messages \
179 into a concise paragraph. Preserve key facts, decisions, and context. \
180 Omit greetings and filler.";
181
182 #[must_use]
184 pub fn new(provider: Arc<dyn Provider>) -> Self {
185 Self {
186 provider,
187 summary_prompt: Self::DEFAULT_PROMPT.to_string(),
188 threshold: 0.80,
189 keep_recent: 4,
190 }
191 }
192
193 #[must_use]
195 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
196 self.summary_prompt = prompt.into();
197 self
198 }
199
200 #[must_use]
202 pub fn with_threshold(mut self, threshold: f64) -> Self {
203 self.threshold = threshold.clamp(0.0, 1.0);
204 self
205 }
206
207 #[must_use]
209 pub fn with_keep_recent(mut self, count: usize) -> Self {
210 self.keep_recent = count;
211 self
212 }
213}
214
215#[async_trait]
216impl ContextManager for LlmCompressor {
217 #[allow(
218 clippy::cast_possible_truncation,
219 clippy::cast_sign_loss,
220 clippy::cast_precision_loss
221 )]
222 async fn prepare(
223 &self,
224 messages: &mut Vec<Message>,
225 context_window: usize,
226 state: &mut AgentState,
227 ) {
228 let max_tokens = (context_window as f64 * self.threshold) as usize;
229
230 if estimate_tokens(messages) <= max_tokens {
231 return;
232 }
233
234 let system_msgs: Vec<Message> = messages
236 .iter()
237 .filter(|m| m.role == MessageRole::System)
238 .cloned()
239 .collect();
240
241 let non_system: Vec<Message> = messages
242 .iter()
243 .filter(|m| m.role != MessageRole::System)
244 .cloned()
245 .collect();
246
247 if non_system.len() <= self.keep_recent {
248 return; }
250
251 let split_at = non_system.len() - self.keep_recent;
252 let old_messages = &non_system[..split_at];
253 let recent_messages = &non_system[split_at..];
254
255 let old_text: String = old_messages
257 .iter()
258 .map(|m| format!("{:?}: {}", m.role, m.content))
259 .collect::<Vec<_>>()
260 .join("\n");
261
262 let req = CompletionRequest {
264 model: self.provider.model_info().name.clone(),
265 messages: vec![
266 Message {
267 role: MessageRole::System,
268 content: self.summary_prompt.clone(),
269 tool_call_id: None,
270 },
271 Message {
272 role: MessageRole::User,
273 content: old_text,
274 tool_call_id: None,
275 },
276 ],
277 tools: vec![],
278 max_tokens: Some(500),
279 temperature: Some(0.3),
280 response_format: None,
281 stream: false,
282 };
283
284 let summary_text = match self.provider.complete(req).await {
285 Ok(response) => match response.content {
286 ResponseContent::Text(text) => text,
287 ResponseContent::ToolCalls(_) => {
288 tracing::warn!("LlmCompressor: provider returned tool calls instead of text");
289 Self::fallback_summary(old_messages)
290 }
291 },
292 Err(e) => {
293 tracing::warn!("LlmCompressor: summarization failed ({e}), using fallback");
294 Self::fallback_summary(old_messages)
295 }
296 };
297
298 let summary_msg = Message {
303 role: MessageRole::Assistant,
304 content: format!("[Context Summary] {summary_text}"),
305 tool_call_id: None,
306 };
307
308 messages.clear();
309 messages.extend(system_msgs);
310 messages.push(summary_msg);
311 messages.extend(recent_messages.iter().cloned());
312
313 state.last_output_truncated = true;
314 }
315}
316
317impl LlmCompressor {
318 fn fallback_summary(old_messages: &[Message]) -> String {
320 format!(
321 "{} earlier messages were removed to save context space.",
322 old_messages.len()
323 )
324 }
325}
326
327pub struct TieredCompressor {
343 recent_count: usize,
345 rule_compressor: RuleBasedCompressor,
347 llm_compressor: Option<LlmCompressor>,
349}
350
351impl TieredCompressor {
352 #[must_use]
354 pub fn new(recent_count: usize) -> Self {
355 Self {
356 recent_count,
357 rule_compressor: RuleBasedCompressor::new(0.85, recent_count),
358 llm_compressor: None,
359 }
360 }
361
362 #[must_use]
364 pub fn with_llm(mut self, provider: Arc<dyn Provider>) -> Self {
365 self.llm_compressor =
366 Some(LlmCompressor::new(provider).with_keep_recent(self.recent_count));
367 self
368 }
369}
370
371#[async_trait]
372impl ContextManager for TieredCompressor {
373 async fn prepare(
374 &self,
375 messages: &mut Vec<Message>,
376 context_window: usize,
377 state: &mut AgentState,
378 ) {
379 if let Some(llm) = &self.llm_compressor {
381 llm.prepare(messages, context_window, state).await;
382 }
383
384 self.rule_compressor
386 .prepare(messages, context_window, state)
387 .await;
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::types::completion::{CompletionResponse, ResponseContent, Usage};
395 use crate::types::model_info::{ModelInfo, ModelTier};
396 use crate::types::stream::CompletionStream;
397
398 fn msg(role: MessageRole, content: &str) -> Message {
399 Message {
400 role,
401 content: content.to_string(),
402 tool_call_id: None,
403 }
404 }
405
406 fn tool_msg(content: &str) -> Message {
407 Message {
408 role: MessageRole::Tool,
409 content: content.to_string(),
410 tool_call_id: Some("call_1".to_string()),
411 }
412 }
413
414 fn default_state() -> AgentState {
415 AgentState::new(ModelTier::Medium, 128_000)
416 }
417
418 #[tokio::test]
421 async fn test_rule_compressor_no_pruning_under_threshold() {
422 let comp = RuleBasedCompressor::default();
423 let mut msgs = vec![
424 msg(MessageRole::System, "system"),
425 msg(MessageRole::User, "hello"),
426 ];
427 let mut state = default_state();
428
429 comp.prepare(&mut msgs, 100_000, &mut state).await;
430 assert_eq!(msgs.len(), 2);
431 assert!(!state.last_output_truncated);
432 }
433
434 #[tokio::test]
435 async fn test_rule_compressor_removes_lowest_scored() {
436 let comp = RuleBasedCompressor::new(0.85, 1);
437 let mut msgs = vec![
438 msg(MessageRole::System, "system"),
439 msg(MessageRole::User, &"old1 ".repeat(500)), msg(MessageRole::Assistant, &"old2 ".repeat(500)), tool_msg(&"tool ".repeat(500)), msg(MessageRole::User, &"recent ".repeat(500)), ];
444 let mut state = default_state();
445
446 comp.prepare(&mut msgs, 800, &mut state).await;
448
449 assert_eq!(msgs[0].role, MessageRole::System);
451 assert!(msgs.len() < 5, "should have removed some messages");
453 assert!(state.last_output_truncated);
454 }
455
456 #[tokio::test]
457 async fn test_rule_compressor_never_removes_system() {
458 let comp = RuleBasedCompressor::new(0.5, 0);
459 let mut msgs = vec![
460 msg(MessageRole::System, &"sys ".repeat(1000)),
461 msg(MessageRole::User, "tiny"),
462 ];
463 let mut state = default_state();
464
465 comp.prepare(&mut msgs, 100, &mut state).await;
466
467 assert!(msgs.iter().any(|m| m.role == MessageRole::System));
469 }
470
471 #[tokio::test]
472 async fn test_rule_compressor_updates_state() {
473 let comp = RuleBasedCompressor::new(0.5, 0);
474 let mut msgs = vec![
475 msg(MessageRole::System, "sys"),
476 msg(MessageRole::User, &"x".repeat(4000)),
477 msg(MessageRole::Assistant, &"y".repeat(4000)),
478 ];
479 let mut state = default_state();
480
481 comp.prepare(&mut msgs, 1000, &mut state).await;
482 assert!(state.last_output_truncated);
483 }
484
485 struct MockSummarizer {
488 info: ModelInfo,
489 }
490
491 impl MockSummarizer {
492 fn new() -> Self {
493 Self {
494 info: ModelInfo::new(
495 "mock-summarizer",
496 ModelTier::Small,
497 4096,
498 false,
499 false,
500 false,
501 ),
502 }
503 }
504 }
505
506 #[async_trait]
507 impl Provider for MockSummarizer {
508 async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
509 Ok(CompletionResponse {
510 content: ResponseContent::Text(
511 "User asked about Rust. Assistant explained traits.".to_string(),
512 ),
513 usage: Usage {
514 prompt_tokens: 50,
515 completion_tokens: 20,
516 total_tokens: 70,
517 },
518 })
519 }
520
521 async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
522 unimplemented!()
523 }
524
525 fn model_info(&self) -> &ModelInfo {
526 &self.info
527 }
528 }
529
530 struct FailingSummarizer {
531 info: ModelInfo,
532 }
533
534 impl FailingSummarizer {
535 fn new() -> Self {
536 Self {
537 info: ModelInfo::new("failing", ModelTier::Small, 4096, false, false, false),
538 }
539 }
540 }
541
542 #[async_trait]
543 impl Provider for FailingSummarizer {
544 async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
545 Err(crate::Error::Provider {
546 message: "network error".into(),
547 status_code: None,
548 })
549 }
550
551 async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
552 unimplemented!()
553 }
554
555 fn model_info(&self) -> &ModelInfo {
556 &self.info
557 }
558 }
559
560 #[tokio::test]
561 async fn test_llm_compressor_summarizes_old_messages() {
562 let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
563 let comp = LlmCompressor::new(provider).with_keep_recent(2);
564
565 let mut msgs = vec![
566 msg(MessageRole::System, "You are helpful"),
567 msg(MessageRole::User, &"old question ".repeat(500)),
568 msg(MessageRole::Assistant, &"old answer ".repeat(500)),
569 msg(MessageRole::User, "recent question"),
570 msg(MessageRole::Assistant, "recent answer"),
571 ];
572 let mut state = default_state();
573
574 comp.prepare(&mut msgs, 800, &mut state).await;
575
576 assert_eq!(msgs[0].role, MessageRole::System);
578 assert!(
580 msgs[1].content.contains("[Context Summary]"),
581 "should have summary: {}",
582 msgs[1].content
583 );
584 assert_eq!(msgs.len(), 4); assert!(state.last_output_truncated);
587 }
588
589 #[tokio::test]
590 async fn test_llm_compressor_no_compression_under_threshold() {
591 let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
592 let comp = LlmCompressor::new(provider).with_keep_recent(2);
593
594 let mut msgs = vec![
595 msg(MessageRole::System, "sys"),
596 msg(MessageRole::User, "hi"),
597 msg(MessageRole::Assistant, "hello"),
598 ];
599 let mut state = default_state();
600
601 comp.prepare(&mut msgs, 100_000, &mut state).await;
602 assert_eq!(msgs.len(), 3);
603 assert!(!state.last_output_truncated);
604 }
605
606 #[tokio::test]
607 async fn test_llm_compressor_custom_prompt() {
608 let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
609 let comp = LlmCompressor::new(provider)
610 .with_prompt("Custom prompt")
611 .with_keep_recent(1);
612
613 let mut msgs = vec![
614 msg(MessageRole::System, "sys"),
615 msg(MessageRole::User, &"old ".repeat(2000)),
616 msg(MessageRole::User, "recent"),
617 ];
618 let mut state = default_state();
619
620 comp.prepare(&mut msgs, 500, &mut state).await;
621 assert!(msgs[1].content.contains("[Context Summary]"));
622 }
623
624 #[tokio::test]
625 async fn test_llm_compressor_fallback_on_failure() {
626 let provider: Arc<dyn Provider> = Arc::new(FailingSummarizer::new());
627 let comp = LlmCompressor::new(provider).with_keep_recent(1);
628
629 let mut msgs = vec![
630 msg(MessageRole::System, "sys"),
631 msg(MessageRole::User, &"old ".repeat(2000)),
632 msg(MessageRole::Assistant, &"old ".repeat(2000)),
633 msg(MessageRole::User, "recent"),
634 ];
635 let mut state = default_state();
636
637 comp.prepare(&mut msgs, 500, &mut state).await;
638
639 assert!(msgs[1].content.contains("[Context Summary]"));
641 assert!(msgs[1].content.contains("removed to save context"));
642 assert!(state.last_output_truncated);
643 }
644
645 #[tokio::test]
648 async fn test_tiered_compressor_rule_only() {
649 let comp = TieredCompressor::new(2);
650
651 let mut msgs = vec![
652 msg(MessageRole::System, "sys"),
653 msg(MessageRole::User, &"old1 ".repeat(500)),
654 msg(MessageRole::Assistant, &"old2 ".repeat(500)),
655 msg(MessageRole::User, &"recent1 ".repeat(500)),
656 msg(MessageRole::Assistant, &"recent2 ".repeat(500)),
657 ];
658 let mut state = default_state();
659
660 comp.prepare(&mut msgs, 1500, &mut state).await;
661
662 assert_eq!(msgs[0].role, MessageRole::System);
663 assert!(msgs.len() < 5);
664 assert!(state.last_output_truncated);
665 }
666
667 #[tokio::test]
668 async fn test_tiered_compressor_with_llm() {
669 let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
670 let comp = TieredCompressor::new(2).with_llm(provider);
671
672 let mut msgs = vec![
673 msg(MessageRole::System, "sys"),
674 msg(MessageRole::User, &"old ".repeat(1000)),
675 msg(MessageRole::Assistant, &"old ".repeat(1000)),
676 msg(MessageRole::User, "recent1"),
677 msg(MessageRole::Assistant, "recent2"),
678 ];
679 let mut state = default_state();
680
681 comp.prepare(&mut msgs, 800, &mut state).await;
682
683 assert_eq!(msgs[0].role, MessageRole::System);
684 assert!(
685 msgs.iter().any(|m| m.content.contains("[Context Summary]")),
686 "should have LLM summary"
687 );
688 assert!(state.last_output_truncated);
689 }
690
691 #[tokio::test]
694 async fn test_rule_compressor_50_messages_within_budget() {
695 let comp = RuleBasedCompressor::new(0.85, 5);
696
697 let mut msgs = vec![msg(MessageRole::System, "You are a helpful assistant")];
698 for i in 0..50 {
699 msgs.push(msg(
700 if i % 2 == 0 {
701 MessageRole::User
702 } else {
703 MessageRole::Assistant
704 },
705 &format!("Message number {i}: {}", "content ".repeat(100)),
706 ));
707 }
708 let mut state = default_state();
709
710 let window = 2000;
712 comp.prepare(&mut msgs, window, &mut state).await;
713
714 let tokens: usize = msgs.iter().map(|m| m.content.len() / 4 + 1).sum();
716 let max = (window as f64 * 0.85) as usize;
717 assert!(tokens <= max, "should be within budget: {tokens} <= {max}");
718 assert_eq!(msgs[0].role, MessageRole::System);
720 assert!(state.last_output_truncated);
721 }
722}