Skip to main content

symbi_runtime/reasoning/
context_manager.rs

1//! Runtime context management for reasoning loops
2//!
3//! Manages the conversation context to keep it within token budgets.
4//! Provides multiple strategies for context compression:
5//! - SlidingWindow: keep the most recent messages
6//! - ObservationMasking: replace old tool outputs but keep reasoning
7//! - AnchoredSummary: keep system + first user + summarize middle + recent
8
9use tracing::{debug, info, warn};
10
11use crate::reasoning::conversation::{Conversation, MessageRole};
12
13/// Strategy for managing context within token budgets.
14#[derive(Debug, Clone, Default)]
15pub enum ContextStrategy {
16    /// Keep the most recent messages that fit the budget.
17    /// Simple and predictable.
18    #[default]
19    SlidingWindow,
20
21    /// Replace old tool outputs with "[previous tool result omitted]"
22    /// but keep the reasoning chain intact. Preserves decision history.
23    ObservationMasking,
24
25    /// Keep system prompt + first user message as anchors,
26    /// summarize the middle, keep recent messages.
27    AnchoredSummary {
28        /// Number of recent messages to always keep.
29        recent_count: usize,
30    },
31}
32
33/// Manages conversation context to stay within token budgets.
34pub trait ContextManager: Send + Sync {
35    /// Apply context management to keep the conversation within budget.
36    fn manage_context(&self, conversation: &mut Conversation, max_tokens: usize);
37
38    /// Get the strategy name for logging.
39    fn strategy_name(&self) -> &str;
40}
41
42/// Default context manager using configurable strategies.
43pub struct DefaultContextManager {
44    strategy: ContextStrategy,
45}
46
47impl DefaultContextManager {
48    /// Create a new context manager with the given strategy.
49    pub fn new(strategy: ContextStrategy) -> Self {
50        Self { strategy }
51    }
52
53    /// Apply sliding window: keep system + most recent messages.
54    fn apply_sliding_window(conversation: &mut Conversation, max_tokens: usize) {
55        conversation.truncate_to_budget(max_tokens);
56    }
57
58    /// Apply observation masking: replace old tool results with placeholders.
59    fn apply_observation_masking(conversation: &mut Conversation, max_tokens: usize) {
60        let estimated = conversation.estimate_tokens();
61        if estimated <= max_tokens {
62            return;
63        }
64
65        info!(
66            estimated_tokens = estimated,
67            max_tokens,
68            over_by = estimated - max_tokens,
69            "ObservationMasking: context exceeds budget, masking old tool results"
70        );
71
72        let messages = conversation.messages().to_vec();
73        let total = messages.len();
74        if total <= 3 {
75            warn!("ObservationMasking: only {} messages, cannot mask", total);
76            return;
77        }
78
79        // Find tool result messages to mask, starting from oldest
80        // Keep the most recent 6 messages (3 turns) intact
81        let keep_recent = 6.min(total);
82        let mut new_messages = Vec::new();
83        let mut masked_count = 0usize;
84
85        for (i, msg) in messages.iter().enumerate() {
86            if i >= total - keep_recent {
87                // Keep recent messages as-is
88                new_messages.push(msg.clone());
89            } else if msg.role == MessageRole::Tool {
90                // Replace old tool results with a placeholder
91                let mut masked = msg.clone();
92                masked.content = format!(
93                    "[Previous {} result omitted for context management]",
94                    msg.tool_name.as_deref().unwrap_or("tool")
95                );
96                masked_count += 1;
97                new_messages.push(masked);
98            } else {
99                // Keep non-tool messages (reasoning, user input)
100                new_messages.push(msg.clone());
101            }
102        }
103
104        info!(
105            masked_tool_results = masked_count,
106            kept_recent = keep_recent,
107            total_messages = total,
108            "ObservationMasking: masked old tool results"
109        );
110
111        *conversation = Conversation::new();
112        for msg in new_messages {
113            conversation.push(msg);
114        }
115
116        // If still over budget, fall back to sliding window
117        if conversation.estimate_tokens() > max_tokens {
118            let still_estimated = conversation.estimate_tokens();
119            warn!(
120                still_estimated,
121                max_tokens, "ObservationMasking insufficient, falling back to SlidingWindow"
122            );
123            Self::apply_sliding_window(conversation, max_tokens);
124        }
125    }
126
127    /// Apply anchored summary: keep anchors + summarize middle + recent.
128    fn apply_anchored_summary(
129        conversation: &mut Conversation,
130        max_tokens: usize,
131        recent_count: usize,
132    ) {
133        if conversation.estimate_tokens() <= max_tokens {
134            return;
135        }
136
137        let messages = conversation.messages().to_vec();
138        let total = messages.len();
139
140        // Determine anchor messages (system + first user)
141        let mut anchor_end = 0;
142        for (i, msg) in messages.iter().enumerate() {
143            if msg.role == MessageRole::System || (msg.role == MessageRole::User && i <= 1) {
144                anchor_end = i + 1;
145            } else {
146                break;
147            }
148        }
149
150        let keep_recent = recent_count.min(total.saturating_sub(anchor_end));
151        let recent_start = total.saturating_sub(keep_recent);
152
153        // If there's a middle section, create a summary placeholder
154        let mut new_messages: Vec<_> = messages[..anchor_end].to_vec();
155
156        if anchor_end < recent_start {
157            let middle_count = recent_start - anchor_end;
158            let tool_calls_in_middle = messages[anchor_end..recent_start]
159                .iter()
160                .filter(|m| !m.tool_calls.is_empty())
161                .count();
162            let tool_results_in_middle = messages[anchor_end..recent_start]
163                .iter()
164                .filter(|m| m.role == MessageRole::Tool)
165                .count();
166
167            let summary = format!(
168                "[Context summary: {} messages omitted ({} tool calls, {} tool results). The conversation continued with the agent working on the task.]",
169                middle_count, tool_calls_in_middle, tool_results_in_middle
170            );
171            new_messages.push(crate::reasoning::conversation::ConversationMessage::user(
172                summary,
173            ));
174        }
175
176        new_messages.extend(messages[recent_start..].to_vec());
177
178        *conversation = Conversation::new();
179        for msg in new_messages {
180            conversation.push(msg);
181        }
182
183        // Final fallback
184        if conversation.estimate_tokens() > max_tokens {
185            Self::apply_sliding_window(conversation, max_tokens);
186        }
187    }
188}
189
190impl Default for DefaultContextManager {
191    fn default() -> Self {
192        Self::new(ContextStrategy::SlidingWindow)
193    }
194}
195
196impl ContextManager for DefaultContextManager {
197    fn manage_context(&self, conversation: &mut Conversation, max_tokens: usize) {
198        let before_tokens = conversation.estimate_tokens();
199        let before_len = conversation.len();
200        debug!(
201            strategy = self.strategy_name(),
202            estimated_tokens = before_tokens,
203            max_tokens,
204            message_count = before_len,
205            "Context management check"
206        );
207
208        match &self.strategy {
209            ContextStrategy::SlidingWindow => {
210                Self::apply_sliding_window(conversation, max_tokens);
211            }
212            ContextStrategy::ObservationMasking => {
213                Self::apply_observation_masking(conversation, max_tokens);
214            }
215            ContextStrategy::AnchoredSummary { recent_count } => {
216                Self::apply_anchored_summary(conversation, max_tokens, *recent_count);
217            }
218        }
219
220        let after_tokens = conversation.estimate_tokens();
221        let after_len = conversation.len();
222        if after_tokens < before_tokens {
223            info!(
224                strategy = self.strategy_name(),
225                before_tokens,
226                after_tokens,
227                tokens_saved = before_tokens - after_tokens,
228                messages_before = before_len,
229                messages_after = after_len,
230                "Context compaction triggered"
231            );
232        }
233    }
234
235    fn strategy_name(&self) -> &str {
236        match self.strategy {
237            ContextStrategy::SlidingWindow => "sliding_window",
238            ContextStrategy::ObservationMasking => "observation_masking",
239            ContextStrategy::AnchoredSummary { .. } => "anchored_summary",
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::reasoning::conversation::{ConversationMessage, ToolCall};
248
249    fn build_long_conversation() -> Conversation {
250        let mut conv = Conversation::with_system("You are a research agent.");
251        for i in 0..20 {
252            conv.push(ConversationMessage::user(format!(
253                "Research question {} about a topic that requires multiple paragraphs of text to describe properly",
254                i
255            )));
256            conv.push(ConversationMessage::assistant_tool_calls(vec![ToolCall {
257                id: format!("call_{}", i),
258                name: "web_search".into(),
259                arguments: format!(r#"{{"query": "topic {} detailed information"}}"#, i),
260            }]));
261            conv.push(ConversationMessage::tool_result(
262                format!("call_{}", i),
263                "web_search",
264                format!("Here are the detailed results for query {}. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", i),
265            ));
266            conv.push(ConversationMessage::assistant(format!(
267                "Based on the search results for question {}, I found that the topic involves multiple interesting aspects that we should discuss in detail.",
268                i
269            )));
270        }
271        conv
272    }
273
274    #[test]
275    fn test_sliding_window_no_truncation_needed() {
276        let mgr = DefaultContextManager::new(ContextStrategy::SlidingWindow);
277        let mut conv = Conversation::with_system("sys");
278        conv.push(ConversationMessage::user("hi"));
279        conv.push(ConversationMessage::assistant("hello"));
280
281        let original_tokens = conv.estimate_tokens();
282        mgr.manage_context(&mut conv, 10000);
283        assert_eq!(conv.estimate_tokens(), original_tokens);
284    }
285
286    #[test]
287    fn test_sliding_window_truncation() {
288        let mgr = DefaultContextManager::new(ContextStrategy::SlidingWindow);
289        let mut conv = build_long_conversation();
290        let original_len = conv.len();
291
292        mgr.manage_context(&mut conv, 200);
293        assert!(conv.len() < original_len);
294        assert!(conv.estimate_tokens() <= 200);
295        // System message preserved
296        assert_eq!(conv.messages()[0].role, MessageRole::System);
297    }
298
299    #[test]
300    fn test_observation_masking() {
301        let mgr = DefaultContextManager::new(ContextStrategy::ObservationMasking);
302        let mut conv = build_long_conversation();
303
304        mgr.manage_context(&mut conv, 500);
305
306        // Check that old tool results are masked
307        let mut found_masked = false;
308        for msg in conv.messages() {
309            if msg.role == MessageRole::Tool && msg.content.contains("omitted") {
310                found_masked = true;
311                break;
312            }
313        }
314        // The masking should have replaced some old tool results
315        // (or fallen back to sliding window if still over budget)
316        assert!(found_masked || conv.estimate_tokens() <= 500);
317    }
318
319    #[test]
320    fn test_anchored_summary() {
321        let mgr = DefaultContextManager::new(ContextStrategy::AnchoredSummary { recent_count: 6 });
322        let mut conv = build_long_conversation();
323        let original_len = conv.len();
324
325        mgr.manage_context(&mut conv, 500);
326        assert!(conv.len() < original_len);
327
328        // System message preserved
329        assert_eq!(conv.messages()[0].role, MessageRole::System);
330
331        // Check for summary message
332        let has_summary = conv
333            .messages()
334            .iter()
335            .any(|m| m.content.contains("Context summary"));
336        // Either has summary or was small enough not to need it
337        assert!(has_summary || conv.estimate_tokens() <= 500);
338    }
339
340    #[test]
341    fn test_strategy_name() {
342        assert_eq!(
343            DefaultContextManager::new(ContextStrategy::SlidingWindow).strategy_name(),
344            "sliding_window"
345        );
346        assert_eq!(
347            DefaultContextManager::new(ContextStrategy::ObservationMasking).strategy_name(),
348            "observation_masking"
349        );
350        assert_eq!(
351            DefaultContextManager::new(ContextStrategy::AnchoredSummary { recent_count: 4 })
352                .strategy_name(),
353            "anchored_summary"
354        );
355    }
356
357    #[test]
358    fn test_context_within_budget_untouched() {
359        let mgr = DefaultContextManager::new(ContextStrategy::ObservationMasking);
360        let mut conv = Conversation::with_system("sys");
361        conv.push(ConversationMessage::user("short"));
362
363        let before = conv.len();
364        mgr.manage_context(&mut conv, 100_000);
365        assert_eq!(conv.len(), before);
366    }
367}