Skip to main content

stakpak_agent_core/
context.rs

1use crate::types::ContextConfig;
2use stakai::{ContentPart, Message, MessageContent, Model, Role, Tool};
3use std::collections::{HashMap, HashSet};
4
5const TRUNCATED_ASSISTANT_PLACEHOLDER: &str = "[assistant message truncated]";
6
7/// Pluggable strategy for reducing context before each inference turn.
8pub trait ContextReducer: Send + Sync {
9    fn reduce(
10        &self,
11        messages: Vec<Message>,
12        model: &Model,
13        max_output_tokens: u32,
14        tools: &[Tool],
15        metadata: &mut serde_json::Value,
16    ) -> Vec<Message>;
17}
18
19#[derive(Debug, Clone)]
20pub struct DefaultContextReducer {
21    config: ContextConfig,
22}
23
24impl DefaultContextReducer {
25    pub fn new(config: ContextConfig) -> Self {
26        Self { config }
27    }
28}
29
30impl Default for DefaultContextReducer {
31    fn default() -> Self {
32        Self::new(ContextConfig::default())
33    }
34}
35
36impl ContextReducer for DefaultContextReducer {
37    fn reduce(
38        &self,
39        messages: Vec<Message>,
40        _model: &Model,
41        _max_output_tokens: u32,
42        _tools: &[Tool],
43        _metadata: &mut serde_json::Value,
44    ) -> Vec<Message> {
45        reduce_context(messages, &self.config)
46    }
47}
48
49pub fn reduce_context(messages: Vec<Message>, config: &ContextConfig) -> Vec<Message> {
50    let messages = dedup_tool_results(messages);
51    let messages = merge_consecutive_same_role(messages);
52    let messages = truncate_old_tool_results(messages, config.keep_last_messages);
53    let messages = truncate_old_assistant_messages(messages, config.keep_last_messages);
54    let messages = strip_dangling_tool_calls(messages);
55    remove_orphaned_tool_results(messages)
56}
57
58pub fn dedup_tool_results(mut messages: Vec<Message>) -> Vec<Message> {
59    let mut last_positions: HashMap<String, (usize, usize)> = HashMap::new();
60
61    for (message_idx, message) in messages.iter().enumerate() {
62        if let MessageContent::Parts(parts) = &message.content {
63            for (part_idx, part) in parts.iter().enumerate() {
64                if let ContentPart::ToolResult { tool_call_id, .. } = part {
65                    last_positions.insert(tool_call_id.clone(), (message_idx, part_idx));
66                }
67            }
68        }
69    }
70
71    for (message_idx, message) in messages.iter_mut().enumerate() {
72        if let MessageContent::Parts(parts) = &mut message.content {
73            let mut part_idx = 0usize;
74            parts.retain(|part| {
75                let should_keep = match part {
76                    ContentPart::ToolResult { tool_call_id, .. } => last_positions
77                        .get(tool_call_id)
78                        .is_some_and(|(last_message_idx, last_part_idx)| {
79                            *last_message_idx == message_idx && *last_part_idx == part_idx
80                        }),
81                    _ => true,
82                };
83                part_idx += 1;
84                should_keep
85            });
86            normalize_message_content(message);
87        }
88    }
89
90    remove_empty_messages(messages)
91}
92
93pub fn merge_consecutive_same_role(messages: Vec<Message>) -> Vec<Message> {
94    let mut merged: Vec<Message> = Vec::with_capacity(messages.len());
95
96    for message in messages {
97        let Some(previous) = merged.last_mut() else {
98            merged.push(message);
99            continue;
100        };
101
102        if previous.role == message.role {
103            let mut previous_parts = message_parts(previous).unwrap_or_default();
104            previous_parts.extend(message_parts(&message).unwrap_or_default());
105            previous.content = MessageContent::Parts(previous_parts);
106            normalize_message_content(previous);
107        } else {
108            merged.push(message);
109        }
110    }
111
112    remove_empty_messages(merged)
113}
114
115pub fn truncate_old_tool_results(messages: Vec<Message>, keep_last_n: usize) -> Vec<Message> {
116    if keep_last_n == usize::MAX {
117        return messages;
118    }
119
120    let mut positions: Vec<(usize, usize, String)> = Vec::new();
121
122    for (message_idx, message) in messages.iter().enumerate() {
123        if let MessageContent::Parts(parts) = &message.content {
124            for (part_idx, part) in parts.iter().enumerate() {
125                if let ContentPart::ToolResult { tool_call_id, .. } = part {
126                    positions.push((message_idx, part_idx, tool_call_id.clone()));
127                }
128            }
129        }
130    }
131
132    if positions.len() <= keep_last_n {
133        return messages;
134    }
135
136    let keep_from = positions.len().saturating_sub(keep_last_n);
137    let keep_set: HashSet<(usize, usize)> = positions
138        .into_iter()
139        .skip(keep_from)
140        .map(|(message_idx, part_idx, _)| (message_idx, part_idx))
141        .collect();
142
143    let mut truncated = messages;
144    for (message_idx, message) in truncated.iter_mut().enumerate() {
145        if let MessageContent::Parts(parts) = &mut message.content {
146            let mut part_idx = 0usize;
147            parts.retain(|part| {
148                let keep = match part {
149                    ContentPart::ToolResult { .. } => keep_set.contains(&(message_idx, part_idx)),
150                    _ => true,
151                };
152                part_idx += 1;
153                keep
154            });
155            normalize_message_content(message);
156        }
157    }
158
159    remove_empty_messages(truncated)
160}
161
162pub fn truncate_old_assistant_messages(
163    mut messages: Vec<Message>,
164    keep_last_n: usize,
165) -> Vec<Message> {
166    if keep_last_n == usize::MAX {
167        return messages;
168    }
169
170    let assistant_indices: Vec<usize> = messages
171        .iter()
172        .enumerate()
173        .filter_map(|(idx, message)| {
174            if message.role == Role::Assistant {
175                Some(idx)
176            } else {
177                None
178            }
179        })
180        .collect();
181
182    if assistant_indices.len() <= keep_last_n {
183        return messages;
184    }
185
186    let keep_start = assistant_indices.len().saturating_sub(keep_last_n);
187    let keep_indices: HashSet<usize> = assistant_indices.into_iter().skip(keep_start).collect();
188
189    for (idx, message) in messages.iter_mut().enumerate() {
190        if message.role != Role::Assistant || keep_indices.contains(&idx) {
191            continue;
192        }
193
194        match &mut message.content {
195            MessageContent::Text(text) => {
196                if !text.is_empty() {
197                    *text = TRUNCATED_ASSISTANT_PLACEHOLDER.to_string();
198                }
199            }
200            MessageContent::Parts(parts) => {
201                parts.retain(|part| matches!(part, ContentPart::ToolCall { .. }));
202
203                if parts.is_empty() {
204                    message.content =
205                        MessageContent::Text(TRUNCATED_ASSISTANT_PLACEHOLDER.to_string());
206                }
207            }
208        }
209    }
210
211    remove_empty_messages(messages)
212}
213
214pub fn strip_dangling_tool_calls(mut messages: Vec<Message>) -> Vec<Message> {
215    for idx in 0..messages.len() {
216        let tool_call_ids: Vec<String> = match &messages[idx].content {
217            MessageContent::Parts(parts) => parts
218                .iter()
219                .filter_map(|part| match part {
220                    ContentPart::ToolCall { id, .. } => Some(id.clone()),
221                    _ => None,
222                })
223                .collect(),
224            MessageContent::Text(_) => Vec::new(),
225        };
226
227        if tool_call_ids.is_empty() {
228            continue;
229        }
230
231        let next_results: HashSet<String> = messages
232            .get(idx + 1)
233            .and_then(|message| match &message.content {
234                MessageContent::Parts(parts) => Some(
235                    parts
236                        .iter()
237                        .filter_map(|part| match part {
238                            ContentPart::ToolResult { tool_call_id, .. } => {
239                                Some(tool_call_id.clone())
240                            }
241                            _ => None,
242                        })
243                        .collect::<HashSet<_>>(),
244                ),
245                MessageContent::Text(_) => None,
246            })
247            .unwrap_or_default();
248
249        let has_immediate_results = !next_results.is_empty()
250            && tool_call_ids
251                .iter()
252                .all(|tool_call_id| next_results.contains(tool_call_id));
253
254        if has_immediate_results {
255            continue;
256        }
257
258        if let MessageContent::Parts(parts) = &mut messages[idx].content {
259            parts.retain(|part| !matches!(part, ContentPart::ToolCall { .. }));
260            normalize_message_content(&mut messages[idx]);
261        }
262    }
263
264    remove_empty_messages(messages)
265}
266
267pub fn remove_orphaned_tool_results(mut messages: Vec<Message>) -> Vec<Message> {
268    let mut seen_tool_calls: HashSet<String> = HashSet::new();
269
270    for message in &mut messages {
271        if let MessageContent::Parts(parts) = &mut message.content {
272            for part in parts.iter() {
273                if let ContentPart::ToolCall { id, .. } = part {
274                    seen_tool_calls.insert(id.clone());
275                }
276            }
277
278            parts.retain(|part| match part {
279                ContentPart::ToolResult { tool_call_id, .. } => {
280                    seen_tool_calls.contains(tool_call_id)
281                }
282                _ => true,
283            });
284
285            normalize_message_content(message);
286        }
287    }
288
289    remove_empty_messages(messages)
290}
291
292fn message_parts(message: &Message) -> Option<Vec<ContentPart>> {
293    match &message.content {
294        MessageContent::Text(text) => {
295            if text.is_empty() {
296                None
297            } else {
298                Some(vec![ContentPart::text(text.clone())])
299            }
300        }
301        MessageContent::Parts(parts) => Some(parts.clone()),
302    }
303}
304
305fn normalize_message_content(message: &mut Message) {
306    match &message.content {
307        MessageContent::Parts(parts) if parts.is_empty() => {
308            message.content = MessageContent::Text(String::new());
309        }
310        _ => {}
311    }
312}
313
314fn remove_empty_messages(messages: Vec<Message>) -> Vec<Message> {
315    messages
316        .into_iter()
317        .filter(|message| match &message.content {
318            MessageContent::Text(text) => !text.is_empty(),
319            MessageContent::Parts(parts) => !parts.is_empty(),
320        })
321        .collect()
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use serde_json::json;
328
329    fn tool_call_message(id: &str) -> Message {
330        Message {
331            role: Role::Assistant,
332            content: MessageContent::Parts(vec![ContentPart::ToolCall {
333                id: id.to_string(),
334                name: "stakpak__view".to_string(),
335                arguments: json!({"path":"README.md"}),
336                provider_options: None,
337                metadata: None,
338            }]),
339            name: None,
340            provider_options: None,
341        }
342    }
343
344    fn tool_result_message(id: &str, value: &str) -> Message {
345        Message {
346            role: Role::Tool,
347            content: MessageContent::Parts(vec![ContentPart::ToolResult {
348                tool_call_id: id.to_string(),
349                content: json!(value),
350                provider_options: None,
351            }]),
352            name: None,
353            provider_options: None,
354        }
355    }
356
357    #[test]
358    fn dedup_keeps_last_tool_result_per_tool_call_id() {
359        let reduced = dedup_tool_results(vec![
360            tool_call_message("tc_1"),
361            tool_result_message("tc_1", "old"),
362            tool_result_message("tc_1", "new"),
363        ]);
364
365        assert_eq!(reduced.len(), 2);
366
367        let last = &reduced[1];
368        assert_eq!(last.role, Role::Tool);
369        if let MessageContent::Parts(parts) = &last.content {
370            assert_eq!(parts.len(), 1);
371            assert!(matches!(
372                &parts[0],
373                ContentPart::ToolResult { content, .. } if content == &json!("new")
374            ));
375        } else {
376            panic!("expected parts content for tool message");
377        }
378    }
379
380    #[test]
381    fn merge_consecutive_same_role_merges_tool_messages() {
382        let merged = merge_consecutive_same_role(vec![
383            tool_call_message("tc_1"),
384            tool_result_message("tc_1", "result_1"),
385            tool_result_message("tc_2", "result_2"),
386        ]);
387
388        assert_eq!(merged.len(), 2);
389        assert_eq!(merged[1].role, Role::Tool);
390
391        if let MessageContent::Parts(parts) = &merged[1].content {
392            let tool_results = parts
393                .iter()
394                .filter(|part| matches!(part, ContentPart::ToolResult { .. }))
395                .count();
396            assert_eq!(tool_results, 2);
397        } else {
398            panic!("expected merged tool parts");
399        }
400    }
401
402    #[test]
403    fn remove_orphaned_tool_results_removes_missing_references() {
404        let reduced = remove_orphaned_tool_results(vec![
405            tool_result_message("tc_missing", "orphan"),
406            tool_call_message("tc_1"),
407            tool_result_message("tc_1", "ok"),
408        ]);
409
410        assert_eq!(reduced.len(), 2);
411        assert_eq!(reduced[0].role, Role::Assistant);
412        assert_eq!(reduced[1].role, Role::Tool);
413    }
414
415    #[test]
416    fn truncate_old_assistant_messages_keeps_recent_context() {
417        let messages = vec![
418            Message::new(Role::Assistant, "older"),
419            Message::new(Role::Assistant, "newer"),
420            Message::new(Role::Assistant, "latest"),
421        ];
422
423        let truncated = truncate_old_assistant_messages(messages, 2);
424
425        assert_eq!(truncated.len(), 3);
426        assert_eq!(
427            truncated[0].text(),
428            Some(TRUNCATED_ASSISTANT_PLACEHOLDER.to_string())
429        );
430        assert_eq!(truncated[1].text(), Some("newer".to_string()));
431        assert_eq!(truncated[2].text(), Some("latest".to_string()));
432    }
433
434    #[test]
435    fn strip_dangling_tool_calls_removes_unresolved_tool_uses() {
436        let assistant_with_tool_call = Message {
437            role: Role::Assistant,
438            content: MessageContent::Parts(vec![
439                ContentPart::text("let me check"),
440                ContentPart::tool_call("tc_1", "stakpak__view", json!({"path":"README.md"})),
441            ]),
442            name: None,
443            provider_options: None,
444        };
445
446        let reduced = reduce_context(
447            vec![
448                assistant_with_tool_call,
449                Message::new(Role::User, "new user prompt"),
450                tool_result_message("tc_1", "late result"),
451            ],
452            &ContextConfig::default(),
453        );
454
455        // Tool call was removed because the immediate next message did not include tool_result.
456        // The orphaned late tool_result is removed by remove_orphaned_tool_results().
457        assert_eq!(reduced.len(), 2);
458        assert_eq!(reduced[0].role, Role::Assistant);
459        assert_eq!(reduced[1].role, Role::User);
460
461        if let MessageContent::Parts(parts) = &reduced[0].content {
462            assert!(
463                parts
464                    .iter()
465                    .all(|part| !matches!(part, ContentPart::ToolCall { .. }))
466            );
467        } else {
468            panic!("expected assistant message parts");
469        }
470    }
471
472    #[test]
473    fn full_reduce_pipeline_runs_in_expected_order() {
474        let config = ContextConfig {
475            keep_last_messages: 2,
476        };
477
478        let reduced = reduce_context(
479            vec![
480                tool_call_message("tc_1"),
481                tool_result_message("tc_1", "old"),
482                tool_result_message("tc_1", "new"),
483                Message::new(Role::Assistant, "analysis"),
484            ],
485            &config,
486        );
487
488        // assistant tool call + last deduped tool result + assistant analysis
489        assert_eq!(reduced.len(), 3);
490    }
491}