1use crate::client::LlmClient;
7use crate::types::Message;
8
9#[cfg(feature = "session")]
10use crate::session::{AgentMessage, MessageRole, Session};
11
12pub struct Compactor {
14 pub threshold: usize,
16 pub keep_recent: usize,
18 pub keep_start: usize,
20 prompt: Option<String>,
22}
23
24impl Compactor {
25 pub fn new(threshold: usize) -> Self {
27 Self {
28 threshold,
29 keep_recent: 10,
30 keep_start: 2,
31 prompt: None,
32 }
33 }
34
35 pub fn with_keep(mut self, start: usize, recent: usize) -> Self {
37 self.keep_start = start;
38 self.keep_recent = recent;
39 self
40 }
41
42 pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
46 self.prompt = Some(prompt.into());
47 self
48 }
49
50 pub fn needs_compaction(&self, messages: &[Message]) -> bool {
52 estimate_tokens(messages) > self.threshold
53 }
54
55 pub async fn compact(
60 &self,
61 summarizer: &dyn LlmClient,
62 messages: &mut Vec<Message>,
63 ) -> Result<bool, CompactionError> {
64 let est = estimate_tokens(messages);
65 if est <= self.threshold {
66 return Ok(false);
67 }
68
69 let total = messages.len();
70 if total <= self.keep_start + self.keep_recent + 1 {
71 return Ok(false);
73 }
74
75 let compact_end = total - self.keep_recent;
76 let to_compact = &messages[self.keep_start..compact_end];
77
78 if to_compact.is_empty() {
79 return Ok(false);
80 }
81
82 let formatted = format_messages_for_summary(to_compact);
84
85 let prompt = self.prompt.as_deref().unwrap_or(COMPACTION_PROMPT);
86 let summary_prompt = vec![Message::system(prompt), Message::user(&formatted)];
87
88 let summary = summarizer
89 .complete(&summary_prompt)
90 .await
91 .map_err(|e| CompactionError::Llm(e.to_string()))?;
92
93 if summary.is_empty() {
94 return Err(CompactionError::EmptySummary);
95 }
96
97 let compacted_count = compact_end - self.keep_start;
99 messages.drain(self.keep_start..compact_end);
100 messages.insert(
101 self.keep_start,
102 Message::system(format!(
103 "<compacted count=\"{}\">\n{}\n</compacted>",
104 compacted_count, summary
105 )),
106 );
107
108 Ok(true)
109 }
110}
111
112#[cfg(feature = "session")]
114impl Compactor {
115 pub fn estimate_session_tokens<M: AgentMessage>(session: &Session<M>) -> usize {
117 session
118 .messages()
119 .iter()
120 .map(|m: &M| m.content().chars().count() / 4 + 1)
121 .sum()
122 }
123
124 pub fn needs_session_compaction<M: AgentMessage>(&self, session: &Session<M>) -> bool {
126 Self::estimate_session_tokens(session) > self.threshold
127 }
128
129 pub async fn compact_session<M: AgentMessage>(
134 &self,
135 summarizer: &dyn LlmClient,
136 session: &mut Session<M>,
137 ) -> Result<usize, CompactionError> {
138 if !self.needs_session_compaction(session) {
139 return Ok(0);
140 }
141
142 let total = session.messages().len();
143 if total <= self.keep_start + self.keep_recent + 1 {
144 return Ok(0);
145 }
146
147 let compact_end = total - self.keep_recent;
148 let to_compact = &session.messages()[self.keep_start..compact_end];
149 if to_compact.is_empty() {
150 return Ok(0);
151 }
152
153 let mut prior_summary: Option<String> = None;
155 let mut new_messages: Vec<(&str, &str)> = Vec::new();
156 for m in to_compact.iter() {
157 let content: &str = m.content();
158 if content.starts_with("<compacted") {
159 prior_summary = Some(content.to_string());
160 } else {
161 new_messages.push((m.role().as_str(), content));
162 }
163 }
164
165 let formatted = format_agent_messages_for_summary(&new_messages);
167 let compacted_count = compact_end - self.keep_start;
168
169 let user_content = match &prior_summary {
171 Some(prev) => format!(
172 "Previous summary (preserve verbatim, do not re-summarize):\n{prev}\n\nNew messages to summarize:\n{formatted}"
173 ),
174 None => formatted,
175 };
176
177 let prompt = self.prompt.as_deref().unwrap_or(COMPACTION_PROMPT);
178 let summary_prompt = vec![Message::system(prompt), Message::user(&user_content)];
179
180 let summary = summarizer
181 .complete(&summary_prompt)
182 .await
183 .map_err(|e| CompactionError::Llm(e.to_string()))?;
184
185 if summary.is_empty() {
186 return Err(CompactionError::EmptySummary);
187 }
188
189 let msgs = session.messages_mut();
191 msgs.drain(self.keep_start..compact_end);
192 let summary_content =
193 format!("<compacted turns=\"{compacted_count}\">\n{summary}\n</compacted>");
194 msgs.insert(self.keep_start, M::new(M::Role::system(), summary_content));
195
196 Ok(compacted_count)
197 }
198}
199
200impl Default for Compactor {
201 fn default() -> Self {
202 Self::new(100_000)
204 }
205}
206
207#[derive(Debug)]
209pub enum CompactionError {
210 Llm(String),
212 EmptySummary,
214}
215
216impl std::fmt::Display for CompactionError {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 match self {
219 Self::Llm(e) => write!(f, "Compaction LLM error: {}", e),
220 Self::EmptySummary => write!(f, "LLM returned empty summary"),
221 }
222 }
223}
224
225impl std::error::Error for CompactionError {}
226
227pub fn estimate_tokens(messages: &[Message]) -> usize {
230 messages
231 .iter()
232 .map(|m| m.content.chars().count() / 4 + 1)
233 .sum()
234}
235
236const COMPACTION_PROMPT: &str = r#"Summarize this conversation concisely. Preserve:
238- Key decisions made
239- Files read, created, or modified (with paths)
240- Important findings and errors encountered
241- Current task state and next steps
242
243Be concise but thorough. Use bullet points. Do not lose critical context."#;
244
245fn format_messages_for_summary(messages: &[Message]) -> String {
247 let mut output = String::new();
248 for msg in messages {
249 let role = match msg.role {
250 crate::types::Role::System => "SYSTEM",
251 crate::types::Role::User => "USER",
252 crate::types::Role::Assistant => "ASSISTANT",
253 crate::types::Role::Tool => "TOOL",
254 };
255 let content = if msg.content.chars().count() > 2000 {
257 let truncated: String = msg.content.chars().take(2000).collect();
258 format!(
259 "{}... [truncated, {} chars total]",
260 truncated,
261 msg.content.chars().count()
262 )
263 } else {
264 msg.content.clone()
265 };
266 output.push_str(&format!("[{}]: {}\n\n", role, content));
267 }
268 output
269}
270
271#[cfg(feature = "session")]
273fn format_agent_messages_for_summary(messages: &[(&str, &str)]) -> String {
274 let mut output = String::new();
275 for (role, content) in messages {
276 let label = match *role {
277 "system" => "SYSTEM",
278 "user" => "USER",
279 "assistant" => "ASSISTANT",
280 "tool" => "TOOL",
281 other => other,
282 };
283 let content = if content.chars().count() > 2000 {
284 let truncated: String = content.chars().take(2000).collect();
285 format!(
286 "{}... [truncated, {} chars total]",
287 truncated,
288 content.chars().count()
289 )
290 } else {
291 content.to_string()
292 };
293 output.push_str(&format!("[{}]: {}\n\n", label, content));
294 }
295 output
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn estimate_tokens_basic() {
304 let msgs = vec![
305 Message::system("Hello world"), Message::user("How are you"), ];
308 let est = estimate_tokens(&msgs);
309 assert!(est > 0);
310 assert!(est < 100);
311 }
312
313 #[test]
314 fn estimate_tokens_non_ascii() {
315 let msgs = vec![Message::user("Привет мир")]; let est = estimate_tokens(&msgs);
318 assert_eq!(est, 3);
320 }
321
322 #[test]
323 fn format_messages_non_ascii_truncation() {
324 let cyrillic: String = "Б".repeat(3000);
326 let msgs = vec![Message::user(&cyrillic)];
327 let formatted = format_messages_for_summary(&msgs);
328 assert!(formatted.contains("truncated"));
329 }
330
331 #[test]
332 fn needs_compaction_under_threshold() {
333 let compactor = Compactor::new(1000);
334 let msgs = vec![Message::user("short")];
335 assert!(!compactor.needs_compaction(&msgs));
336 }
337
338 #[test]
339 fn needs_compaction_over_threshold() {
340 let compactor = Compactor::new(10);
341 let msgs: Vec<Message> = (0..100)
342 .map(|i| {
343 Message::user(format!(
344 "Message number {} with some content to pad it out",
345 i
346 ))
347 })
348 .collect();
349 assert!(compactor.needs_compaction(&msgs));
350 }
351
352 #[test]
353 fn format_messages_truncates_long() {
354 let long_msg = "x".repeat(5000);
355 let msgs = vec![Message::user(&long_msg)];
356 let formatted = format_messages_for_summary(&msgs);
357 assert!(formatted.contains("truncated"));
358 assert!(formatted.len() < 5000);
359 }
360
361 #[test]
362 fn compactor_default() {
363 let c = Compactor::default();
364 assert_eq!(c.threshold, 100_000);
365 assert_eq!(c.keep_recent, 10);
366 assert_eq!(c.keep_start, 2);
367 }
368
369 #[test]
370 fn compactor_with_keep() {
371 let c = Compactor::new(50_000).with_keep(3, 5);
372 assert_eq!(c.keep_start, 3);
373 assert_eq!(c.keep_recent, 5);
374 }
375
376 #[tokio::test]
377 async fn compact_not_needed() {
378 use crate::types::SgrError;
379 struct MockClient;
380 #[async_trait::async_trait]
381 impl LlmClient for MockClient {
382 async fn structured_call(
383 &self,
384 _: &[Message],
385 _: &serde_json::Value,
386 ) -> Result<
387 (
388 Option<serde_json::Value>,
389 Vec<crate::types::ToolCall>,
390 String,
391 ),
392 SgrError,
393 > {
394 unimplemented!()
395 }
396 async fn tools_call(
397 &self,
398 _: &[Message],
399 _: &[crate::tool::ToolDef],
400 ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
401 unimplemented!()
402 }
403 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
404 Ok("Summary of conversation.".into())
405 }
406 }
407
408 let compactor = Compactor::new(100_000);
409 let mut msgs = vec![Message::user("short")];
410 let result = compactor.compact(&MockClient, &mut msgs).await.unwrap();
411 assert!(!result);
412 assert_eq!(msgs.len(), 1);
413 }
414
415 #[tokio::test]
416 async fn compact_replaces_old_messages() {
417 use crate::types::SgrError;
418 struct MockClient;
419 #[async_trait::async_trait]
420 impl LlmClient for MockClient {
421 async fn structured_call(
422 &self,
423 _: &[Message],
424 _: &serde_json::Value,
425 ) -> Result<
426 (
427 Option<serde_json::Value>,
428 Vec<crate::types::ToolCall>,
429 String,
430 ),
431 SgrError,
432 > {
433 unimplemented!()
434 }
435 async fn tools_call(
436 &self,
437 _: &[Message],
438 _: &[crate::tool::ToolDef],
439 ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
440 unimplemented!()
441 }
442 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
443 Ok("Key decisions: implemented auth module. Files: src/auth.rs created.".into())
444 }
445 }
446
447 let compactor = Compactor::new(5).with_keep(2, 2); let mut msgs = vec![
449 Message::system("System prompt"),
450 Message::user("Initial task"),
451 Message::assistant("Step 1 done"),
452 Message::user("Continue"),
453 Message::assistant("Step 2 done"),
454 Message::user("Continue more"),
455 Message::assistant("Step 3 done"),
456 Message::user("Final step"),
458 Message::assistant("All done"),
459 ];
460
461 let result = compactor.compact(&MockClient, &mut msgs).await.unwrap();
462 assert!(result);
463
464 assert_eq!(msgs.len(), 5);
466 assert!(msgs[2].content.contains("compacted"));
467 assert!(msgs[2].content.contains("Key decisions"));
468 assert_eq!(msgs[3].content, "Final step");
469 assert_eq!(msgs[4].content, "All done");
470 }
471
472 #[test]
473 fn with_prompt_overrides_default() {
474 let c = Compactor::new(1000).with_prompt("Custom: summarize sales data");
475 assert_eq!(c.prompt.as_deref(), Some("Custom: summarize sales data"));
476 }
477
478 #[tokio::test]
479 async fn compact_uses_custom_prompt() {
480 use crate::types::SgrError;
481 use std::sync::Arc;
482 use std::sync::atomic::{AtomicBool, Ordering};
483
484 let saw_custom = Arc::new(AtomicBool::new(false));
485 let saw_custom_clone = saw_custom.clone();
486
487 struct PromptCheckClient {
488 saw_custom: Arc<AtomicBool>,
489 }
490 #[async_trait::async_trait]
491 impl LlmClient for PromptCheckClient {
492 async fn structured_call(
493 &self,
494 _: &[Message],
495 _: &serde_json::Value,
496 ) -> Result<
497 (
498 Option<serde_json::Value>,
499 Vec<crate::types::ToolCall>,
500 String,
501 ),
502 SgrError,
503 > {
504 unimplemented!()
505 }
506 async fn tools_call(
507 &self,
508 _: &[Message],
509 _: &[crate::tool::ToolDef],
510 ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
511 unimplemented!()
512 }
513 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
514 if messages[0].content.contains("SALES FOCUS") {
515 self.saw_custom.store(true, Ordering::SeqCst);
516 }
517 Ok("Summary".into())
518 }
519 }
520
521 let client = PromptCheckClient {
522 saw_custom: saw_custom_clone,
523 };
524 let compactor = Compactor::new(5)
525 .with_keep(1, 1)
526 .with_prompt("SALES FOCUS: summarize this");
527
528 let mut msgs = vec![
529 Message::system("sys"),
530 Message::user("msg1"),
531 Message::assistant("resp1"),
532 Message::user("msg2"),
533 Message::assistant("resp2"),
534 Message::user("last"),
535 ];
536
537 let result = compactor.compact(&client, &mut msgs).await.unwrap();
538 assert!(result);
539 assert!(saw_custom.load(Ordering::SeqCst));
540 }
541
542 #[cfg(feature = "session")]
543 mod session_tests {
544 use super::*;
545 use crate::session::Session;
546 use crate::session::simple::{SimpleMsg, SimpleRole};
547
548 fn make_session() -> Session<SimpleMsg> {
549 let dir = std::env::temp_dir().join("sgr_compact_session_test");
550 let _ = std::fs::remove_dir_all(&dir);
551 Session::new(dir.to_str().unwrap(), 100).unwrap()
552 }
553
554 #[test]
555 fn estimate_session_tokens_basic() {
556 let mut session = make_session();
557 session.push(SimpleRole::User, "Hello world".into()); session.push(SimpleRole::Assistant, "Hi there".into()); let est = Compactor::estimate_session_tokens(&session);
560 assert!(est > 0 && est < 100);
561 let dir = std::env::temp_dir().join("sgr_compact_session_test");
562 let _ = std::fs::remove_dir_all(&dir);
563 }
564
565 #[tokio::test]
566 async fn compact_session_not_needed() {
567 use crate::types::SgrError;
568 struct MockClient;
569 #[async_trait::async_trait]
570 impl LlmClient for MockClient {
571 async fn structured_call(
572 &self,
573 _: &[Message],
574 _: &serde_json::Value,
575 ) -> Result<
576 (
577 Option<serde_json::Value>,
578 Vec<crate::types::ToolCall>,
579 String,
580 ),
581 SgrError,
582 > {
583 unimplemented!()
584 }
585 async fn tools_call(
586 &self,
587 _: &[Message],
588 _: &[crate::tool::ToolDef],
589 ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
590 unimplemented!()
591 }
592 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
593 Ok("summary".into())
594 }
595 }
596
597 let mut session = make_session();
598 session.push(SimpleRole::User, "short msg".into());
599 let compactor = Compactor::new(100_000);
600 let result = compactor
601 .compact_session(&MockClient, &mut session)
602 .await
603 .unwrap();
604 assert_eq!(result, 0);
605 let dir = std::env::temp_dir().join("sgr_compact_session_test");
606 let _ = std::fs::remove_dir_all(&dir);
607 }
608
609 #[tokio::test]
610 async fn compact_session_replaces_middle() {
611 use crate::types::SgrError;
612 struct MockClient;
613 #[async_trait::async_trait]
614 impl LlmClient for MockClient {
615 async fn structured_call(
616 &self,
617 _: &[Message],
618 _: &serde_json::Value,
619 ) -> Result<
620 (
621 Option<serde_json::Value>,
622 Vec<crate::types::ToolCall>,
623 String,
624 ),
625 SgrError,
626 > {
627 unimplemented!()
628 }
629 async fn tools_call(
630 &self,
631 _: &[Message],
632 _: &[crate::tool::ToolDef],
633 ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
634 unimplemented!()
635 }
636 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
637 Ok("Compacted: auth module created".into())
638 }
639 }
640
641 let mut session = make_session();
642 session.push(SimpleRole::System, "system prompt".into());
643 session.push(SimpleRole::User, "initial task".into());
644 for i in 0..6 {
645 let role = if i % 2 == 0 {
646 SimpleRole::User
647 } else {
648 SimpleRole::Assistant
649 };
650 session.push(role, format!("msg {i}"));
651 }
652 session.push(SimpleRole::User, "final".into());
653 session.push(SimpleRole::Assistant, "done".into());
654
655 let compactor = Compactor::new(5).with_keep(2, 2);
656 let result = compactor
657 .compact_session(&MockClient, &mut session)
658 .await
659 .unwrap();
660 assert!(result > 0);
661
662 assert_eq!(session.messages().len(), 5);
664 assert!(session.messages()[2].content().contains("<compacted"));
665 assert!(session.messages()[2].content().contains("auth module"));
666 assert_eq!(session.messages()[3].content(), "final");
667 assert_eq!(session.messages()[4].content(), "done");
668
669 let dir = std::env::temp_dir().join("sgr_compact_session_test");
670 let _ = std::fs::remove_dir_all(&dir);
671 }
672
673 #[tokio::test]
674 async fn compact_session_incremental_preserves_prior() {
675 use crate::types::SgrError;
676 use std::sync::Arc;
677 use std::sync::atomic::{AtomicBool, Ordering};
678
679 let saw_prior = Arc::new(AtomicBool::new(false));
680 let saw_prior_clone = saw_prior.clone();
681
682 struct IncrementalClient {
683 saw_prior: Arc<AtomicBool>,
684 }
685 #[async_trait::async_trait]
686 impl LlmClient for IncrementalClient {
687 async fn structured_call(
688 &self,
689 _: &[Message],
690 _: &serde_json::Value,
691 ) -> Result<
692 (
693 Option<serde_json::Value>,
694 Vec<crate::types::ToolCall>,
695 String,
696 ),
697 SgrError,
698 > {
699 unimplemented!()
700 }
701 async fn tools_call(
702 &self,
703 _: &[Message],
704 _: &[crate::tool::ToolDef],
705 ) -> Result<Vec<crate::types::ToolCall>, SgrError> {
706 unimplemented!()
707 }
708 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
709 if messages[1].content.contains("Previous summary")
711 && messages[1].content.contains("prior context here")
712 {
713 self.saw_prior.store(true, Ordering::SeqCst);
714 }
715 Ok("Merged summary".into())
716 }
717 }
718
719 let mut session = make_session();
720 session.push(SimpleRole::System, "system".into());
721 session.push(SimpleRole::User, "initial".into());
722 session.push(
724 SimpleRole::System,
725 "<compacted turns=\"5\">\nprior context here\n</compacted>".into(),
726 );
727 for i in 0..4 {
728 let role = if i % 2 == 0 {
729 SimpleRole::User
730 } else {
731 SimpleRole::Assistant
732 };
733 session.push(role, format!("new msg {i}"));
734 }
735 session.push(SimpleRole::User, "keep1".into());
736 session.push(SimpleRole::Assistant, "keep2".into());
737
738 let client = IncrementalClient {
739 saw_prior: saw_prior_clone,
740 };
741 let compactor = Compactor::new(5).with_keep(2, 2);
742 let result = compactor
743 .compact_session(&client, &mut session)
744 .await
745 .unwrap();
746
747 assert!(result > 0);
748 assert!(
749 saw_prior.load(Ordering::SeqCst),
750 "should send prior summary to LLM"
751 );
752 assert!(session.messages()[2].content().contains("<compacted"));
753
754 let dir = std::env::temp_dir().join("sgr_compact_session_test");
755 let _ = std::fs::remove_dir_all(&dir);
756 }
757 }
758}