Skip to main content

traitclaw_core/
context_managers.rs

1//! Built-in [`ContextManager`] implementations for common compression strategies.
2//!
3//! - [`RuleBasedCompressor`]: importance-scored message pruning
4//! - [`LlmCompressor`]: LLM-powered summarization of old messages
5//! - [`TieredCompressor`]: chained keep-recent → rule-compress → LLM-summarize
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use crate::traits::context_manager::ContextManager;
12use crate::traits::provider::Provider;
13use crate::types::agent_state::AgentState;
14use crate::types::completion::{CompletionRequest, ResponseContent};
15use crate::types::message::{Message, MessageRole};
16
17/// Estimate token count for a message list (4 chars ≈ 1 token).
18fn estimate_tokens(messages: &[Message]) -> usize {
19    messages.iter().map(|m| m.content.len() / 4 + 1).sum()
20}
21
22// ---------------------------------------------------------------------------
23// RuleBasedCompressor
24// ---------------------------------------------------------------------------
25
26/// Scores messages by importance and removes lowest-scored first.
27///
28/// Scoring table (configurable):
29/// - System messages: **never removed** (score ∞)
30/// - Last `recent_count` messages: 0.9
31/// - Tool-result messages: 0.7
32/// - Older user/assistant messages: 0.3
33///
34/// # Example
35///
36/// ```rust
37/// use traitclaw_core::context_managers::RuleBasedCompressor;
38///
39/// let compressor = RuleBasedCompressor::new(0.85, 3);
40/// ```
41pub struct RuleBasedCompressor {
42    /// Fraction of `context_window` at which pruning begins.
43    threshold: f64,
44    /// Number of recent non-system messages to protect (score 0.9).
45    recent_count: usize,
46}
47
48impl RuleBasedCompressor {
49    /// Create a compressor with custom threshold and recent-message protection.
50    ///
51    /// - `threshold`: 0.0–1.0, fraction of `context_window` triggering compression.
52    /// - `recent_count`: number of most-recent messages to protect from removal.
53    #[must_use]
54    pub fn new(threshold: f64, recent_count: usize) -> Self {
55        Self {
56            threshold: threshold.clamp(0.0, 1.0),
57            recent_count,
58        }
59    }
60
61    /// Score a message by importance. Higher = more important.
62    fn score_message(msg: &Message, is_recent: bool) -> f64 {
63        if msg.role == MessageRole::System {
64            return f64::INFINITY; // never remove
65        }
66        if is_recent {
67            return 0.9;
68        }
69        if msg.tool_call_id.is_some() || msg.role == MessageRole::Tool {
70            return 0.7;
71        }
72        0.3
73    }
74}
75
76impl Default for RuleBasedCompressor {
77    fn default() -> Self {
78        Self::new(0.85, 3)
79    }
80}
81
82#[async_trait]
83impl ContextManager for RuleBasedCompressor {
84    #[allow(
85        clippy::cast_possible_truncation,
86        clippy::cast_sign_loss,
87        clippy::cast_precision_loss
88    )]
89    async fn prepare(
90        &self,
91        messages: &mut Vec<Message>,
92        context_window: usize,
93        state: &mut AgentState,
94    ) {
95        let max_tokens = (context_window as f64 * self.threshold) as usize;
96
97        if estimate_tokens(messages) <= max_tokens {
98            return;
99        }
100
101        // Score all messages, remembering their indices and per-message token costs
102        let total_non_system = messages
103            .iter()
104            .filter(|m| m.role != MessageRole::System)
105            .count();
106        let recent_start = total_non_system.saturating_sub(self.recent_count);
107
108        let mut scored: Vec<(usize, f64, usize)> = Vec::new(); // (idx, score, tokens)
109        let mut non_system_idx = 0usize;
110        for (i, msg) in messages.iter().enumerate() {
111            if msg.role == MessageRole::System {
112                continue;
113            }
114            let is_recent = non_system_idx >= recent_start;
115            let tokens = msg.content.len() / 4 + 1;
116            scored.push((i, Self::score_message(msg, is_recent), tokens));
117            non_system_idx += 1;
118        }
119
120        // Sort by score ascending (lowest importance first)
121        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
122
123        // Project removals: track projected token total without mutating messages
124        let mut projected_tokens = estimate_tokens(messages);
125        let mut remove_set: Vec<usize> = Vec::new();
126        for &(idx, score, tokens) in &scored {
127            if projected_tokens <= max_tokens {
128                break;
129            }
130            if score.is_infinite() {
131                continue; // never remove system
132            }
133            remove_set.push(idx);
134            projected_tokens = projected_tokens.saturating_sub(tokens);
135        }
136
137        if !remove_set.is_empty() {
138            // Remove in reverse index order to preserve earlier indices
139            remove_set.sort_unstable();
140            for &idx in remove_set.iter().rev() {
141                messages.remove(idx);
142            }
143            state.last_output_truncated = true;
144        }
145    }
146}
147
148// ---------------------------------------------------------------------------
149// LlmCompressor
150// ---------------------------------------------------------------------------
151
152/// Summarizes old messages using an LLM provider.
153///
154/// Makes exactly **one** LLM call per compression event. If the call fails,
155/// falls back to rule-based pruning (remove oldest non-system messages).
156///
157/// # Example
158///
159/// ```rust,no_run
160/// use traitclaw_core::context_managers::LlmCompressor;
161/// # fn example(provider: std::sync::Arc<dyn traitclaw_core::Provider>) {
162/// let compressor = LlmCompressor::new(provider);
163/// # }
164/// ```
165pub struct LlmCompressor {
166    /// Provider for summarization calls.
167    provider: Arc<dyn Provider>,
168    /// Prompt template for summarization.
169    summary_prompt: String,
170    /// Fraction of `context_window` at which compression triggers.
171    threshold: f64,
172    /// Number of recent messages to keep verbatim.
173    keep_recent: usize,
174}
175
176impl LlmCompressor {
177    /// Default summarization prompt.
178    const DEFAULT_PROMPT: &str = "Summarize the following conversation messages \
179        into a concise paragraph. Preserve key facts, decisions, and context. \
180        Omit greetings and filler.";
181
182    /// Create a compressor with a given provider.
183    #[must_use]
184    pub fn new(provider: Arc<dyn Provider>) -> Self {
185        Self {
186            provider,
187            summary_prompt: Self::DEFAULT_PROMPT.to_string(),
188            threshold: 0.80,
189            keep_recent: 4,
190        }
191    }
192
193    /// Set a custom summarization prompt template.
194    #[must_use]
195    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
196        self.summary_prompt = prompt.into();
197        self
198    }
199
200    /// Set the compression threshold (0.0–1.0).
201    #[must_use]
202    pub fn with_threshold(mut self, threshold: f64) -> Self {
203        self.threshold = threshold.clamp(0.0, 1.0);
204        self
205    }
206
207    /// Set the number of recent messages to keep verbatim.
208    #[must_use]
209    pub fn with_keep_recent(mut self, count: usize) -> Self {
210        self.keep_recent = count;
211        self
212    }
213}
214
215#[async_trait]
216impl ContextManager for LlmCompressor {
217    #[allow(
218        clippy::cast_possible_truncation,
219        clippy::cast_sign_loss,
220        clippy::cast_precision_loss
221    )]
222    async fn prepare(
223        &self,
224        messages: &mut Vec<Message>,
225        context_window: usize,
226        state: &mut AgentState,
227    ) {
228        let max_tokens = (context_window as f64 * self.threshold) as usize;
229
230        if estimate_tokens(messages) <= max_tokens {
231            return;
232        }
233
234        // Partition: system messages | old messages (to summarize) | recent messages (keep)
235        let system_msgs: Vec<Message> = messages
236            .iter()
237            .filter(|m| m.role == MessageRole::System)
238            .cloned()
239            .collect();
240
241        let non_system: Vec<Message> = messages
242            .iter()
243            .filter(|m| m.role != MessageRole::System)
244            .cloned()
245            .collect();
246
247        if non_system.len() <= self.keep_recent {
248            return; // not enough messages to summarize
249        }
250
251        let split_at = non_system.len() - self.keep_recent;
252        let old_messages = &non_system[..split_at];
253        let recent_messages = &non_system[split_at..];
254
255        // Build text of old messages for summarization
256        let old_text: String = old_messages
257            .iter()
258            .map(|m| format!("{:?}: {}", m.role, m.content))
259            .collect::<Vec<_>>()
260            .join("\n");
261
262        // Single LLM call for summarization
263        let req = CompletionRequest {
264            model: self.provider.model_info().name.clone(),
265            messages: vec![
266                Message {
267                    role: MessageRole::System,
268                    content: self.summary_prompt.clone(),
269                    tool_call_id: None,
270                },
271                Message {
272                    role: MessageRole::User,
273                    content: old_text,
274                    tool_call_id: None,
275                },
276            ],
277            tools: vec![],
278            max_tokens: Some(500),
279            temperature: Some(0.3),
280            response_format: None,
281            stream: false,
282        };
283
284        let summary_text = match self.provider.complete(req).await {
285            Ok(response) => match response.content {
286                ResponseContent::Text(text) => text,
287                ResponseContent::ToolCalls(_) => {
288                    tracing::warn!("LlmCompressor: provider returned tool calls instead of text");
289                    Self::fallback_summary(old_messages)
290                }
291            },
292            Err(e) => {
293                tracing::warn!("LlmCompressor: summarization failed ({e}), using fallback");
294                Self::fallback_summary(old_messages)
295            }
296        };
297
298        // Rebuild messages: system + summary + recent
299        // NOTE: Summary uses Assistant role so it appears naturally in the
300        // conversation flow. If a provider rejects consecutive assistant
301        // messages, consider switching to System role with a marker prefix.
302        let summary_msg = Message {
303            role: MessageRole::Assistant,
304            content: format!("[Context Summary] {summary_text}"),
305            tool_call_id: None,
306        };
307
308        messages.clear();
309        messages.extend(system_msgs);
310        messages.push(summary_msg);
311        messages.extend(recent_messages.iter().cloned());
312
313        state.last_output_truncated = true;
314    }
315}
316
317impl LlmCompressor {
318    /// Fallback when LLM call fails: truncate old messages to a brief note.
319    fn fallback_summary(old_messages: &[Message]) -> String {
320        format!(
321            "{} earlier messages were removed to save context space.",
322            old_messages.len()
323        )
324    }
325}
326
327// ---------------------------------------------------------------------------
328// TieredCompressor
329// ---------------------------------------------------------------------------
330
331/// Chains compression tiers: keep recent → rule-compress mid → LLM-summarize old.
332///
333/// Without an LLM provider, uses only rule-based compression for the older messages.
334///
335/// # Example
336///
337/// ```rust
338/// use traitclaw_core::context_managers::TieredCompressor;
339///
340/// let compressor = TieredCompressor::new(5); // keep last 5 messages
341/// ```
342pub struct TieredCompressor {
343    /// Number of recent messages to keep verbatim.
344    recent_count: usize,
345    /// Rule-based compressor for the middle tier.
346    rule_compressor: RuleBasedCompressor,
347    /// Optional LLM compressor for the oldest tier.
348    llm_compressor: Option<LlmCompressor>,
349}
350
351impl TieredCompressor {
352    /// Create a tiered compressor (rule-only mode).
353    #[must_use]
354    pub fn new(recent_count: usize) -> Self {
355        Self {
356            recent_count,
357            rule_compressor: RuleBasedCompressor::new(0.85, recent_count),
358            llm_compressor: None,
359        }
360    }
361
362    /// Enable the LLM summarization tier.
363    #[must_use]
364    pub fn with_llm(mut self, provider: Arc<dyn Provider>) -> Self {
365        self.llm_compressor =
366            Some(LlmCompressor::new(provider).with_keep_recent(self.recent_count));
367        self
368    }
369}
370
371#[async_trait]
372impl ContextManager for TieredCompressor {
373    async fn prepare(
374        &self,
375        messages: &mut Vec<Message>,
376        context_window: usize,
377        state: &mut AgentState,
378    ) {
379        // Tier 1: Try LLM summarization of oldest messages (if available)
380        if let Some(llm) = &self.llm_compressor {
381            llm.prepare(messages, context_window, state).await;
382        }
383
384        // Tier 2: Rule-based compression for remaining overflow
385        self.rule_compressor
386            .prepare(messages, context_window, state)
387            .await;
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::types::completion::{CompletionResponse, ResponseContent, Usage};
395    use crate::types::model_info::{ModelInfo, ModelTier};
396    use crate::types::stream::CompletionStream;
397
398    fn msg(role: MessageRole, content: &str) -> Message {
399        Message {
400            role,
401            content: content.to_string(),
402            tool_call_id: None,
403        }
404    }
405
406    fn tool_msg(content: &str) -> Message {
407        Message {
408            role: MessageRole::Tool,
409            content: content.to_string(),
410            tool_call_id: Some("call_1".to_string()),
411        }
412    }
413
414    fn default_state() -> AgentState {
415        AgentState::new(ModelTier::Medium, 128_000)
416    }
417
418    // ── RuleBasedCompressor ─────────────────────────────────────────────
419
420    #[tokio::test]
421    async fn test_rule_compressor_no_pruning_under_threshold() {
422        let comp = RuleBasedCompressor::default();
423        let mut msgs = vec![
424            msg(MessageRole::System, "system"),
425            msg(MessageRole::User, "hello"),
426        ];
427        let mut state = default_state();
428
429        comp.prepare(&mut msgs, 100_000, &mut state).await;
430        assert_eq!(msgs.len(), 2);
431        assert!(!state.last_output_truncated);
432    }
433
434    #[tokio::test]
435    async fn test_rule_compressor_removes_lowest_scored() {
436        let comp = RuleBasedCompressor::new(0.85, 1);
437        let mut msgs = vec![
438            msg(MessageRole::System, "system"),
439            msg(MessageRole::User, &"old1 ".repeat(500)), // old, score 0.3
440            msg(MessageRole::Assistant, &"old2 ".repeat(500)), // old, score 0.3
441            tool_msg(&"tool ".repeat(500)),               // tool, score 0.7
442            msg(MessageRole::User, &"recent ".repeat(500)), // recent, score 0.9
443        ];
444        let mut state = default_state();
445
446        // context_window small enough to trigger pruning
447        comp.prepare(&mut msgs, 800, &mut state).await;
448
449        // System must survive
450        assert_eq!(msgs[0].role, MessageRole::System);
451        // Old messages should be removed first
452        assert!(msgs.len() < 5, "should have removed some messages");
453        assert!(state.last_output_truncated);
454    }
455
456    #[tokio::test]
457    async fn test_rule_compressor_never_removes_system() {
458        let comp = RuleBasedCompressor::new(0.5, 0);
459        let mut msgs = vec![
460            msg(MessageRole::System, &"sys ".repeat(1000)),
461            msg(MessageRole::User, "tiny"),
462        ];
463        let mut state = default_state();
464
465        comp.prepare(&mut msgs, 100, &mut state).await;
466
467        // System message must survive even if it alone exceeds the budget
468        assert!(msgs.iter().any(|m| m.role == MessageRole::System));
469    }
470
471    #[tokio::test]
472    async fn test_rule_compressor_updates_state() {
473        let comp = RuleBasedCompressor::new(0.5, 0);
474        let mut msgs = vec![
475            msg(MessageRole::System, "sys"),
476            msg(MessageRole::User, &"x".repeat(4000)),
477            msg(MessageRole::Assistant, &"y".repeat(4000)),
478        ];
479        let mut state = default_state();
480
481        comp.prepare(&mut msgs, 1000, &mut state).await;
482        assert!(state.last_output_truncated);
483    }
484
485    // ── LlmCompressor ───────────────────────────────────────────────────
486
487    struct MockSummarizer {
488        info: ModelInfo,
489    }
490
491    impl MockSummarizer {
492        fn new() -> Self {
493            Self {
494                info: ModelInfo::new(
495                    "mock-summarizer",
496                    ModelTier::Small,
497                    4096,
498                    false,
499                    false,
500                    false,
501                ),
502            }
503        }
504    }
505
506    #[async_trait]
507    impl Provider for MockSummarizer {
508        async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
509            Ok(CompletionResponse {
510                content: ResponseContent::Text(
511                    "User asked about Rust. Assistant explained traits.".to_string(),
512                ),
513                usage: Usage {
514                    prompt_tokens: 50,
515                    completion_tokens: 20,
516                    total_tokens: 70,
517                },
518            })
519        }
520
521        async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
522            unimplemented!()
523        }
524
525        fn model_info(&self) -> &ModelInfo {
526            &self.info
527        }
528    }
529
530    struct FailingSummarizer {
531        info: ModelInfo,
532    }
533
534    impl FailingSummarizer {
535        fn new() -> Self {
536            Self {
537                info: ModelInfo::new("failing", ModelTier::Small, 4096, false, false, false),
538            }
539        }
540    }
541
542    #[async_trait]
543    impl Provider for FailingSummarizer {
544        async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
545            Err(crate::Error::Provider {
546                message: "network error".into(),
547                status_code: None,
548            })
549        }
550
551        async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
552            unimplemented!()
553        }
554
555        fn model_info(&self) -> &ModelInfo {
556            &self.info
557        }
558    }
559
560    #[tokio::test]
561    async fn test_llm_compressor_summarizes_old_messages() {
562        let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
563        let comp = LlmCompressor::new(provider).with_keep_recent(2);
564
565        let mut msgs = vec![
566            msg(MessageRole::System, "You are helpful"),
567            msg(MessageRole::User, &"old question ".repeat(500)),
568            msg(MessageRole::Assistant, &"old answer ".repeat(500)),
569            msg(MessageRole::User, "recent question"),
570            msg(MessageRole::Assistant, "recent answer"),
571        ];
572        let mut state = default_state();
573
574        comp.prepare(&mut msgs, 800, &mut state).await;
575
576        // System message preserved
577        assert_eq!(msgs[0].role, MessageRole::System);
578        // Summary inserted
579        assert!(
580            msgs[1].content.contains("[Context Summary]"),
581            "should have summary: {}",
582            msgs[1].content
583        );
584        // Recent messages kept
585        assert_eq!(msgs.len(), 4); // system + summary + 2 recent
586        assert!(state.last_output_truncated);
587    }
588
589    #[tokio::test]
590    async fn test_llm_compressor_no_compression_under_threshold() {
591        let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
592        let comp = LlmCompressor::new(provider).with_keep_recent(2);
593
594        let mut msgs = vec![
595            msg(MessageRole::System, "sys"),
596            msg(MessageRole::User, "hi"),
597            msg(MessageRole::Assistant, "hello"),
598        ];
599        let mut state = default_state();
600
601        comp.prepare(&mut msgs, 100_000, &mut state).await;
602        assert_eq!(msgs.len(), 3);
603        assert!(!state.last_output_truncated);
604    }
605
606    #[tokio::test]
607    async fn test_llm_compressor_custom_prompt() {
608        let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
609        let comp = LlmCompressor::new(provider)
610            .with_prompt("Custom prompt")
611            .with_keep_recent(1);
612
613        let mut msgs = vec![
614            msg(MessageRole::System, "sys"),
615            msg(MessageRole::User, &"old ".repeat(2000)),
616            msg(MessageRole::User, "recent"),
617        ];
618        let mut state = default_state();
619
620        comp.prepare(&mut msgs, 500, &mut state).await;
621        assert!(msgs[1].content.contains("[Context Summary]"));
622    }
623
624    #[tokio::test]
625    async fn test_llm_compressor_fallback_on_failure() {
626        let provider: Arc<dyn Provider> = Arc::new(FailingSummarizer::new());
627        let comp = LlmCompressor::new(provider).with_keep_recent(1);
628
629        let mut msgs = vec![
630            msg(MessageRole::System, "sys"),
631            msg(MessageRole::User, &"old ".repeat(2000)),
632            msg(MessageRole::Assistant, &"old ".repeat(2000)),
633            msg(MessageRole::User, "recent"),
634        ];
635        let mut state = default_state();
636
637        comp.prepare(&mut msgs, 500, &mut state).await;
638
639        // Should still compress with fallback
640        assert!(msgs[1].content.contains("[Context Summary]"));
641        assert!(msgs[1].content.contains("removed to save context"));
642        assert!(state.last_output_truncated);
643    }
644
645    // ── TieredCompressor ────────────────────────────────────────────────
646
647    #[tokio::test]
648    async fn test_tiered_compressor_rule_only() {
649        let comp = TieredCompressor::new(2);
650
651        let mut msgs = vec![
652            msg(MessageRole::System, "sys"),
653            msg(MessageRole::User, &"old1 ".repeat(500)),
654            msg(MessageRole::Assistant, &"old2 ".repeat(500)),
655            msg(MessageRole::User, &"recent1 ".repeat(500)),
656            msg(MessageRole::Assistant, &"recent2 ".repeat(500)),
657        ];
658        let mut state = default_state();
659
660        comp.prepare(&mut msgs, 1500, &mut state).await;
661
662        assert_eq!(msgs[0].role, MessageRole::System);
663        assert!(msgs.len() < 5);
664        assert!(state.last_output_truncated);
665    }
666
667    #[tokio::test]
668    async fn test_tiered_compressor_with_llm() {
669        let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
670        let comp = TieredCompressor::new(2).with_llm(provider);
671
672        let mut msgs = vec![
673            msg(MessageRole::System, "sys"),
674            msg(MessageRole::User, &"old ".repeat(1000)),
675            msg(MessageRole::Assistant, &"old ".repeat(1000)),
676            msg(MessageRole::User, "recent1"),
677            msg(MessageRole::Assistant, "recent2"),
678        ];
679        let mut state = default_state();
680
681        comp.prepare(&mut msgs, 800, &mut state).await;
682
683        assert_eq!(msgs[0].role, MessageRole::System);
684        assert!(
685            msgs.iter().any(|m| m.content.contains("[Context Summary]")),
686            "should have LLM summary"
687        );
688        assert!(state.last_output_truncated);
689    }
690
691    // ── 50-message stress test ──────────────────────────────────────────
692
693    #[tokio::test]
694    async fn test_rule_compressor_50_messages_within_budget() {
695        let comp = RuleBasedCompressor::new(0.85, 5);
696
697        let mut msgs = vec![msg(MessageRole::System, "You are a helpful assistant")];
698        for i in 0..50 {
699            msgs.push(msg(
700                if i % 2 == 0 {
701                    MessageRole::User
702                } else {
703                    MessageRole::Assistant
704                },
705                &format!("Message number {i}: {}", "content ".repeat(100)),
706            ));
707        }
708        let mut state = default_state();
709
710        // Small window to force heavy compression
711        let window = 2000;
712        comp.prepare(&mut msgs, window, &mut state).await;
713
714        // Verify within budget
715        let tokens: usize = msgs.iter().map(|m| m.content.len() / 4 + 1).sum();
716        let max = (window as f64 * 0.85) as usize;
717        assert!(tokens <= max, "should be within budget: {tokens} <= {max}");
718        // System message must survive
719        assert_eq!(msgs[0].role, MessageRole::System);
720        assert!(state.last_output_truncated);
721    }
722}