Skip to main content

agent_sdk/
compact.rs

1//! Conversation compaction — summarize older messages when approaching context limits.
2//!
3//! Two-tier context management:
4//!
5//! 1. **Pruning** ([`prune_tool_results`]) — lightweight pass that truncates oversized
6//!    tool results in older messages. Triggers at a configurable percentage of
7//!    `context_budget` (default 70%). This often frees enough space to avoid full
8//!    compaction.
9//!
10//! 2. **Compaction** ([`call_summarizer`]) — when `input_tokens` exceeds `context_budget`,
11//!    older messages are summarized via a cheaper model call and replaced with a compact
12//!    summary, preserving the system prompt and recent turns.
13
14use tracing::{debug, warn};
15
16use crate::client::{ApiContentBlock, ApiMessage, CreateMessageRequest};
17use crate::error::Result;
18use crate::provider::LlmProvider;
19
20/// Default model for compaction summaries.
21pub const DEFAULT_COMPACTION_MODEL: &str = "claude-haiku-4-5";
22
23/// Default minimum number of messages to keep at the end (never compact below this).
24pub const DEFAULT_MIN_KEEP_MESSAGES: usize = 4;
25
26/// Default max tokens for the summarization response.
27pub const DEFAULT_SUMMARY_MAX_TOKENS: u32 = 4096;
28
29/// Check whether compaction should trigger.
30pub fn should_compact(input_tokens: u64, context_budget: u64) -> bool {
31    input_tokens > context_budget
32}
33
34/// Find the split point — index where old messages end and recent messages begin.
35///
36/// Rules:
37/// - Keep at least `min_keep` messages at the end (falls back to `DEFAULT_MIN_KEEP_MESSAGES`).
38/// - Never split inside a tool-use cycle (assistant with tool_use followed by
39///   user with tool_result must stay together).
40/// - Returns 0 if the conversation is too short to compact.
41pub fn find_split_point(conversation: &[ApiMessage], min_keep: usize) -> usize {
42    if conversation.len() <= min_keep {
43        return 0;
44    }
45
46    // Start candidate: keep min_keep from the end
47    let mut split = conversation.len() - min_keep;
48
49    // Walk backwards to find a clean boundary (not inside a tool cycle).
50    // A tool cycle is: assistant message with ToolUse blocks followed by a user
51    // message with ToolResult blocks. We must not split between them.
52    while split > 0 {
53        // Check if the message at `split` is a user message with tool results
54        // and the message before it is an assistant with tool uses — if so, the
55        // split would break a tool cycle, so move split backwards to before the
56        // assistant message.
57        if split < conversation.len() {
58            let msg = &conversation[split];
59            if msg.role == "user" && has_tool_results(&msg.content) {
60                // This user message has tool results — check if prior is assistant with tool_use
61                if split > 0 {
62                    let prev = &conversation[split - 1];
63                    if prev.role == "assistant" && has_tool_uses(&prev.content) {
64                        // Can't split here — move before the assistant message
65                        split -= 1;
66                        continue;
67                    }
68                }
69            }
70        }
71        break;
72    }
73
74    split
75}
76
77/// Build the summarization prompt from old messages.
78pub fn build_summary_prompt(old_messages: &[ApiMessage]) -> String {
79    let mut rendered = String::new();
80
81    for msg in old_messages {
82        rendered.push_str(&format!("[{}]\n", msg.role));
83        for block in &msg.content {
84            match block {
85                ApiContentBlock::Text { text, .. } => {
86                    rendered.push_str(text);
87                    rendered.push('\n');
88                }
89                ApiContentBlock::ToolUse { name, input, .. } => {
90                    rendered.push_str(&format!("Tool call: {} input: {}\n", name, input));
91                }
92                ApiContentBlock::ToolResult {
93                    content, is_error, ..
94                } => {
95                    let label = if *is_error == Some(true) {
96                        "error"
97                    } else {
98                        "result"
99                    };
100                    // Truncate long tool results
101                    let content_str = content.to_string();
102                    if content_str.len() > 500 {
103                        let mut end = 500;
104                        while end > 0 && !content_str.is_char_boundary(end) {
105                            end -= 1;
106                        }
107                        rendered.push_str(&format!("Tool {}: {}...\n", label, &content_str[..end]));
108                    } else {
109                        rendered.push_str(&format!("Tool {}: {}\n", label, content_str));
110                    }
111                }
112                ApiContentBlock::Thinking { thinking } => {
113                    // Skip thinking blocks in summary — they're internal reasoning
114                    if thinking.len() <= 200 {
115                        rendered.push_str(&format!("(thinking: {})\n", thinking));
116                    }
117                }
118                ApiContentBlock::Image { .. } => {
119                    rendered.push_str("[image]\n");
120                }
121            }
122        }
123        rendered.push('\n');
124    }
125
126    format!(
127        "Summarize the following conversation segment concisely. Preserve:\n\
128         - Key decisions made\n\
129         - Important facts and context established\n\
130         - File paths and code references mentioned\n\
131         - Tool results and their outcomes\n\
132         - Any commitments or action items\n\n\
133         Format as a structured summary with sections.\n\n\
134         <conversation>\n{rendered}</conversation>"
135    )
136}
137
138/// Call the summarizer model via an LLM provider. Falls back to `fallback_model` on failure.
139///
140/// If a separate `fallback_provider` is given it is used for the retry; otherwise
141/// the same `provider` is reused with `fallback_model`.
142pub async fn call_summarizer(
143    provider: &dyn LlmProvider,
144    summary_prompt: &str,
145    compaction_model: &str,
146    fallback_provider: Option<&dyn LlmProvider>,
147    fallback_model: &str,
148    summary_max_tokens: u32,
149) -> Result<String> {
150    let request = CreateMessageRequest {
151        model: compaction_model.to_string(),
152        max_tokens: summary_max_tokens,
153        messages: vec![ApiMessage {
154            role: "user".to_string(),
155            content: vec![ApiContentBlock::Text {
156                text: summary_prompt.to_string(),
157                cache_control: None,
158            }],
159        }],
160        system: None,
161        tools: None,
162        stream: false,
163        metadata: None,
164        thinking: None,
165    };
166
167    match provider.create_message(&request).await {
168        Ok(resp) => extract_text(&resp.content),
169        Err(e) => {
170            warn!(
171                model = compaction_model,
172                error = %e,
173                "Compaction model failed, falling back to primary model"
174            );
175            // Retry with the fallback (primary) model/provider
176            let mut fallback_req = request;
177            fallback_req.model = fallback_model.to_string();
178            let fb = fallback_provider.unwrap_or(provider);
179            let resp = fb.create_message(&fallback_req).await?;
180            extract_text(&resp.content)
181        }
182    }
183}
184
185/// Replace old messages with a summary message.
186pub fn splice_conversation(conversation: &mut Vec<ApiMessage>, split_point: usize, summary: &str) {
187    // Remove old messages
188    conversation.drain(..split_point);
189
190    // Insert summary as the first message
191    conversation.insert(
192        0,
193        ApiMessage {
194            role: "user".to_string(),
195            content: vec![ApiContentBlock::Text {
196                text: format!(
197                    "[Previous conversation summary]\n{summary}\n[End of summary — conversation continues below]"
198                ),
199                cache_control: None,
200            }],
201        },
202    );
203}
204
205/// Result of a compaction operation.
206#[derive(Debug)]
207pub struct CompactResult {
208    pub pre_tokens: u64,
209    pub summary: String,
210    pub messages_compacted: usize,
211}
212
213// ── tool result pruning ──────────────────────────────────────────────────
214
215/// Default: tool results longer than this (in chars) are candidates for pruning.
216pub const DEFAULT_PRUNE_TOOL_RESULT_MAX_CHARS: usize = 2_000;
217
218/// Default: pruning triggers at 70% of context budget.
219pub const DEFAULT_PRUNE_THRESHOLD_PCT: u8 = 70;
220
221/// Check whether lightweight pruning should trigger.
222///
223/// Returns `true` when `input_tokens` exceeds `threshold_pct`% of `context_budget`
224/// but is still within the budget (i.e. before full compaction fires).
225pub fn should_prune(input_tokens: u64, context_budget: u64, threshold_pct: u8) -> bool {
226    let threshold = context_budget * threshold_pct as u64 / 100;
227    input_tokens > threshold
228}
229
230/// Prune oversized tool results in-place to free context space.
231///
232/// Walks the conversation from oldest to newest, skipping the last
233/// `preserve_tail` messages. For each `ToolResult` whose text content exceeds
234/// `max_chars`, replaces it with: first 500 chars + marker + last 200 chars.
235///
236/// Returns the total number of characters removed.
237pub fn prune_tool_results(
238    conversation: &mut [ApiMessage],
239    max_chars: usize,
240    preserve_tail: usize,
241) -> usize {
242    let len = conversation.len();
243    let end = len.saturating_sub(preserve_tail);
244    let mut total_removed = 0;
245
246    for msg in conversation[..end].iter_mut() {
247        for block in msg.content.iter_mut() {
248            if let ApiContentBlock::ToolResult { content, .. } = block {
249                let text = content.to_string();
250                if text.len() <= max_chars {
251                    continue;
252                }
253
254                let original_len = text.len();
255
256                // Build pruned version: head + marker + tail
257                let head_end = char_boundary(&text, 500);
258                let tail_start = char_boundary_rev(&text, 200);
259
260                let pruned = format!(
261                    "{}\n\n[...{} chars pruned...]\n\n{}",
262                    &text[..head_end],
263                    original_len - head_end - (original_len - tail_start),
264                    &text[tail_start..]
265                );
266
267                let removed = original_len - pruned.len();
268                total_removed += removed;
269                *content = serde_json::json!(pruned);
270
271                debug!(
272                    original = original_len,
273                    pruned = pruned.len(),
274                    saved = removed,
275                    "Pruned tool result"
276                );
277            }
278        }
279    }
280
281    if total_removed > 0 {
282        debug!(
283            total_chars_removed = total_removed,
284            "Tool result pruning complete"
285        );
286    }
287
288    total_removed
289}
290
291/// Find a char boundary at or before `target` bytes from the start.
292fn char_boundary(s: &str, target: usize) -> usize {
293    let target = target.min(s.len());
294    let mut pos = target;
295    while pos > 0 && !s.is_char_boundary(pos) {
296        pos -= 1;
297    }
298    pos
299}
300
301/// Find a char boundary at or after `distance` bytes from the end.
302fn char_boundary_rev(s: &str, distance: usize) -> usize {
303    if distance >= s.len() {
304        return 0;
305    }
306    let mut pos = s.len() - distance;
307    while pos < s.len() && !s.is_char_boundary(pos) {
308        pos += 1;
309    }
310    pos
311}
312
313// ── helpers ──────────────────────────────────────────────────────────────
314
315fn has_tool_uses(blocks: &[ApiContentBlock]) -> bool {
316    blocks
317        .iter()
318        .any(|b| matches!(b, ApiContentBlock::ToolUse { .. }))
319}
320
321fn has_tool_results(blocks: &[ApiContentBlock]) -> bool {
322    blocks
323        .iter()
324        .any(|b| matches!(b, ApiContentBlock::ToolResult { .. }))
325}
326
327fn extract_text(content: &[ApiContentBlock]) -> Result<String> {
328    let text: String = content
329        .iter()
330        .filter_map(|b| match b {
331            ApiContentBlock::Text { text, .. } => Some(text.as_str()),
332            _ => None,
333        })
334        .collect::<Vec<_>>()
335        .join("");
336
337    if text.is_empty() {
338        Err(crate::error::AgentError::Api(
339            "Compaction response contained no text".into(),
340        ))
341    } else {
342        debug!(summary_len = text.len(), "Generated compaction summary");
343        Ok(text)
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    fn text_msg(role: &str, text: &str) -> ApiMessage {
352        ApiMessage {
353            role: role.to_string(),
354            content: vec![ApiContentBlock::Text {
355                text: text.to_string(),
356                cache_control: None,
357            }],
358        }
359    }
360
361    fn tool_use_msg() -> ApiMessage {
362        ApiMessage {
363            role: "assistant".to_string(),
364            content: vec![
365                ApiContentBlock::Text {
366                    text: "Let me check.".to_string(),
367                    cache_control: None,
368                },
369                ApiContentBlock::ToolUse {
370                    id: "tu_1".to_string(),
371                    name: "Bash".to_string(),
372                    input: serde_json::json!({"command": "ls"}),
373                },
374            ],
375        }
376    }
377
378    fn tool_result_msg() -> ApiMessage {
379        ApiMessage {
380            role: "user".to_string(),
381            content: vec![ApiContentBlock::ToolResult {
382                tool_use_id: "tu_1".to_string(),
383                content: serde_json::json!("file1.rs\nfile2.rs"),
384                is_error: None,
385                cache_control: None,
386                name: None,
387            }],
388        }
389    }
390
391    #[test]
392    fn test_should_compact_threshold() {
393        assert!(!should_compact(150_000, 160_000));
394        assert!(should_compact(170_000, 160_000));
395        assert!(should_compact(160_001, 160_000));
396        assert!(!should_compact(160_000, 160_000));
397    }
398
399    #[test]
400    fn test_find_split_point_preserves_recent() {
401        let conv: Vec<ApiMessage> = (0..10)
402            .map(|i| {
403                let role = if i % 2 == 0 { "user" } else { "assistant" };
404                text_msg(role, &format!("message {i}"))
405            })
406            .collect();
407
408        let split = find_split_point(&conv, DEFAULT_MIN_KEEP_MESSAGES);
409        // Should keep at least DEFAULT_MIN_KEEP_MESSAGES (4) at the end
410        assert!(conv.len() - split >= DEFAULT_MIN_KEEP_MESSAGES);
411        assert_eq!(split, 6); // 10 - 4 = 6
412    }
413
414    #[test]
415    fn test_find_split_point_respects_tool_boundaries() {
416        // Conversation: user, assistant, user, assistant(tool_use), user(tool_result), assistant, user
417        // 7 messages total, MIN_KEEP = 4, so split candidate = 3
418        // Index 3 = assistant(tool_use), index 4 = user(tool_result)
419        // Split at 3 means keeping [3..7] which includes the tool pair — that's fine
420        // But what if split would land at index 4 (user with tool_result)?
421        let conv = vec![
422            text_msg("user", "hello"),        // 0
423            text_msg("assistant", "hi"),      // 1
424            text_msg("user", "do something"), // 2
425            tool_use_msg(),                   // 3 - assistant with tool_use
426            tool_result_msg(),                // 4 - user with tool_result
427            text_msg("assistant", "done"),    // 5
428            text_msg("user", "thanks"),       // 6
429        ];
430
431        let split = find_split_point(&conv, DEFAULT_MIN_KEEP_MESSAGES);
432        // split candidate = 7 - 4 = 3
433        // index 3 is assistant(tool_use), index 4 is user(tool_result)
434        // The kept portion [3..] includes both, so split=3 is clean
435        assert_eq!(split, 3);
436    }
437
438    #[test]
439    fn test_find_split_point_moves_back_when_splitting_tool_cycle() {
440        // Force the split candidate to land ON a tool_result message
441        // 5 messages: user, tool_use, tool_result, assistant, user
442        // MIN_KEEP = 4, candidate split = 5 - 4 = 1
443        // Index 1 is tool_use(assistant), which is fine — kept portion starts with it
444        let conv = vec![
445            text_msg("user", "start"),   // 0
446            tool_use_msg(),              // 1
447            tool_result_msg(),           // 2
448            text_msg("assistant", "ok"), // 3
449            text_msg("user", "next"),    // 4
450        ];
451        let split = find_split_point(&conv, DEFAULT_MIN_KEEP_MESSAGES);
452        assert_eq!(split, 1);
453
454        // Now: 6 messages where split=2 lands on tool_result
455        let conv2 = vec![
456            text_msg("user", "start"),     // 0
457            text_msg("assistant", "ack"),  // 1
458            tool_result_msg(),             // 2 - user with tool_result (split candidate)
459            text_msg("assistant", "done"), // 3
460            text_msg("user", "q1"),        // 4
461            text_msg("assistant", "a1"),   // 5
462        ];
463        let split2 = find_split_point(&conv2, DEFAULT_MIN_KEEP_MESSAGES);
464        // candidate = 6 - 4 = 2, which is a tool_result user msg
465        // prev (index 1) is assistant but no tool_use → no cycle, so split stays at 2
466        assert_eq!(split2, 2);
467    }
468
469    #[test]
470    fn test_find_split_point_too_short() {
471        let conv = vec![
472            text_msg("user", "hi"),
473            text_msg("assistant", "hello"),
474            text_msg("user", "bye"),
475        ];
476        assert_eq!(find_split_point(&conv, DEFAULT_MIN_KEEP_MESSAGES), 0);
477    }
478
479    #[test]
480    fn test_splice_conversation() {
481        let mut conv: Vec<ApiMessage> = (0..10)
482            .map(|i| {
483                let role = if i % 2 == 0 { "user" } else { "assistant" };
484                text_msg(role, &format!("msg {i}"))
485            })
486            .collect();
487
488        splice_conversation(&mut conv, 6, "Summary of messages 0-5");
489
490        // 1 summary + 4 kept = 5
491        assert_eq!(conv.len(), 5);
492        // First message should be the summary
493        match &conv[0].content[0] {
494            ApiContentBlock::Text { text, .. } => {
495                assert!(text.contains("Summary of messages 0-5"));
496                assert!(text.contains("[Previous conversation summary]"));
497            }
498            _ => panic!("Expected text block"),
499        }
500        // Second message should be the old index 6 (msg 6)
501        match &conv[1].content[0] {
502            ApiContentBlock::Text { text, .. } => assert_eq!(text, "msg 6"),
503            _ => panic!("Expected text block"),
504        }
505    }
506
507    #[test]
508    fn test_find_split_point_custom_min_keep() {
509        let conv: Vec<ApiMessage> = (0..10)
510            .map(|i| {
511                let role = if i % 2 == 0 { "user" } else { "assistant" };
512                text_msg(role, &format!("message {i}"))
513            })
514            .collect();
515
516        // min_keep=2 → split at 8 (keeps last 2)
517        assert_eq!(find_split_point(&conv, 2), 8);
518
519        // min_keep=6 → split at 4 (keeps last 6)
520        assert_eq!(find_split_point(&conv, 6), 4);
521
522        // min_keep=1 → split at 9 (keeps last 1)
523        assert_eq!(find_split_point(&conv, 1), 9);
524    }
525
526    #[test]
527    fn test_build_summary_prompt_format() {
528        let msgs = vec![
529            text_msg("user", "Tell me about Rust"),
530            text_msg("assistant", "Rust is a systems language."),
531        ];
532
533        let prompt = build_summary_prompt(&msgs);
534        assert!(prompt.contains("Summarize the following"));
535        assert!(prompt.contains("[user]"));
536        assert!(prompt.contains("Tell me about Rust"));
537        assert!(prompt.contains("[assistant]"));
538        assert!(prompt.contains("Rust is a systems language."));
539        assert!(prompt.contains("<conversation>"));
540    }
541
542    // ── should_prune ──────────────────────────────────────────────────
543
544    #[test]
545    fn test_should_prune_threshold() {
546        // 70% of 160_000 = 112_000
547        assert!(!should_prune(100_000, 160_000, 70));
548        assert!(should_prune(120_000, 160_000, 70));
549        assert!(!should_prune(112_000, 160_000, 70));
550        assert!(should_prune(112_001, 160_000, 70));
551    }
552
553    // ── prune_tool_results ────────────────────────────────────────────
554
555    fn large_tool_result_msg(size: usize) -> ApiMessage {
556        ApiMessage {
557            role: "user".to_string(),
558            content: vec![ApiContentBlock::ToolResult {
559                tool_use_id: "tu_big".to_string(),
560                content: serde_json::json!("x".repeat(size)),
561                is_error: None,
562                cache_control: None,
563                name: None,
564            }],
565        }
566    }
567
568    #[test]
569    fn prune_truncates_large_tool_results() {
570        let mut conv = vec![
571            text_msg("user", "start"),
572            tool_use_msg(),
573            large_tool_result_msg(5000),
574            text_msg("assistant", "ok"),
575            text_msg("user", "next"),
576        ];
577
578        let removed = prune_tool_results(&mut conv, 2000, 2);
579        assert!(removed > 0, "Should have pruned chars");
580
581        // The tool result (index 2) should now contain the pruning marker
582        if let ApiContentBlock::ToolResult { content, .. } = &conv[2].content[0] {
583            let text = content.as_str().unwrap();
584            assert!(text.contains("[..."), "Should contain prune marker");
585            assert!(text.len() < 5000, "Should be smaller than original");
586        } else {
587            panic!("Expected tool result");
588        }
589    }
590
591    #[test]
592    fn prune_preserves_small_tool_results() {
593        let mut conv = vec![
594            text_msg("user", "start"),
595            tool_use_msg(),
596            tool_result_msg(), // Small result
597            text_msg("assistant", "ok"),
598            text_msg("user", "next"),
599        ];
600
601        let removed = prune_tool_results(&mut conv, 2000, 2);
602        assert_eq!(removed, 0, "Small results should not be pruned");
603    }
604
605    #[test]
606    fn prune_skips_tail_messages() {
607        let mut conv = vec![
608            text_msg("user", "old"),
609            text_msg("assistant", "old reply"),
610            tool_use_msg(),
611            large_tool_result_msg(5000), // index 3 — in the tail (preserve_tail=2)
612            text_msg("assistant", "done"),
613        ];
614
615        // preserve_tail=2 → only process [0..3], skipping indices 3 and 4
616        let removed = prune_tool_results(&mut conv, 2000, 2);
617        assert_eq!(removed, 0, "Tail messages should not be pruned");
618    }
619
620    #[test]
621    fn prune_handles_empty_conversation() {
622        let mut conv: Vec<ApiMessage> = vec![];
623        let removed = prune_tool_results(&mut conv, 2000, 2);
624        assert_eq!(removed, 0);
625    }
626
627    #[test]
628    fn prune_multiple_large_tool_results() {
629        let mut conv = vec![
630            text_msg("user", "q1"),
631            tool_use_msg(),
632            large_tool_result_msg(5000), // index 2 — should be pruned
633            text_msg("assistant", "a1"),
634            text_msg("user", "q2"),
635            tool_use_msg(),
636            large_tool_result_msg(8000), // index 6 — should be pruned
637            text_msg("assistant", "a2"),
638            text_msg("user", "latest"),    // tail
639            text_msg("assistant", "done"), // tail
640        ];
641
642        let removed = prune_tool_results(&mut conv, 2000, 2);
643        assert!(removed > 0, "Should have pruned chars");
644
645        // Both tool results should be pruned
646        for idx in [2, 6] {
647            if let ApiContentBlock::ToolResult { content, .. } = &conv[idx].content[0] {
648                let text = content.as_str().unwrap();
649                assert!(text.contains("[..."), "Index {} should be pruned", idx);
650            }
651        }
652    }
653
654    #[test]
655    fn prune_all_messages_in_tail() {
656        let mut conv = vec![
657            text_msg("user", "hello"),
658            tool_use_msg(),
659            large_tool_result_msg(5000),
660        ];
661
662        // preserve_tail=10 > len=3 → nothing should be pruned
663        let removed = prune_tool_results(&mut conv, 2000, 10);
664        assert_eq!(removed, 0, "All messages in tail — nothing to prune");
665    }
666}