1use crate::types::ContextConfig;
2use stakai::{ContentPart, Message, MessageContent, Model, Role, Tool};
3use std::collections::{HashMap, HashSet};
4
5const TRUNCATED_ASSISTANT_PLACEHOLDER: &str = "[assistant message truncated]";
6
7pub trait ContextReducer: Send + Sync {
9 fn reduce(
10 &self,
11 messages: Vec<Message>,
12 model: &Model,
13 max_output_tokens: u32,
14 tools: &[Tool],
15 metadata: &mut serde_json::Value,
16 ) -> Vec<Message>;
17}
18
19#[derive(Debug, Clone)]
20pub struct DefaultContextReducer {
21 config: ContextConfig,
22}
23
24impl DefaultContextReducer {
25 pub fn new(config: ContextConfig) -> Self {
26 Self { config }
27 }
28}
29
30impl Default for DefaultContextReducer {
31 fn default() -> Self {
32 Self::new(ContextConfig::default())
33 }
34}
35
36impl ContextReducer for DefaultContextReducer {
37 fn reduce(
38 &self,
39 messages: Vec<Message>,
40 _model: &Model,
41 _max_output_tokens: u32,
42 _tools: &[Tool],
43 _metadata: &mut serde_json::Value,
44 ) -> Vec<Message> {
45 reduce_context(messages, &self.config)
46 }
47}
48
49pub fn reduce_context(messages: Vec<Message>, config: &ContextConfig) -> Vec<Message> {
50 let messages = dedup_tool_results(messages);
51 let messages = merge_consecutive_same_role(messages);
52 let messages = truncate_old_tool_results(messages, config.keep_last_messages);
53 let messages = truncate_old_assistant_messages(messages, config.keep_last_messages);
54 let messages = strip_dangling_tool_calls(messages);
55 remove_orphaned_tool_results(messages)
56}
57
58pub fn dedup_tool_results(mut messages: Vec<Message>) -> Vec<Message> {
59 let mut last_positions: HashMap<String, (usize, usize)> = HashMap::new();
60
61 for (message_idx, message) in messages.iter().enumerate() {
62 if let MessageContent::Parts(parts) = &message.content {
63 for (part_idx, part) in parts.iter().enumerate() {
64 if let ContentPart::ToolResult { tool_call_id, .. } = part {
65 last_positions.insert(tool_call_id.clone(), (message_idx, part_idx));
66 }
67 }
68 }
69 }
70
71 for (message_idx, message) in messages.iter_mut().enumerate() {
72 if let MessageContent::Parts(parts) = &mut message.content {
73 let mut part_idx = 0usize;
74 parts.retain(|part| {
75 let should_keep = match part {
76 ContentPart::ToolResult { tool_call_id, .. } => last_positions
77 .get(tool_call_id)
78 .is_some_and(|(last_message_idx, last_part_idx)| {
79 *last_message_idx == message_idx && *last_part_idx == part_idx
80 }),
81 _ => true,
82 };
83 part_idx += 1;
84 should_keep
85 });
86 normalize_message_content(message);
87 }
88 }
89
90 remove_empty_messages(messages)
91}
92
93pub fn merge_consecutive_same_role(messages: Vec<Message>) -> Vec<Message> {
94 let mut merged: Vec<Message> = Vec::with_capacity(messages.len());
95
96 for message in messages {
97 let Some(previous) = merged.last_mut() else {
98 merged.push(message);
99 continue;
100 };
101
102 if previous.role == message.role {
103 let mut previous_parts = message_parts(previous).unwrap_or_default();
104 previous_parts.extend(message_parts(&message).unwrap_or_default());
105 previous.content = MessageContent::Parts(previous_parts);
106 normalize_message_content(previous);
107 } else {
108 merged.push(message);
109 }
110 }
111
112 remove_empty_messages(merged)
113}
114
115pub fn truncate_old_tool_results(messages: Vec<Message>, keep_last_n: usize) -> Vec<Message> {
116 if keep_last_n == usize::MAX {
117 return messages;
118 }
119
120 let mut positions: Vec<(usize, usize, String)> = Vec::new();
121
122 for (message_idx, message) in messages.iter().enumerate() {
123 if let MessageContent::Parts(parts) = &message.content {
124 for (part_idx, part) in parts.iter().enumerate() {
125 if let ContentPart::ToolResult { tool_call_id, .. } = part {
126 positions.push((message_idx, part_idx, tool_call_id.clone()));
127 }
128 }
129 }
130 }
131
132 if positions.len() <= keep_last_n {
133 return messages;
134 }
135
136 let keep_from = positions.len().saturating_sub(keep_last_n);
137 let keep_set: HashSet<(usize, usize)> = positions
138 .into_iter()
139 .skip(keep_from)
140 .map(|(message_idx, part_idx, _)| (message_idx, part_idx))
141 .collect();
142
143 let mut truncated = messages;
144 for (message_idx, message) in truncated.iter_mut().enumerate() {
145 if let MessageContent::Parts(parts) = &mut message.content {
146 let mut part_idx = 0usize;
147 parts.retain(|part| {
148 let keep = match part {
149 ContentPart::ToolResult { .. } => keep_set.contains(&(message_idx, part_idx)),
150 _ => true,
151 };
152 part_idx += 1;
153 keep
154 });
155 normalize_message_content(message);
156 }
157 }
158
159 remove_empty_messages(truncated)
160}
161
162pub fn truncate_old_assistant_messages(
163 mut messages: Vec<Message>,
164 keep_last_n: usize,
165) -> Vec<Message> {
166 if keep_last_n == usize::MAX {
167 return messages;
168 }
169
170 let assistant_indices: Vec<usize> = messages
171 .iter()
172 .enumerate()
173 .filter_map(|(idx, message)| {
174 if message.role == Role::Assistant {
175 Some(idx)
176 } else {
177 None
178 }
179 })
180 .collect();
181
182 if assistant_indices.len() <= keep_last_n {
183 return messages;
184 }
185
186 let keep_start = assistant_indices.len().saturating_sub(keep_last_n);
187 let keep_indices: HashSet<usize> = assistant_indices.into_iter().skip(keep_start).collect();
188
189 for (idx, message) in messages.iter_mut().enumerate() {
190 if message.role != Role::Assistant || keep_indices.contains(&idx) {
191 continue;
192 }
193
194 match &mut message.content {
195 MessageContent::Text(text) => {
196 if !text.is_empty() {
197 *text = TRUNCATED_ASSISTANT_PLACEHOLDER.to_string();
198 }
199 }
200 MessageContent::Parts(parts) => {
201 parts.retain(|part| matches!(part, ContentPart::ToolCall { .. }));
202
203 if parts.is_empty() {
204 message.content =
205 MessageContent::Text(TRUNCATED_ASSISTANT_PLACEHOLDER.to_string());
206 }
207 }
208 }
209 }
210
211 remove_empty_messages(messages)
212}
213
214pub fn strip_dangling_tool_calls(mut messages: Vec<Message>) -> Vec<Message> {
215 for idx in 0..messages.len() {
216 let tool_call_ids: Vec<String> = match &messages[idx].content {
217 MessageContent::Parts(parts) => parts
218 .iter()
219 .filter_map(|part| match part {
220 ContentPart::ToolCall { id, .. } => Some(id.clone()),
221 _ => None,
222 })
223 .collect(),
224 MessageContent::Text(_) => Vec::new(),
225 };
226
227 if tool_call_ids.is_empty() {
228 continue;
229 }
230
231 let next_results: HashSet<String> = messages
232 .get(idx + 1)
233 .and_then(|message| match &message.content {
234 MessageContent::Parts(parts) => Some(
235 parts
236 .iter()
237 .filter_map(|part| match part {
238 ContentPart::ToolResult { tool_call_id, .. } => {
239 Some(tool_call_id.clone())
240 }
241 _ => None,
242 })
243 .collect::<HashSet<_>>(),
244 ),
245 MessageContent::Text(_) => None,
246 })
247 .unwrap_or_default();
248
249 let has_immediate_results = !next_results.is_empty()
250 && tool_call_ids
251 .iter()
252 .all(|tool_call_id| next_results.contains(tool_call_id));
253
254 if has_immediate_results {
255 continue;
256 }
257
258 if let MessageContent::Parts(parts) = &mut messages[idx].content {
259 parts.retain(|part| !matches!(part, ContentPart::ToolCall { .. }));
260 normalize_message_content(&mut messages[idx]);
261 }
262 }
263
264 remove_empty_messages(messages)
265}
266
267pub fn remove_orphaned_tool_results(mut messages: Vec<Message>) -> Vec<Message> {
268 let mut seen_tool_calls: HashSet<String> = HashSet::new();
269
270 for message in &mut messages {
271 if let MessageContent::Parts(parts) = &mut message.content {
272 for part in parts.iter() {
273 if let ContentPart::ToolCall { id, .. } = part {
274 seen_tool_calls.insert(id.clone());
275 }
276 }
277
278 parts.retain(|part| match part {
279 ContentPart::ToolResult { tool_call_id, .. } => {
280 seen_tool_calls.contains(tool_call_id)
281 }
282 _ => true,
283 });
284
285 normalize_message_content(message);
286 }
287 }
288
289 remove_empty_messages(messages)
290}
291
292fn message_parts(message: &Message) -> Option<Vec<ContentPart>> {
293 match &message.content {
294 MessageContent::Text(text) => {
295 if text.is_empty() {
296 None
297 } else {
298 Some(vec![ContentPart::text(text.clone())])
299 }
300 }
301 MessageContent::Parts(parts) => Some(parts.clone()),
302 }
303}
304
305fn normalize_message_content(message: &mut Message) {
306 match &message.content {
307 MessageContent::Parts(parts) if parts.is_empty() => {
308 message.content = MessageContent::Text(String::new());
309 }
310 _ => {}
311 }
312}
313
314fn remove_empty_messages(messages: Vec<Message>) -> Vec<Message> {
315 messages
316 .into_iter()
317 .filter(|message| match &message.content {
318 MessageContent::Text(text) => !text.is_empty(),
319 MessageContent::Parts(parts) => !parts.is_empty(),
320 })
321 .collect()
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use serde_json::json;
328
329 fn tool_call_message(id: &str) -> Message {
330 Message {
331 role: Role::Assistant,
332 content: MessageContent::Parts(vec![ContentPart::ToolCall {
333 id: id.to_string(),
334 name: "stakpak__view".to_string(),
335 arguments: json!({"path":"README.md"}),
336 provider_options: None,
337 metadata: None,
338 }]),
339 name: None,
340 provider_options: None,
341 }
342 }
343
344 fn tool_result_message(id: &str, value: &str) -> Message {
345 Message {
346 role: Role::Tool,
347 content: MessageContent::Parts(vec![ContentPart::ToolResult {
348 tool_call_id: id.to_string(),
349 content: json!(value),
350 provider_options: None,
351 }]),
352 name: None,
353 provider_options: None,
354 }
355 }
356
357 #[test]
358 fn dedup_keeps_last_tool_result_per_tool_call_id() {
359 let reduced = dedup_tool_results(vec![
360 tool_call_message("tc_1"),
361 tool_result_message("tc_1", "old"),
362 tool_result_message("tc_1", "new"),
363 ]);
364
365 assert_eq!(reduced.len(), 2);
366
367 let last = &reduced[1];
368 assert_eq!(last.role, Role::Tool);
369 if let MessageContent::Parts(parts) = &last.content {
370 assert_eq!(parts.len(), 1);
371 assert!(matches!(
372 &parts[0],
373 ContentPart::ToolResult { content, .. } if content == &json!("new")
374 ));
375 } else {
376 panic!("expected parts content for tool message");
377 }
378 }
379
380 #[test]
381 fn merge_consecutive_same_role_merges_tool_messages() {
382 let merged = merge_consecutive_same_role(vec![
383 tool_call_message("tc_1"),
384 tool_result_message("tc_1", "result_1"),
385 tool_result_message("tc_2", "result_2"),
386 ]);
387
388 assert_eq!(merged.len(), 2);
389 assert_eq!(merged[1].role, Role::Tool);
390
391 if let MessageContent::Parts(parts) = &merged[1].content {
392 let tool_results = parts
393 .iter()
394 .filter(|part| matches!(part, ContentPart::ToolResult { .. }))
395 .count();
396 assert_eq!(tool_results, 2);
397 } else {
398 panic!("expected merged tool parts");
399 }
400 }
401
402 #[test]
403 fn remove_orphaned_tool_results_removes_missing_references() {
404 let reduced = remove_orphaned_tool_results(vec![
405 tool_result_message("tc_missing", "orphan"),
406 tool_call_message("tc_1"),
407 tool_result_message("tc_1", "ok"),
408 ]);
409
410 assert_eq!(reduced.len(), 2);
411 assert_eq!(reduced[0].role, Role::Assistant);
412 assert_eq!(reduced[1].role, Role::Tool);
413 }
414
415 #[test]
416 fn truncate_old_assistant_messages_keeps_recent_context() {
417 let messages = vec![
418 Message::new(Role::Assistant, "older"),
419 Message::new(Role::Assistant, "newer"),
420 Message::new(Role::Assistant, "latest"),
421 ];
422
423 let truncated = truncate_old_assistant_messages(messages, 2);
424
425 assert_eq!(truncated.len(), 3);
426 assert_eq!(
427 truncated[0].text(),
428 Some(TRUNCATED_ASSISTANT_PLACEHOLDER.to_string())
429 );
430 assert_eq!(truncated[1].text(), Some("newer".to_string()));
431 assert_eq!(truncated[2].text(), Some("latest".to_string()));
432 }
433
434 #[test]
435 fn strip_dangling_tool_calls_removes_unresolved_tool_uses() {
436 let assistant_with_tool_call = Message {
437 role: Role::Assistant,
438 content: MessageContent::Parts(vec![
439 ContentPart::text("let me check"),
440 ContentPart::tool_call("tc_1", "stakpak__view", json!({"path":"README.md"})),
441 ]),
442 name: None,
443 provider_options: None,
444 };
445
446 let reduced = reduce_context(
447 vec![
448 assistant_with_tool_call,
449 Message::new(Role::User, "new user prompt"),
450 tool_result_message("tc_1", "late result"),
451 ],
452 &ContextConfig::default(),
453 );
454
455 assert_eq!(reduced.len(), 2);
458 assert_eq!(reduced[0].role, Role::Assistant);
459 assert_eq!(reduced[1].role, Role::User);
460
461 if let MessageContent::Parts(parts) = &reduced[0].content {
462 assert!(
463 parts
464 .iter()
465 .all(|part| !matches!(part, ContentPart::ToolCall { .. }))
466 );
467 } else {
468 panic!("expected assistant message parts");
469 }
470 }
471
472 #[test]
473 fn full_reduce_pipeline_runs_in_expected_order() {
474 let config = ContextConfig {
475 keep_last_messages: 2,
476 };
477
478 let reduced = reduce_context(
479 vec![
480 tool_call_message("tc_1"),
481 tool_result_message("tc_1", "old"),
482 tool_result_message("tc_1", "new"),
483 Message::new(Role::Assistant, "analysis"),
484 ],
485 &config,
486 );
487
488 assert_eq!(reduced.len(), 3);
490 }
491}