Skip to main content

sgr_agent/
compaction.rs

1//! LLM-based context compaction — summarize old messages to stay within limits.
2//!
3//! Unlike simple sliding window (trim_messages), compaction preserves key decisions
4//! and file changes by summarizing old messages through a fast LLM.
5
6use crate::client::LlmClient;
7use crate::types::Message;
8
9#[cfg(feature = "session")]
10use crate::session::{AgentMessage, MessageRole, Session};
11
12/// Compacts conversation history using LLM summarization.
13pub struct Compactor {
14    /// Token threshold — compact when estimated tokens exceed this.
15    pub threshold: usize,
16    /// Number of recent messages to keep uncompacted.
17    pub keep_recent: usize,
18    /// Number of initial messages to preserve (system + first user).
19    pub keep_start: usize,
20    /// Custom compaction prompt (overrides the default).
21    prompt: Option<String>,
22}
23
24impl Compactor {
25    /// Create a compactor with the given token threshold.
26    pub fn new(threshold: usize) -> Self {
27        Self {
28            threshold,
29            keep_recent: 10,
30            keep_start: 2,
31            prompt: None,
32        }
33    }
34
35    /// Create with custom keep parameters.
36    pub fn with_keep(mut self, start: usize, recent: usize) -> Self {
37        self.keep_start = start;
38        self.keep_recent = recent;
39        self
40    }
41
42    /// Use a custom compaction prompt instead of the default.
43    ///
44    /// Useful for domain-specific compaction (e.g., sales coaching, code review).
45    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
46        self.prompt = Some(prompt.into());
47        self
48    }
49
50    /// Check if compaction is needed based on estimated token count.
51    pub fn needs_compaction(&self, messages: &[Message]) -> bool {
52        estimate_tokens(messages) > self.threshold
53    }
54
55    /// Compact messages using an LLM summarizer.
56    ///
57    /// Replaces messages[keep_start..len-keep_recent] with a single summary message.
58    /// Returns true if compaction was performed.
59    pub async fn compact(
60        &self,
61        summarizer: &dyn LlmClient,
62        messages: &mut Vec<Message>,
63    ) -> Result<bool, CompactionError> {
64        let est = estimate_tokens(messages);
65        if est <= self.threshold {
66            return Ok(false);
67        }
68
69        let total = messages.len();
70        if total <= self.keep_start + self.keep_recent + 1 {
71            // Not enough messages to compact
72            return Ok(false);
73        }
74
75        let compact_end = total - self.keep_recent;
76        let to_compact = &messages[self.keep_start..compact_end];
77
78        if to_compact.is_empty() {
79            return Ok(false);
80        }
81
82        // Format messages for summarization
83        let formatted = format_messages_for_summary(to_compact);
84
85        let prompt = self.prompt.as_deref().unwrap_or(COMPACTION_PROMPT);
86        let summary_prompt = vec![Message::system(prompt), Message::user(&formatted)];
87
88        let summary = summarizer
89            .complete(&summary_prompt)
90            .await
91            .map_err(|e| CompactionError::Llm(e.to_string()))?;
92
93        if summary.is_empty() {
94            return Err(CompactionError::EmptySummary);
95        }
96
97        // Replace compacted messages with summary
98        let compacted_count = compact_end - self.keep_start;
99        messages.drain(self.keep_start..compact_end);
100        messages.insert(
101            self.keep_start,
102            Message::system(format!(
103                "<compacted count=\"{}\">\n{}\n</compacted>",
104                compacted_count, summary
105            )),
106        );
107
108        Ok(true)
109    }
110}
111
112/// Session-aware compaction methods (requires `session` feature).
113#[cfg(feature = "session")]
114impl Compactor {
115    /// Estimate token count for a session's messages.
116    pub fn estimate_session_tokens<M: AgentMessage>(session: &Session<M>) -> usize {
117        session
118            .messages()
119            .iter()
120            .map(|m: &M| m.content().chars().count() / 4 + 1)
121            .sum()
122    }
123
124    /// Check if compaction is needed for a session.
125    pub fn needs_session_compaction<M: AgentMessage>(&self, session: &Session<M>) -> bool {
126        Self::estimate_session_tokens(session) > self.threshold
127    }
128
129    /// Compact a session using LLM summarization with incremental support.
130    ///
131    /// Preserves prior `<compacted>` summaries verbatim — only summarizes new
132    /// messages since the last compaction. Returns the number of messages compacted.
133    pub async fn compact_session<M: AgentMessage>(
134        &self,
135        summarizer: &dyn LlmClient,
136        session: &mut Session<M>,
137    ) -> Result<usize, CompactionError> {
138        if !self.needs_session_compaction(session) {
139            return Ok(0);
140        }
141
142        let total = session.messages().len();
143        if total <= self.keep_start + self.keep_recent + 1 {
144            return Ok(0);
145        }
146
147        let compact_end = total - self.keep_recent;
148        let to_compact = &session.messages()[self.keep_start..compact_end];
149        if to_compact.is_empty() {
150            return Ok(0);
151        }
152
153        // Separate existing compacted summaries from new messages (incremental).
154        let mut prior_summary: Option<String> = None;
155        let mut new_messages: Vec<(&str, &str)> = Vec::new();
156        for m in to_compact.iter() {
157            let content: &str = m.content();
158            if content.starts_with("<compacted") {
159                prior_summary = Some(content.to_string());
160            } else {
161                new_messages.push((m.role().as_str(), content));
162            }
163        }
164
165        // Format new messages for summarization
166        let formatted = format_agent_messages_for_summary(&new_messages);
167        let compacted_count = compact_end - self.keep_start;
168
169        // Build user content with prior summary if incremental
170        let user_content = match &prior_summary {
171            Some(prev) => format!(
172                "Previous summary (preserve verbatim, do not re-summarize):\n{prev}\n\nNew messages to summarize:\n{formatted}"
173            ),
174            None => formatted,
175        };
176
177        let prompt = self.prompt.as_deref().unwrap_or(COMPACTION_PROMPT);
178        let summary_prompt = vec![Message::system(prompt), Message::user(&user_content)];
179
180        let summary = summarizer
181            .complete(&summary_prompt)
182            .await
183            .map_err(|e| CompactionError::Llm(e.to_string()))?;
184
185        if summary.is_empty() {
186            return Err(CompactionError::EmptySummary);
187        }
188
189        // Replace compacted messages with summary
190        let msgs = session.messages_mut();
191        msgs.drain(self.keep_start..compact_end);
192        let summary_content =
193            format!("<compacted turns=\"{compacted_count}\">\n{summary}\n</compacted>");
194        msgs.insert(self.keep_start, M::new(M::Role::system(), summary_content));
195
196        Ok(compacted_count)
197    }
198}
199
200impl Default for Compactor {
201    fn default() -> Self {
202        // ~100K tokens threshold (roughly 400K chars / 4)
203        Self::new(100_000)
204    }
205}
206
207/// Errors from compaction.
208#[derive(Debug)]
209pub enum CompactionError {
210    /// LLM call failed.
211    Llm(String),
212    /// LLM returned empty summary.
213    EmptySummary,
214}
215
216impl std::fmt::Display for CompactionError {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        match self {
219            Self::Llm(e) => write!(f, "Compaction LLM error: {}", e),
220            Self::EmptySummary => write!(f, "LLM returned empty summary"),
221        }
222    }
223}
224
225impl std::error::Error for CompactionError {}
226
227/// Estimate token count for messages (rough: chars / 4).
228/// Uses char count (not byte length) for correct non-ASCII estimation.
229pub fn estimate_tokens(messages: &[Message]) -> usize {
230    messages
231        .iter()
232        .map(|m| m.content.chars().count() / 4 + 1)
233        .sum()
234}
235
236/// System prompt for the summarizer LLM.
237const COMPACTION_PROMPT: &str = r#"Summarize this conversation concisely. Preserve:
238- Key decisions made
239- Files read, created, or modified (with paths)
240- Important findings and errors encountered
241- Current task state and next steps
242
243Be concise but thorough. Use bullet points. Do not lose critical context."#;
244
245/// Format messages into text for summarization.
246fn format_messages_for_summary(messages: &[Message]) -> String {
247    let mut output = String::new();
248    for msg in messages {
249        let role = match msg.role {
250            crate::types::Role::System => "SYSTEM",
251            crate::types::Role::User => "USER",
252            crate::types::Role::Assistant => "ASSISTANT",
253            crate::types::Role::Tool => "TOOL",
254        };
255        // Truncate very long messages for summarization (char-safe boundary)
256        let content = if msg.content.chars().count() > 2000 {
257            let truncated: String = msg.content.chars().take(2000).collect();
258            format!(
259                "{}... [truncated, {} chars total]",
260                truncated,
261                msg.content.chars().count()
262            )
263        } else {
264            msg.content.clone()
265        };
266        output.push_str(&format!("[{}]: {}\n\n", role, content));
267    }
268    output
269}
270
271/// Format generic agent messages (role string + content) for summarization.
272#[cfg(feature = "session")]
273fn format_agent_messages_for_summary(messages: &[(&str, &str)]) -> String {
274    let mut output = String::new();
275    for (role, content) in messages {
276        let label = match *role {
277            "system" => "SYSTEM",
278            "user" => "USER",
279            "assistant" => "ASSISTANT",
280            "tool" => "TOOL",
281            other => other,
282        };
283        let content = if content.chars().count() > 2000 {
284            let truncated: String = content.chars().take(2000).collect();
285            format!(
286                "{}... [truncated, {} chars total]",
287                truncated,
288                content.chars().count()
289            )
290        } else {
291            content.to_string()
292        };
293        output.push_str(&format!("[{}]: {}\n\n", label, content));
294    }
295    output
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn estimate_tokens_basic() {
304        let msgs = vec![
305            Message::system("Hello world"), // 11 chars → 3 tokens
306            Message::user("How are you"),   // 11 chars → 3 tokens
307        ];
308        let est = estimate_tokens(&msgs);
309        assert!(est > 0);
310        assert!(est < 100);
311    }
312
313    #[test]
314    fn estimate_tokens_non_ascii() {
315        // "Привет" = 6 chars, 12 bytes. chars/4 = 1, bytes/4 = 3
316        let msgs = vec![Message::user("Привет мир")]; // 10 chars
317        let est = estimate_tokens(&msgs);
318        // Should be 10/4+1 = 3, not 20/4+1 = 6
319        assert_eq!(est, 3);
320    }
321
322    #[test]
323    fn format_messages_non_ascii_truncation() {
324        // 3000 Russian chars — should not panic
325        let cyrillic: String = "Б".repeat(3000);
326        let msgs = vec![Message::user(&cyrillic)];
327        let formatted = format_messages_for_summary(&msgs);
328        assert!(formatted.contains("truncated"));
329    }
330
331    #[test]
332    fn needs_compaction_under_threshold() {
333        let compactor = Compactor::new(1000);
334        let msgs = vec![Message::user("short")];
335        assert!(!compactor.needs_compaction(&msgs));
336    }
337
338    #[test]
339    fn needs_compaction_over_threshold() {
340        let compactor = Compactor::new(10);
341        let msgs: Vec<Message> = (0..100)
342            .map(|i| {
343                Message::user(format!(
344                    "Message number {} with some content to pad it out",
345                    i
346                ))
347            })
348            .collect();
349        assert!(compactor.needs_compaction(&msgs));
350    }
351
352    #[test]
353    fn format_messages_truncates_long() {
354        let long_msg = "x".repeat(5000);
355        let msgs = vec![Message::user(&long_msg)];
356        let formatted = format_messages_for_summary(&msgs);
357        assert!(formatted.contains("truncated"));
358        assert!(formatted.len() < 5000);
359    }
360
361    #[test]
362    fn compactor_default() {
363        let c = Compactor::default();
364        assert_eq!(c.threshold, 100_000);
365        assert_eq!(c.keep_recent, 10);
366        assert_eq!(c.keep_start, 2);
367    }
368
369    #[test]
370    fn compactor_with_keep() {
371        let c = Compactor::new(50_000).with_keep(3, 5);
372        assert_eq!(c.keep_start, 3);
373        assert_eq!(c.keep_recent, 5);
374    }
375
376    #[tokio::test]
377    async fn compact_not_needed() {
378        use crate::types::SgrError;
379        struct MockClient;
380        #[async_trait::async_trait]
381        impl LlmClient for MockClient {
382            async fn structured_call(
383                &self,
384                _: &[Message],
385                _: &serde_json::Value,
386            ) -> Result<
387                (
388                    Option<serde_json::Value>,
389                    Vec<crate::types::ToolCall>,
390                    String,
391                ),
392                SgrError,
393            > {
394                unimplemented!()
395            }
396            async fn tools_call(
397                &self,
398                _: &[Message],
399                _: &[crate::tool::ToolDef],
400            ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
401                unimplemented!()
402            }
403            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
404                Ok("Summary of conversation.".into())
405            }
406        }
407
408        let compactor = Compactor::new(100_000);
409        let mut msgs = vec![Message::user("short")];
410        let result = compactor.compact(&MockClient, &mut msgs).await.unwrap();
411        assert!(!result);
412        assert_eq!(msgs.len(), 1);
413    }
414
415    #[tokio::test]
416    async fn compact_replaces_old_messages() {
417        use crate::types::SgrError;
418        struct MockClient;
419        #[async_trait::async_trait]
420        impl LlmClient for MockClient {
421            async fn structured_call(
422                &self,
423                _: &[Message],
424                _: &serde_json::Value,
425            ) -> Result<
426                (
427                    Option<serde_json::Value>,
428                    Vec<crate::types::ToolCall>,
429                    String,
430                ),
431                SgrError,
432            > {
433                unimplemented!()
434            }
435            async fn tools_call(
436                &self,
437                _: &[Message],
438                _: &[crate::tool::ToolDef],
439            ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
440                unimplemented!()
441            }
442            async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
443                Ok("Key decisions: implemented auth module. Files: src/auth.rs created.".into())
444            }
445        }
446
447        let compactor = Compactor::new(5).with_keep(2, 2); // very low threshold
448        let mut msgs = vec![
449            Message::system("System prompt"),
450            Message::user("Initial task"),
451            Message::assistant("Step 1 done"),
452            Message::user("Continue"),
453            Message::assistant("Step 2 done"),
454            Message::user("Continue more"),
455            Message::assistant("Step 3 done"),
456            // last 2 to keep:
457            Message::user("Final step"),
458            Message::assistant("All done"),
459        ];
460
461        let result = compactor.compact(&MockClient, &mut msgs).await.unwrap();
462        assert!(result);
463
464        // Should have: system, initial, compacted summary, last 2
465        assert_eq!(msgs.len(), 5);
466        assert!(msgs[2].content.contains("compacted"));
467        assert!(msgs[2].content.contains("Key decisions"));
468        assert_eq!(msgs[3].content, "Final step");
469        assert_eq!(msgs[4].content, "All done");
470    }
471
472    #[test]
473    fn with_prompt_overrides_default() {
474        let c = Compactor::new(1000).with_prompt("Custom: summarize sales data");
475        assert_eq!(c.prompt.as_deref(), Some("Custom: summarize sales data"));
476    }
477
478    #[tokio::test]
479    async fn compact_uses_custom_prompt() {
480        use crate::types::SgrError;
481        use std::sync::Arc;
482        use std::sync::atomic::{AtomicBool, Ordering};
483
484        let saw_custom = Arc::new(AtomicBool::new(false));
485        let saw_custom_clone = saw_custom.clone();
486
487        struct PromptCheckClient {
488            saw_custom: Arc<AtomicBool>,
489        }
490        #[async_trait::async_trait]
491        impl LlmClient for PromptCheckClient {
492            async fn structured_call(
493                &self,
494                _: &[Message],
495                _: &serde_json::Value,
496            ) -> Result<
497                (
498                    Option<serde_json::Value>,
499                    Vec<crate::types::ToolCall>,
500                    String,
501                ),
502                SgrError,
503            > {
504                unimplemented!()
505            }
506            async fn tools_call(
507                &self,
508                _: &[Message],
509                _: &[crate::tool::ToolDef],
510            ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
511                unimplemented!()
512            }
513            async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
514                if messages[0].content.contains("SALES FOCUS") {
515                    self.saw_custom.store(true, Ordering::SeqCst);
516                }
517                Ok("Summary".into())
518            }
519        }
520
521        let client = PromptCheckClient {
522            saw_custom: saw_custom_clone,
523        };
524        let compactor = Compactor::new(5)
525            .with_keep(1, 1)
526            .with_prompt("SALES FOCUS: summarize this");
527
528        let mut msgs = vec![
529            Message::system("sys"),
530            Message::user("msg1"),
531            Message::assistant("resp1"),
532            Message::user("msg2"),
533            Message::assistant("resp2"),
534            Message::user("last"),
535        ];
536
537        let result = compactor.compact(&client, &mut msgs).await.unwrap();
538        assert!(result);
539        assert!(saw_custom.load(Ordering::SeqCst));
540    }
541
542    #[cfg(feature = "session")]
543    mod session_tests {
544        use super::*;
545        use crate::session::Session;
546        use crate::session::simple::{SimpleMsg, SimpleRole};
547
548        fn make_session() -> Session<SimpleMsg> {
549            let dir = std::env::temp_dir().join("sgr_compact_session_test");
550            let _ = std::fs::remove_dir_all(&dir);
551            Session::new(dir.to_str().unwrap(), 100).unwrap()
552        }
553
554        #[test]
555        fn estimate_session_tokens_basic() {
556            let mut session = make_session();
557            session.push(SimpleRole::User, "Hello world".into()); // 11 chars → 3
558            session.push(SimpleRole::Assistant, "Hi there".into()); // 8 chars → 3
559            let est = Compactor::estimate_session_tokens(&session);
560            assert!(est > 0 && est < 100);
561            let dir = std::env::temp_dir().join("sgr_compact_session_test");
562            let _ = std::fs::remove_dir_all(&dir);
563        }
564
565        #[tokio::test]
566        async fn compact_session_not_needed() {
567            use crate::types::SgrError;
568            struct MockClient;
569            #[async_trait::async_trait]
570            impl LlmClient for MockClient {
571                async fn structured_call(
572                    &self,
573                    _: &[Message],
574                    _: &serde_json::Value,
575                ) -> Result<
576                    (
577                        Option<serde_json::Value>,
578                        Vec<crate::types::ToolCall>,
579                        String,
580                    ),
581                    SgrError,
582                > {
583                    unimplemented!()
584                }
585                async fn tools_call(
586                    &self,
587                    _: &[Message],
588                    _: &[crate::tool::ToolDef],
589                ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
590                    unimplemented!()
591                }
592                async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
593                    Ok("summary".into())
594                }
595            }
596
597            let mut session = make_session();
598            session.push(SimpleRole::User, "short msg".into());
599            let compactor = Compactor::new(100_000);
600            let result = compactor
601                .compact_session(&MockClient, &mut session)
602                .await
603                .unwrap();
604            assert_eq!(result, 0);
605            let dir = std::env::temp_dir().join("sgr_compact_session_test");
606            let _ = std::fs::remove_dir_all(&dir);
607        }
608
609        #[tokio::test]
610        async fn compact_session_replaces_middle() {
611            use crate::types::SgrError;
612            struct MockClient;
613            #[async_trait::async_trait]
614            impl LlmClient for MockClient {
615                async fn structured_call(
616                    &self,
617                    _: &[Message],
618                    _: &serde_json::Value,
619                ) -> Result<
620                    (
621                        Option<serde_json::Value>,
622                        Vec<crate::types::ToolCall>,
623                        String,
624                    ),
625                    SgrError,
626                > {
627                    unimplemented!()
628                }
629                async fn tools_call(
630                    &self,
631                    _: &[Message],
632                    _: &[crate::tool::ToolDef],
633                ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
634                    unimplemented!()
635                }
636                async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
637                    Ok("Compacted: auth module created".into())
638                }
639            }
640
641            let mut session = make_session();
642            session.push(SimpleRole::System, "system prompt".into());
643            session.push(SimpleRole::User, "initial task".into());
644            for i in 0..6 {
645                let role = if i % 2 == 0 {
646                    SimpleRole::User
647                } else {
648                    SimpleRole::Assistant
649                };
650                session.push(role, format!("msg {i}"));
651            }
652            session.push(SimpleRole::User, "final".into());
653            session.push(SimpleRole::Assistant, "done".into());
654
655            let compactor = Compactor::new(5).with_keep(2, 2);
656            let result = compactor
657                .compact_session(&MockClient, &mut session)
658                .await
659                .unwrap();
660            assert!(result > 0);
661
662            // Check structure: keep_start(2) + compacted(1) + keep_recent(2) = 5
663            assert_eq!(session.messages().len(), 5);
664            assert!(session.messages()[2].content().contains("<compacted"));
665            assert!(session.messages()[2].content().contains("auth module"));
666            assert_eq!(session.messages()[3].content(), "final");
667            assert_eq!(session.messages()[4].content(), "done");
668
669            let dir = std::env::temp_dir().join("sgr_compact_session_test");
670            let _ = std::fs::remove_dir_all(&dir);
671        }
672
673        #[tokio::test]
674        async fn compact_session_incremental_preserves_prior() {
675            use crate::types::SgrError;
676            use std::sync::Arc;
677            use std::sync::atomic::{AtomicBool, Ordering};
678
679            let saw_prior = Arc::new(AtomicBool::new(false));
680            let saw_prior_clone = saw_prior.clone();
681
682            struct IncrementalClient {
683                saw_prior: Arc<AtomicBool>,
684            }
685            #[async_trait::async_trait]
686            impl LlmClient for IncrementalClient {
687                async fn structured_call(
688                    &self,
689                    _: &[Message],
690                    _: &serde_json::Value,
691                ) -> Result<
692                    (
693                        Option<serde_json::Value>,
694                        Vec<crate::types::ToolCall>,
695                        String,
696                    ),
697                    SgrError,
698                > {
699                    unimplemented!()
700                }
701                async fn tools_call(
702                    &self,
703                    _: &[Message],
704                    _: &[crate::tool::ToolDef],
705                ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
706                    unimplemented!()
707                }
708                async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
709                    // Check that the user message contains the prior summary
710                    if messages[1].content.contains("Previous summary")
711                        && messages[1].content.contains("prior context here")
712                    {
713                        self.saw_prior.store(true, Ordering::SeqCst);
714                    }
715                    Ok("Merged summary".into())
716                }
717            }
718
719            let mut session = make_session();
720            session.push(SimpleRole::System, "system".into());
721            session.push(SimpleRole::User, "initial".into());
722            // Simulate a prior compacted block in the middle
723            session.push(
724                SimpleRole::System,
725                "<compacted turns=\"5\">\nprior context here\n</compacted>".into(),
726            );
727            for i in 0..4 {
728                let role = if i % 2 == 0 {
729                    SimpleRole::User
730                } else {
731                    SimpleRole::Assistant
732                };
733                session.push(role, format!("new msg {i}"));
734            }
735            session.push(SimpleRole::User, "keep1".into());
736            session.push(SimpleRole::Assistant, "keep2".into());
737
738            let client = IncrementalClient {
739                saw_prior: saw_prior_clone,
740            };
741            let compactor = Compactor::new(5).with_keep(2, 2);
742            let result = compactor
743                .compact_session(&client, &mut session)
744                .await
745                .unwrap();
746
747            assert!(result > 0);
748            assert!(
749                saw_prior.load(Ordering::SeqCst),
750                "should send prior summary to LLM"
751            );
752            assert!(session.messages()[2].content().contains("<compacted"));
753
754            let dir = std::env::temp_dir().join("sgr_compact_session_test");
755            let _ = std::fs::remove_dir_all(&dir);
756        }
757    }
758}