symbi_runtime/reasoning/
context_manager.rs1use tracing::{debug, info, warn};
10
11use crate::reasoning::conversation::{Conversation, MessageRole};
12
13#[derive(Debug, Clone, Default)]
15pub enum ContextStrategy {
16 #[default]
19 SlidingWindow,
20
21 ObservationMasking,
24
25 AnchoredSummary {
28 recent_count: usize,
30 },
31}
32
33pub trait ContextManager: Send + Sync {
35 fn manage_context(&self, conversation: &mut Conversation, max_tokens: usize);
37
38 fn strategy_name(&self) -> &str;
40}
41
42pub struct DefaultContextManager {
44 strategy: ContextStrategy,
45}
46
47impl DefaultContextManager {
48 pub fn new(strategy: ContextStrategy) -> Self {
50 Self { strategy }
51 }
52
53 fn apply_sliding_window(conversation: &mut Conversation, max_tokens: usize) {
55 conversation.truncate_to_budget(max_tokens);
56 }
57
58 fn apply_observation_masking(conversation: &mut Conversation, max_tokens: usize) {
60 let estimated = conversation.estimate_tokens();
61 if estimated <= max_tokens {
62 return;
63 }
64
65 info!(
66 estimated_tokens = estimated,
67 max_tokens,
68 over_by = estimated - max_tokens,
69 "ObservationMasking: context exceeds budget, masking old tool results"
70 );
71
72 let messages = conversation.messages().to_vec();
73 let total = messages.len();
74 if total <= 3 {
75 warn!("ObservationMasking: only {} messages, cannot mask", total);
76 return;
77 }
78
79 let keep_recent = 6.min(total);
82 let mut new_messages = Vec::new();
83 let mut masked_count = 0usize;
84
85 for (i, msg) in messages.iter().enumerate() {
86 if i >= total - keep_recent {
87 new_messages.push(msg.clone());
89 } else if msg.role == MessageRole::Tool {
90 let mut masked = msg.clone();
92 masked.content = format!(
93 "[Previous {} result omitted for context management]",
94 msg.tool_name.as_deref().unwrap_or("tool")
95 );
96 masked_count += 1;
97 new_messages.push(masked);
98 } else {
99 new_messages.push(msg.clone());
101 }
102 }
103
104 info!(
105 masked_tool_results = masked_count,
106 kept_recent = keep_recent,
107 total_messages = total,
108 "ObservationMasking: masked old tool results"
109 );
110
111 *conversation = Conversation::new();
112 for msg in new_messages {
113 conversation.push(msg);
114 }
115
116 if conversation.estimate_tokens() > max_tokens {
118 let still_estimated = conversation.estimate_tokens();
119 warn!(
120 still_estimated,
121 max_tokens, "ObservationMasking insufficient, falling back to SlidingWindow"
122 );
123 Self::apply_sliding_window(conversation, max_tokens);
124 }
125 }
126
127 fn apply_anchored_summary(
129 conversation: &mut Conversation,
130 max_tokens: usize,
131 recent_count: usize,
132 ) {
133 if conversation.estimate_tokens() <= max_tokens {
134 return;
135 }
136
137 let messages = conversation.messages().to_vec();
138 let total = messages.len();
139
140 let mut anchor_end = 0;
142 for (i, msg) in messages.iter().enumerate() {
143 if msg.role == MessageRole::System || (msg.role == MessageRole::User && i <= 1) {
144 anchor_end = i + 1;
145 } else {
146 break;
147 }
148 }
149
150 let keep_recent = recent_count.min(total.saturating_sub(anchor_end));
151 let recent_start = total.saturating_sub(keep_recent);
152
153 let mut new_messages: Vec<_> = messages[..anchor_end].to_vec();
155
156 if anchor_end < recent_start {
157 let middle_count = recent_start - anchor_end;
158 let tool_calls_in_middle = messages[anchor_end..recent_start]
159 .iter()
160 .filter(|m| !m.tool_calls.is_empty())
161 .count();
162 let tool_results_in_middle = messages[anchor_end..recent_start]
163 .iter()
164 .filter(|m| m.role == MessageRole::Tool)
165 .count();
166
167 let summary = format!(
168 "[Context summary: {} messages omitted ({} tool calls, {} tool results). The conversation continued with the agent working on the task.]",
169 middle_count, tool_calls_in_middle, tool_results_in_middle
170 );
171 new_messages.push(crate::reasoning::conversation::ConversationMessage::user(
172 summary,
173 ));
174 }
175
176 new_messages.extend(messages[recent_start..].to_vec());
177
178 *conversation = Conversation::new();
179 for msg in new_messages {
180 conversation.push(msg);
181 }
182
183 if conversation.estimate_tokens() > max_tokens {
185 Self::apply_sliding_window(conversation, max_tokens);
186 }
187 }
188}
189
190impl Default for DefaultContextManager {
191 fn default() -> Self {
192 Self::new(ContextStrategy::SlidingWindow)
193 }
194}
195
196impl ContextManager for DefaultContextManager {
197 fn manage_context(&self, conversation: &mut Conversation, max_tokens: usize) {
198 let before_tokens = conversation.estimate_tokens();
199 let before_len = conversation.len();
200 debug!(
201 strategy = self.strategy_name(),
202 estimated_tokens = before_tokens,
203 max_tokens,
204 message_count = before_len,
205 "Context management check"
206 );
207
208 match &self.strategy {
209 ContextStrategy::SlidingWindow => {
210 Self::apply_sliding_window(conversation, max_tokens);
211 }
212 ContextStrategy::ObservationMasking => {
213 Self::apply_observation_masking(conversation, max_tokens);
214 }
215 ContextStrategy::AnchoredSummary { recent_count } => {
216 Self::apply_anchored_summary(conversation, max_tokens, *recent_count);
217 }
218 }
219
220 let after_tokens = conversation.estimate_tokens();
221 let after_len = conversation.len();
222 if after_tokens < before_tokens {
223 info!(
224 strategy = self.strategy_name(),
225 before_tokens,
226 after_tokens,
227 tokens_saved = before_tokens - after_tokens,
228 messages_before = before_len,
229 messages_after = after_len,
230 "Context compaction triggered"
231 );
232 }
233 }
234
235 fn strategy_name(&self) -> &str {
236 match self.strategy {
237 ContextStrategy::SlidingWindow => "sliding_window",
238 ContextStrategy::ObservationMasking => "observation_masking",
239 ContextStrategy::AnchoredSummary { .. } => "anchored_summary",
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::reasoning::conversation::{ConversationMessage, ToolCall};
248
249 fn build_long_conversation() -> Conversation {
250 let mut conv = Conversation::with_system("You are a research agent.");
251 for i in 0..20 {
252 conv.push(ConversationMessage::user(format!(
253 "Research question {} about a topic that requires multiple paragraphs of text to describe properly",
254 i
255 )));
256 conv.push(ConversationMessage::assistant_tool_calls(vec![ToolCall {
257 id: format!("call_{}", i),
258 name: "web_search".into(),
259 arguments: format!(r#"{{"query": "topic {} detailed information"}}"#, i),
260 }]));
261 conv.push(ConversationMessage::tool_result(
262 format!("call_{}", i),
263 "web_search",
264 format!("Here are the detailed results for query {}. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.", i),
265 ));
266 conv.push(ConversationMessage::assistant(format!(
267 "Based on the search results for question {}, I found that the topic involves multiple interesting aspects that we should discuss in detail.",
268 i
269 )));
270 }
271 conv
272 }
273
274 #[test]
275 fn test_sliding_window_no_truncation_needed() {
276 let mgr = DefaultContextManager::new(ContextStrategy::SlidingWindow);
277 let mut conv = Conversation::with_system("sys");
278 conv.push(ConversationMessage::user("hi"));
279 conv.push(ConversationMessage::assistant("hello"));
280
281 let original_tokens = conv.estimate_tokens();
282 mgr.manage_context(&mut conv, 10000);
283 assert_eq!(conv.estimate_tokens(), original_tokens);
284 }
285
286 #[test]
287 fn test_sliding_window_truncation() {
288 let mgr = DefaultContextManager::new(ContextStrategy::SlidingWindow);
289 let mut conv = build_long_conversation();
290 let original_len = conv.len();
291
292 mgr.manage_context(&mut conv, 200);
293 assert!(conv.len() < original_len);
294 assert!(conv.estimate_tokens() <= 200);
295 assert_eq!(conv.messages()[0].role, MessageRole::System);
297 }
298
299 #[test]
300 fn test_observation_masking() {
301 let mgr = DefaultContextManager::new(ContextStrategy::ObservationMasking);
302 let mut conv = build_long_conversation();
303
304 mgr.manage_context(&mut conv, 500);
305
306 let mut found_masked = false;
308 for msg in conv.messages() {
309 if msg.role == MessageRole::Tool && msg.content.contains("omitted") {
310 found_masked = true;
311 break;
312 }
313 }
314 assert!(found_masked || conv.estimate_tokens() <= 500);
317 }
318
319 #[test]
320 fn test_anchored_summary() {
321 let mgr = DefaultContextManager::new(ContextStrategy::AnchoredSummary { recent_count: 6 });
322 let mut conv = build_long_conversation();
323 let original_len = conv.len();
324
325 mgr.manage_context(&mut conv, 500);
326 assert!(conv.len() < original_len);
327
328 assert_eq!(conv.messages()[0].role, MessageRole::System);
330
331 let has_summary = conv
333 .messages()
334 .iter()
335 .any(|m| m.content.contains("Context summary"));
336 assert!(has_summary || conv.estimate_tokens() <= 500);
338 }
339
340 #[test]
341 fn test_strategy_name() {
342 assert_eq!(
343 DefaultContextManager::new(ContextStrategy::SlidingWindow).strategy_name(),
344 "sliding_window"
345 );
346 assert_eq!(
347 DefaultContextManager::new(ContextStrategy::ObservationMasking).strategy_name(),
348 "observation_masking"
349 );
350 assert_eq!(
351 DefaultContextManager::new(ContextStrategy::AnchoredSummary { recent_count: 4 })
352 .strategy_name(),
353 "anchored_summary"
354 );
355 }
356
357 #[test]
358 fn test_context_within_budget_untouched() {
359 let mgr = DefaultContextManager::new(ContextStrategy::ObservationMasking);
360 let mut conv = Conversation::with_system("sys");
361 conv.push(ConversationMessage::user("short"));
362
363 let before = conv.len();
364 mgr.manage_context(&mut conv, 100_000);
365 assert_eq!(conv.len(), before);
366 }
367}