Skip to main content

traitclaw_core/memory/
compressed.rs

1//! Compressed memory decorator — automatic context window management.
2//!
3//! `CompressedMemory` wraps any [`Memory`] implementation and automatically
4//! summarizes older messages when the conversation exceeds a configured
5//! threshold, keeping the agent within its context window budget.
6//!
7//! # Architecture Decision
8//!
9//! Uses the **Decorator Pattern** — `CompressedMemory<M>` implements `Memory`
10//! for any `M: Memory`. This means it can wrap `InMemoryMemory`, `SqliteMemory`,
11//! or even another `CompressedMemory` (stackable).
12//!
13//! # Example
14//!
15//! ```rust,no_run
16//! use traitclaw_core::memory::compressed::CompressedMemory;
17//! use traitclaw_core::memory::in_memory::InMemoryMemory;
18//!
19//! // Wrap in-memory with compression at 70% threshold
20//! let memory = CompressedMemory::new(
21//!     InMemoryMemory::new(),
22//!     20,  // compress when > 20 messages
23//!     5,   // keep last 5 messages uncompressed
24//! );
25//! ```
26
27use async_trait::async_trait;
28use serde_json::Value;
29use std::sync::Mutex;
30
31use crate::traits::memory::{Memory, MemoryEntry};
32use crate::types::message::{Message, MessageRole};
33use crate::Result;
34
35/// A memory decorator that compresses older messages into summaries.
36///
37/// When the message count exceeds `threshold`, older messages (except
38/// the system prompt) are replaced with a single summary message,
39/// keeping only the most recent `keep_recent` messages intact.
40///
41/// The summary is generated by concatenating message content (for
42/// simplicity and zero-dependency). For LLM-powered summarization,
43/// use a custom strategy by wrapping this decorator with a provider call.
44pub struct CompressedMemory<M: Memory> {
45    inner: M,
46    /// Compress when message count exceeds this.
47    threshold: usize,
48    /// Number of recent messages to keep uncompressed.
49    keep_recent: usize,
50    /// Cached summaries per session.
51    summaries: Mutex<std::collections::HashMap<String, String>>,
52}
53
54impl<M: Memory> CompressedMemory<M> {
55    /// Create a new compressed memory wrapping the given inner memory.
56    ///
57    /// - `threshold`: Trigger compression when message count exceeds this
58    /// - `keep_recent`: Number of most recent messages to preserve uncompressed
59    pub fn new(inner: M, threshold: usize, keep_recent: usize) -> Self {
60        Self {
61            inner,
62            threshold,
63            keep_recent,
64            summaries: Mutex::new(std::collections::HashMap::new()),
65        }
66    }
67
68    /// Get the compression threshold.
69    #[must_use]
70    pub fn threshold(&self) -> usize {
71        self.threshold
72    }
73
74    /// Get the number of recent messages kept uncompressed.
75    #[must_use]
76    pub fn keep_recent(&self) -> usize {
77        self.keep_recent
78    }
79
80    /// Generate a simple summary of messages.
81    ///
82    /// For production use, replace this with an LLM-based summarizer.
83    fn summarize(messages: &[Message]) -> String {
84        let mut summary = String::from("[Compressed context summary]\n");
85        for msg in messages {
86            let role = match msg.role {
87                MessageRole::User => "User",
88                MessageRole::Assistant => "Assistant",
89                MessageRole::System => "System",
90                MessageRole::Tool => "Tool",
91            };
92            // Truncate long messages in summary (char-boundary safe)
93            let content = if msg.content.len() > 100 {
94                // Find the last valid char boundary at or before byte 100
95                let mut end = 100;
96                while end > 0 && !msg.content.is_char_boundary(end) {
97                    end -= 1;
98                }
99                format!("{}...", &msg.content[..end])
100            } else {
101                msg.content.clone()
102            };
103            summary.push_str(&format!("- {role}: {content}\n"));
104        }
105        summary
106    }
107}
108
109#[async_trait]
110impl<M: Memory> Memory for CompressedMemory<M> {
111    async fn messages(&self, session_id: &str) -> Result<Vec<Message>> {
112        let all_messages = self.inner.messages(session_id).await?;
113
114        if all_messages.len() <= self.threshold {
115            return Ok(all_messages);
116        }
117
118        // Separate system prompt (if present)
119        let (system_msgs, conversation): (Vec<_>, Vec<_>) = all_messages
120            .into_iter()
121            .partition(|m| m.role == MessageRole::System);
122
123        if conversation.len() <= self.keep_recent {
124            let mut result = system_msgs;
125            result.extend(conversation);
126            return Ok(result);
127        }
128
129        // Split into old (to compress) and recent (to keep)
130        let split_point = conversation.len().saturating_sub(self.keep_recent);
131        let (old_messages, recent_messages) = conversation.split_at(split_point);
132
133        // Generate or retrieve cached summary
134        let summary = {
135            let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
136            let key = format!("{session_id}:{split_point}");
137            cache
138                .entry(key)
139                .or_insert_with(|| Self::summarize(old_messages))
140                .clone()
141        };
142
143        // Reconstruct: system + summary + recent
144        let mut result = system_msgs;
145        result.push(Message::system(summary));
146        result.extend(recent_messages.iter().cloned());
147
148        Ok(result)
149    }
150
151    async fn append(&self, session_id: &str, message: Message) -> Result<()> {
152        // Invalidate summary cache for this session
153        {
154            let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
155            let prefix = format!("{session_id}:");
156            cache.retain(|k, _| !k.starts_with(&prefix));
157        }
158        self.inner.append(session_id, message).await
159    }
160
161    async fn get_context(&self, session_id: &str, key: &str) -> Result<Option<Value>> {
162        self.inner.get_context(session_id, key).await
163    }
164
165    async fn set_context(&self, session_id: &str, key: &str, value: Value) -> Result<()> {
166        self.inner.set_context(session_id, key, value).await
167    }
168
169    async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
170        self.inner.recall(query, limit).await
171    }
172
173    async fn store(&self, entry: MemoryEntry) -> Result<()> {
174        self.inner.store(entry).await
175    }
176
177    async fn create_session(&self) -> Result<String> {
178        self.inner.create_session().await
179    }
180
181    async fn list_sessions(&self) -> Result<Vec<String>> {
182        self.inner.list_sessions().await
183    }
184
185    async fn delete_session(&self, session_id: &str) -> Result<()> {
186        // Clean up summary cache
187        {
188            let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
189            cache.retain(|k, _| !k.starts_with(session_id));
190        }
191        self.inner.delete_session(session_id).await
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::memory::in_memory::InMemoryMemory;
199
200    #[tokio::test]
201    async fn test_below_threshold_no_compression() {
202        let memory = CompressedMemory::new(InMemoryMemory::new(), 10, 3);
203
204        // Add 5 messages (below threshold of 10)
205        for i in 0..5 {
206            memory
207                .append("s1", Message::user(format!("msg {i}")))
208                .await
209                .unwrap();
210        }
211
212        let msgs = memory.messages("s1").await.unwrap();
213        assert_eq!(msgs.len(), 5, "should return all messages uncompressed");
214    }
215
216    #[tokio::test]
217    async fn test_above_threshold_compresses() {
218        let memory = CompressedMemory::new(InMemoryMemory::new(), 5, 3);
219
220        // Add 8 messages (above threshold of 5)
221        for i in 0..8 {
222            memory
223                .append("s1", Message::user(format!("msg {i}")))
224                .await
225                .unwrap();
226        }
227
228        let msgs = memory.messages("s1").await.unwrap();
229        // Should have: 1 summary + 3 recent = 4
230        assert_eq!(msgs.len(), 4, "should compress to summary + 3 recent");
231
232        // First message should be the summary
233        assert!(
234            msgs[0].content.contains("[Compressed context summary]"),
235            "first msg should be summary, got: {}",
236            msgs[0].content
237        );
238
239        // Last 3 should be the recent messages
240        assert_eq!(msgs[1].content, "msg 5");
241        assert_eq!(msgs[2].content, "msg 6");
242        assert_eq!(msgs[3].content, "msg 7");
243    }
244
245    #[tokio::test]
246    async fn test_system_prompt_preserved() {
247        let inner = InMemoryMemory::new();
248        inner
249            .append("s1", Message::system("You are helpful"))
250            .await
251            .unwrap();
252        for i in 0..10 {
253            inner
254                .append("s1", Message::user(format!("msg {i}")))
255                .await
256                .unwrap();
257        }
258
259        let memory = CompressedMemory::new(inner, 5, 3);
260        let msgs = memory.messages("s1").await.unwrap();
261
262        // System prompt should be first
263        assert_eq!(msgs[0].role, MessageRole::System);
264        assert_eq!(msgs[0].content, "You are helpful");
265    }
266
267    #[tokio::test]
268    async fn test_append_invalidates_cache() {
269        let memory = CompressedMemory::new(InMemoryMemory::new(), 3, 2);
270
271        for i in 0..5 {
272            memory
273                .append("s1", Message::user(format!("msg {i}")))
274                .await
275                .unwrap();
276        }
277
278        // First call populates cache
279        let msgs1 = memory.messages("s1").await.unwrap();
280        let summary1 = msgs1[0].content.clone();
281
282        // Add new message — should invalidate cache
283        memory.append("s1", Message::user("msg 5")).await.unwrap();
284
285        let msgs2 = memory.messages("s1").await.unwrap();
286        // Summary should be different (more messages compressed)
287        assert_ne!(msgs2[0].content, summary1);
288    }
289
290    #[tokio::test]
291    async fn test_stackable_decorator() {
292        // CompressedMemory wrapping CompressedMemory
293        let inner = CompressedMemory::new(InMemoryMemory::new(), 20, 5);
294        let outer = CompressedMemory::new(inner, 10, 3);
295
296        for i in 0..15 {
297            outer
298                .append("s1", Message::user(format!("msg {i}")))
299                .await
300                .unwrap();
301        }
302
303        let msgs = outer.messages("s1").await.unwrap();
304        // Outer threshold is 10, keep_recent is 3
305        // So we get: 1 summary + 3 recent = 4
306        assert_eq!(msgs.len(), 4, "stacked decorators should compress");
307    }
308
309    #[tokio::test]
310    async fn test_working_memory_delegates() {
311        let memory = CompressedMemory::new(InMemoryMemory::new(), 5, 2);
312
313        memory
314            .set_context("s1", "key", serde_json::json!("value"))
315            .await
316            .unwrap();
317
318        let val = memory.get_context("s1", "key").await.unwrap();
319        assert_eq!(val, Some(serde_json::json!("value")));
320    }
321
322    #[tokio::test]
323    async fn test_multibyte_content_no_panic() {
324        // Regression: &msg.content[..100] used to panic on multi-byte UTF-8
325        let memory = CompressedMemory::new(InMemoryMemory::new(), 3, 1);
326
327        // Vietnamese text > 100 bytes (each Vietnamese char is 2-3 bytes)
328        let long_vietnamese = "Xin chào! Đây là một tin nhắn rất dài bằng tiếng Việt để kiểm tra rằng việc cắt ngắn không gây lỗi panic khi gặp ký tự đa byte UTF-8.";
329        assert!(
330            long_vietnamese.len() > 100,
331            "test premise: string must be >100 bytes"
332        );
333
334        for _ in 0..5 {
335            memory
336                .append("s1", Message::user(long_vietnamese))
337                .await
338                .unwrap();
339        }
340
341        // This should NOT panic
342        let msgs = memory.messages("s1").await.unwrap();
343        assert_eq!(msgs.len(), 2); // 1 summary + 1 recent
344        assert!(msgs[0].content.contains("[Compressed"));
345    }
346
347    #[tokio::test]
348    async fn test_threshold_zero_always_compresses() {
349        let memory = CompressedMemory::new(InMemoryMemory::new(), 0, 1);
350
351        memory.append("s1", Message::user("a")).await.unwrap();
352        memory.append("s1", Message::user("b")).await.unwrap();
353
354        let msgs = memory.messages("s1").await.unwrap();
355        // threshold=0 means even 2 messages triggers compression
356        // summary + 1 recent = 2
357        assert_eq!(msgs.len(), 2);
358        assert!(msgs[0].content.contains("[Compressed"));
359        assert_eq!(msgs[1].content, "b");
360    }
361
362    #[tokio::test]
363    async fn test_keep_recent_exceeds_message_count() {
364        let memory = CompressedMemory::new(InMemoryMemory::new(), 0, 100);
365
366        memory
367            .append("s1", Message::user("only one"))
368            .await
369            .unwrap();
370
371        let msgs = memory.messages("s1").await.unwrap();
372        // keep_recent (100) > message count (1) → no compression
373        assert_eq!(msgs.len(), 1);
374        assert_eq!(msgs[0].content, "only one");
375    }
376}