1use async_trait::async_trait;
28use serde_json::Value;
29use std::sync::Mutex;
30
31use crate::traits::memory::{Memory, MemoryEntry};
32use crate::types::message::{Message, MessageRole};
33use crate::Result;
34
35pub struct CompressedMemory<M: Memory> {
45 inner: M,
46 threshold: usize,
48 keep_recent: usize,
50 summaries: Mutex<std::collections::HashMap<String, String>>,
52}
53
54impl<M: Memory> CompressedMemory<M> {
55 pub fn new(inner: M, threshold: usize, keep_recent: usize) -> Self {
60 Self {
61 inner,
62 threshold,
63 keep_recent,
64 summaries: Mutex::new(std::collections::HashMap::new()),
65 }
66 }
67
68 #[must_use]
70 pub fn threshold(&self) -> usize {
71 self.threshold
72 }
73
74 #[must_use]
76 pub fn keep_recent(&self) -> usize {
77 self.keep_recent
78 }
79
80 fn summarize(messages: &[Message]) -> String {
84 let mut summary = String::from("[Compressed context summary]\n");
85 for msg in messages {
86 let role = match msg.role {
87 MessageRole::User => "User",
88 MessageRole::Assistant => "Assistant",
89 MessageRole::System => "System",
90 MessageRole::Tool => "Tool",
91 };
92 let content = if msg.content.len() > 100 {
94 let mut end = 100;
96 while end > 0 && !msg.content.is_char_boundary(end) {
97 end -= 1;
98 }
99 format!("{}...", &msg.content[..end])
100 } else {
101 msg.content.clone()
102 };
103 summary.push_str(&format!("- {role}: {content}\n"));
104 }
105 summary
106 }
107}
108
109#[async_trait]
110impl<M: Memory> Memory for CompressedMemory<M> {
111 async fn messages(&self, session_id: &str) -> Result<Vec<Message>> {
112 let all_messages = self.inner.messages(session_id).await?;
113
114 if all_messages.len() <= self.threshold {
115 return Ok(all_messages);
116 }
117
118 let (system_msgs, conversation): (Vec<_>, Vec<_>) = all_messages
120 .into_iter()
121 .partition(|m| m.role == MessageRole::System);
122
123 if conversation.len() <= self.keep_recent {
124 let mut result = system_msgs;
125 result.extend(conversation);
126 return Ok(result);
127 }
128
129 let split_point = conversation.len().saturating_sub(self.keep_recent);
131 let (old_messages, recent_messages) = conversation.split_at(split_point);
132
133 let summary = {
135 let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
136 let key = format!("{session_id}:{split_point}");
137 cache
138 .entry(key)
139 .or_insert_with(|| Self::summarize(old_messages))
140 .clone()
141 };
142
143 let mut result = system_msgs;
145 result.push(Message::system(summary));
146 result.extend(recent_messages.iter().cloned());
147
148 Ok(result)
149 }
150
151 async fn append(&self, session_id: &str, message: Message) -> Result<()> {
152 {
154 let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
155 let prefix = format!("{session_id}:");
156 cache.retain(|k, _| !k.starts_with(&prefix));
157 }
158 self.inner.append(session_id, message).await
159 }
160
161 async fn get_context(&self, session_id: &str, key: &str) -> Result<Option<Value>> {
162 self.inner.get_context(session_id, key).await
163 }
164
165 async fn set_context(&self, session_id: &str, key: &str, value: Value) -> Result<()> {
166 self.inner.set_context(session_id, key, value).await
167 }
168
169 async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
170 self.inner.recall(query, limit).await
171 }
172
173 async fn store(&self, entry: MemoryEntry) -> Result<()> {
174 self.inner.store(entry).await
175 }
176
177 async fn create_session(&self) -> Result<String> {
178 self.inner.create_session().await
179 }
180
181 async fn list_sessions(&self) -> Result<Vec<String>> {
182 self.inner.list_sessions().await
183 }
184
185 async fn delete_session(&self, session_id: &str) -> Result<()> {
186 {
188 let mut cache = self.summaries.lock().unwrap_or_else(|e| e.into_inner());
189 cache.retain(|k, _| !k.starts_with(session_id));
190 }
191 self.inner.delete_session(session_id).await
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::memory::in_memory::InMemoryMemory;
199
200 #[tokio::test]
201 async fn test_below_threshold_no_compression() {
202 let memory = CompressedMemory::new(InMemoryMemory::new(), 10, 3);
203
204 for i in 0..5 {
206 memory
207 .append("s1", Message::user(format!("msg {i}")))
208 .await
209 .unwrap();
210 }
211
212 let msgs = memory.messages("s1").await.unwrap();
213 assert_eq!(msgs.len(), 5, "should return all messages uncompressed");
214 }
215
216 #[tokio::test]
217 async fn test_above_threshold_compresses() {
218 let memory = CompressedMemory::new(InMemoryMemory::new(), 5, 3);
219
220 for i in 0..8 {
222 memory
223 .append("s1", Message::user(format!("msg {i}")))
224 .await
225 .unwrap();
226 }
227
228 let msgs = memory.messages("s1").await.unwrap();
229 assert_eq!(msgs.len(), 4, "should compress to summary + 3 recent");
231
232 assert!(
234 msgs[0].content.contains("[Compressed context summary]"),
235 "first msg should be summary, got: {}",
236 msgs[0].content
237 );
238
239 assert_eq!(msgs[1].content, "msg 5");
241 assert_eq!(msgs[2].content, "msg 6");
242 assert_eq!(msgs[3].content, "msg 7");
243 }
244
245 #[tokio::test]
246 async fn test_system_prompt_preserved() {
247 let inner = InMemoryMemory::new();
248 inner
249 .append("s1", Message::system("You are helpful"))
250 .await
251 .unwrap();
252 for i in 0..10 {
253 inner
254 .append("s1", Message::user(format!("msg {i}")))
255 .await
256 .unwrap();
257 }
258
259 let memory = CompressedMemory::new(inner, 5, 3);
260 let msgs = memory.messages("s1").await.unwrap();
261
262 assert_eq!(msgs[0].role, MessageRole::System);
264 assert_eq!(msgs[0].content, "You are helpful");
265 }
266
267 #[tokio::test]
268 async fn test_append_invalidates_cache() {
269 let memory = CompressedMemory::new(InMemoryMemory::new(), 3, 2);
270
271 for i in 0..5 {
272 memory
273 .append("s1", Message::user(format!("msg {i}")))
274 .await
275 .unwrap();
276 }
277
278 let msgs1 = memory.messages("s1").await.unwrap();
280 let summary1 = msgs1[0].content.clone();
281
282 memory.append("s1", Message::user("msg 5")).await.unwrap();
284
285 let msgs2 = memory.messages("s1").await.unwrap();
286 assert_ne!(msgs2[0].content, summary1);
288 }
289
290 #[tokio::test]
291 async fn test_stackable_decorator() {
292 let inner = CompressedMemory::new(InMemoryMemory::new(), 20, 5);
294 let outer = CompressedMemory::new(inner, 10, 3);
295
296 for i in 0..15 {
297 outer
298 .append("s1", Message::user(format!("msg {i}")))
299 .await
300 .unwrap();
301 }
302
303 let msgs = outer.messages("s1").await.unwrap();
304 assert_eq!(msgs.len(), 4, "stacked decorators should compress");
307 }
308
309 #[tokio::test]
310 async fn test_working_memory_delegates() {
311 let memory = CompressedMemory::new(InMemoryMemory::new(), 5, 2);
312
313 memory
314 .set_context("s1", "key", serde_json::json!("value"))
315 .await
316 .unwrap();
317
318 let val = memory.get_context("s1", "key").await.unwrap();
319 assert_eq!(val, Some(serde_json::json!("value")));
320 }
321
322 #[tokio::test]
323 async fn test_multibyte_content_no_panic() {
324 let memory = CompressedMemory::new(InMemoryMemory::new(), 3, 1);
326
327 let long_vietnamese = "Xin chào! Đây là một tin nhắn rất dài bằng tiếng Việt để kiểm tra rằng việc cắt ngắn không gây lỗi panic khi gặp ký tự đa byte UTF-8.";
329 assert!(
330 long_vietnamese.len() > 100,
331 "test premise: string must be >100 bytes"
332 );
333
334 for _ in 0..5 {
335 memory
336 .append("s1", Message::user(long_vietnamese))
337 .await
338 .unwrap();
339 }
340
341 let msgs = memory.messages("s1").await.unwrap();
343 assert_eq!(msgs.len(), 2); assert!(msgs[0].content.contains("[Compressed"));
345 }
346
347 #[tokio::test]
348 async fn test_threshold_zero_always_compresses() {
349 let memory = CompressedMemory::new(InMemoryMemory::new(), 0, 1);
350
351 memory.append("s1", Message::user("a")).await.unwrap();
352 memory.append("s1", Message::user("b")).await.unwrap();
353
354 let msgs = memory.messages("s1").await.unwrap();
355 assert_eq!(msgs.len(), 2);
358 assert!(msgs[0].content.contains("[Compressed"));
359 assert_eq!(msgs[1].content, "b");
360 }
361
362 #[tokio::test]
363 async fn test_keep_recent_exceeds_message_count() {
364 let memory = CompressedMemory::new(InMemoryMemory::new(), 0, 100);
365
366 memory
367 .append("s1", Message::user("only one"))
368 .await
369 .unwrap();
370
371 let msgs = memory.messages("s1").await.unwrap();
372 assert_eq!(msgs.len(), 1);
374 assert_eq!(msgs[0].content, "only one");
375 }
376}