sayr_engine/
memory.rs

1use crate::message::Message;
2use crate::storage::ConversationStore;
3
4/// In-memory transcript storage.
5#[derive(Default, Clone, Debug)]
6pub struct ConversationMemory {
7    messages: Vec<Message>,
8}
9
10impl ConversationMemory {
11    pub fn with_messages(messages: Vec<Message>) -> Self {
12        Self { messages }
13    }
14
15    pub fn push(&mut self, message: Message) {
16        self.messages.push(message);
17    }
18
19    pub fn iter(&self) -> impl DoubleEndedIterator<Item = &Message> + '_ {
20        self.messages.iter()
21    }
22
23    pub fn len(&self) -> usize {
24        self.messages.len()
25    }
26
27    pub fn is_empty(&self) -> bool {
28        self.messages.is_empty()
29    }
30}
31
32/// A conversation memory that persists messages through a pluggable backend.
33#[derive(Clone, Debug)]
34pub struct PersistentConversationMemory<S: ConversationStore> {
35    store: S,
36    inner: ConversationMemory,
37}
38
39impl<S: ConversationStore> PersistentConversationMemory<S> {
40    pub fn new(store: S) -> Self {
41        Self {
42            store,
43            inner: ConversationMemory::default(),
44        }
45    }
46
47    pub async fn load(mut self) -> crate::Result<Self> {
48        let stored = self.store.load().await?;
49        self.inner = ConversationMemory::with_messages(stored);
50        Ok(self)
51    }
52
53    pub fn as_memory(&self) -> &ConversationMemory {
54        &self.inner
55    }
56
57    pub async fn push(&mut self, message: Message) -> crate::Result<()> {
58        self.store.append(&message).await?;
59        self.inner.push(message);
60        Ok(())
61    }
62
63    pub async fn clear(&mut self) -> crate::Result<()> {
64        self.store.clear().await?;
65        self.inner = ConversationMemory::default();
66        Ok(())
67    }
68}
69
70// ─────────────────────────────────────────────────────────────────────────────
71// Memory Strategies
72// ─────────────────────────────────────────────────────────────────────────────
73
74/// Memory strategy trait for managing conversation context
75pub trait MemoryStrategy: Send + Sync {
76    /// Apply the strategy to get messages to send to the LLM
77    fn get_context_messages(&self, messages: &[Message]) -> Vec<Message>;
78
79    /// Name of the strategy
80    fn name(&self) -> &str;
81}
82
83/// Keep all messages (default, no limiting)
84#[derive(Clone, Default)]
85pub struct FullMemoryStrategy;
86
87impl MemoryStrategy for FullMemoryStrategy {
88    fn get_context_messages(&self, messages: &[Message]) -> Vec<Message> {
89        messages.to_vec()
90    }
91
92    fn name(&self) -> &str {
93        "full"
94    }
95}
96
97/// Keep only the last N messages (sliding window)
98#[derive(Clone)]
99pub struct WindowedMemoryStrategy {
100    window_size: usize,
101    keep_system: bool,
102}
103
104impl WindowedMemoryStrategy {
105    pub fn new(window_size: usize) -> Self {
106        Self {
107            window_size,
108            keep_system: true,
109        }
110    }
111
112    pub fn without_system(mut self) -> Self {
113        self.keep_system = false;
114        self
115    }
116}
117
118impl MemoryStrategy for WindowedMemoryStrategy {
119    fn get_context_messages(&self, messages: &[Message]) -> Vec<Message> {
120        use crate::message::Role;
121
122        if messages.len() <= self.window_size {
123            return messages.to_vec();
124        }
125
126        let mut result = Vec::new();
127
128        // Keep system messages if configured
129        if self.keep_system {
130            for msg in messages {
131                if msg.role == Role::System {
132                    result.push(msg.clone());
133                }
134            }
135        }
136
137        // Add the last N non-system messages
138        let non_system: Vec<&Message> = messages
139            .iter()
140            .filter(|m| m.role != Role::System)
141            .collect();
142
143        let start = if non_system.len() > self.window_size {
144            non_system.len() - self.window_size
145        } else {
146            0
147        };
148
149        for msg in &non_system[start..] {
150            result.push((*msg).clone());
151        }
152
153        result
154    }
155
156    fn name(&self) -> &str {
157        "windowed"
158    }
159}
160
161/// Keep first and last N messages, summarize the middle
162#[derive(Clone)]
163pub struct SummarizedMemoryStrategy {
164    /// Number of messages to keep at the start
165    keep_first: usize,
166    /// Number of messages to keep at the end
167    keep_last: usize,
168    /// Summary of the middle (set after summarization)
169    summary: Option<String>,
170}
171
172impl SummarizedMemoryStrategy {
173    pub fn new(keep_first: usize, keep_last: usize) -> Self {
174        Self {
175            keep_first,
176            keep_last,
177            summary: None,
178        }
179    }
180
181    pub fn with_summary(mut self, summary: impl Into<String>) -> Self {
182        self.summary = Some(summary.into());
183        self
184    }
185
186    /// Check if summarization is needed (more than keep_first + keep_last messages)
187    pub fn needs_summary(&self, messages: &[Message]) -> bool {
188        messages.len() > self.keep_first + self.keep_last
189    }
190
191    /// Get messages that need to be summarized
192    pub fn messages_to_summarize<'a>(&self, messages: &'a [Message]) -> &'a [Message] {
193        if messages.len() <= self.keep_first + self.keep_last {
194            return &[];
195        }
196        let end = messages.len() - self.keep_last;
197        &messages[self.keep_first..end]
198    }
199}
200
201impl MemoryStrategy for SummarizedMemoryStrategy {
202    fn get_context_messages(&self, messages: &[Message]) -> Vec<Message> {
203        if messages.len() <= self.keep_first + self.keep_last {
204            return messages.to_vec();
205        }
206
207        let mut result = Vec::new();
208
209        // Add first N messages
210        for msg in messages.iter().take(self.keep_first) {
211            result.push(msg.clone());
212        }
213
214        // Add summary as a system message if available
215        if let Some(ref summary) = self.summary {
216            result.push(Message::system(format!(
217                "[Summary of {} messages]: {}",
218                messages.len() - self.keep_first - self.keep_last,
219                summary
220            )));
221        }
222
223        // Add last N messages
224        let start = messages.len() - self.keep_last;
225        for msg in &messages[start..] {
226            result.push(msg.clone());
227        }
228
229        result
230    }
231
232    fn name(&self) -> &str {
233        "summarized"
234    }
235}
236
237/// Token-based memory limiting (approximate)
238#[derive(Clone)]
239pub struct TokenLimitedMemoryStrategy {
240    max_tokens: usize,
241    /// Approximate characters per token (default: 4)
242    chars_per_token: usize,
243}
244
245impl TokenLimitedMemoryStrategy {
246    pub fn new(max_tokens: usize) -> Self {
247        Self {
248            max_tokens,
249            chars_per_token: 4,
250        }
251    }
252
253    pub fn with_chars_per_token(mut self, chars: usize) -> Self {
254        self.chars_per_token = chars;
255        self
256    }
257
258    fn estimate_tokens(&self, content: &str) -> usize {
259        content.len() / self.chars_per_token
260    }
261}
262
263impl MemoryStrategy for TokenLimitedMemoryStrategy {
264    fn get_context_messages(&self, messages: &[Message]) -> Vec<Message> {
265        use crate::message::Role;
266
267        let mut result = Vec::new();
268        let mut total_tokens = 0;
269
270        // Always include system messages first
271        for msg in messages {
272            if msg.role == Role::System {
273                let tokens = self.estimate_tokens(&msg.content);
274                total_tokens += tokens;
275                result.push(msg.clone());
276            }
277        }
278
279        // Add messages from the end until we hit the limit
280        let non_system: Vec<&Message> = messages
281            .iter()
282            .filter(|m| m.role != Role::System)
283            .collect();
284
285        let mut temp = Vec::new();
286        for msg in non_system.iter().rev() {
287            let tokens = self.estimate_tokens(&msg.content);
288            if total_tokens + tokens > self.max_tokens {
289                break;
290            }
291            total_tokens += tokens;
292            temp.push((*msg).clone());
293        }
294
295        // Reverse to maintain chronological order
296        temp.reverse();
297        result.extend(temp);
298
299        result
300    }
301
302    fn name(&self) -> &str {
303        "token_limited"
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_windowed_strategy() {
313        let messages = vec![
314            Message::system("You are a helpful assistant"),
315            Message::user("Hello"),
316            Message::assistant("Hi there!"),
317            Message::user("How are you?"),
318            Message::assistant("I'm doing well!"),
319            Message::user("What's 2+2?"),
320            Message::assistant("4"),
321        ];
322
323        let strategy = WindowedMemoryStrategy::new(4);
324        let context = strategy.get_context_messages(&messages);
325
326        // Should keep system + last 4 non-system messages
327        assert_eq!(context.len(), 5); // 1 system + 4 recent
328        assert_eq!(context[0].content, "You are a helpful assistant");
329    }
330
331    #[test]
332    fn test_token_limited_strategy() {
333        let messages = vec![
334            Message::system("System"),
335            Message::user("A".repeat(100)),
336            Message::assistant("B".repeat(100)),
337            Message::user("C".repeat(100)),
338        ];
339
340        let strategy = TokenLimitedMemoryStrategy::new(50); // ~200 chars
341        let context = strategy.get_context_messages(&messages);
342
343        // Should keep system and fit as many recent messages as possible
344        assert!(context.len() <= messages.len());
345    }
346}
347