1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::VecDeque;
6
7use crate::compression::TemporalCompressor;
8use crate::horizon::HorizonConfig;
9use vex_persist::VectorStoreBackend;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Episode {
14 pub id: u64,
16 pub content: String,
18 pub created_at: DateTime<Utc>,
20 pub base_importance: f64,
22 pub pinned: bool,
24 pub tags: Vec<String>,
26}
27
28impl Episode {
29 pub fn new(content: &str, importance: f64) -> Self {
31 static COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
32 Self {
33 id: COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
34 content: content.to_string(),
35 created_at: Utc::now(),
36 base_importance: importance.clamp(0.0, 1.0),
37 pinned: false,
38 tags: Vec::new(),
39 }
40 }
41
42 pub fn pinned(content: &str) -> Self {
44 let mut ep = Self::new(content, 1.0);
45 ep.pinned = true;
46 ep
47 }
48
49 pub fn with_tag(mut self, tag: &str) -> Self {
51 self.tags.push(tag.to_string());
52 self
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct EpisodicMemory {
59 pub config: HorizonConfig,
61 pub compressor: TemporalCompressor,
63 episodes: VecDeque<Episode>,
65}
66
67impl EpisodicMemory {
68 pub fn new(config: HorizonConfig) -> Self {
70 let max_age = config
71 .horizon
72 .duration()
73 .unwrap_or(chrono::Duration::weeks(52));
74 Self {
75 config,
76 compressor: TemporalCompressor::new(
77 crate::compression::DecayStrategy::Exponential,
78 max_age,
79 ),
80 episodes: VecDeque::new(),
81 }
82 }
83
84 pub fn add(&mut self, episode: Episode) {
86 self.episodes.push_front(episode);
87 self.maybe_evict();
88 }
89
90 pub fn remember(&mut self, content: &str, importance: f64) {
92 self.add(Episode::new(content, importance));
93 }
94
95 pub fn episodes(&self) -> impl Iterator<Item = &Episode> {
97 self.episodes.iter()
98 }
99
100 pub fn by_tag(&self, tag: &str) -> Vec<&Episode> {
102 self.episodes
103 .iter()
104 .filter(|e| e.tags.contains(&tag.to_string()))
105 .collect()
106 }
107
108 pub fn recent(&self) -> Vec<&Episode> {
110 self.episodes
111 .iter()
112 .filter(|e| self.config.horizon.contains(e.created_at))
113 .collect()
114 }
115
116 pub fn by_importance(&self) -> Vec<(&Episode, f64)> {
118 let mut episodes: Vec<_> = self
119 .episodes
120 .iter()
121 .map(|e| {
122 let importance = if e.pinned {
123 1.0
124 } else {
125 self.compressor.importance(e.created_at, e.base_importance)
126 };
127 (e, importance)
128 })
129 .collect();
130 episodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
131 episodes
132 }
133
134 pub fn len(&self) -> usize {
136 self.episodes.len()
137 }
138
139 pub fn is_empty(&self) -> bool {
141 self.episodes.is_empty()
142 }
143
144 pub fn clear(&mut self) {
146 self.episodes.retain(|e| e.pinned);
147 }
148
149 fn maybe_evict(&mut self) {
152 if !self.config.auto_evict {
153 return;
154 }
155
156 let max_age_ids: std::collections::HashSet<u64> = self
164 .episodes
165 .iter()
166 .filter(|e| !e.pinned && self.compressor.should_evict(e.created_at))
167 .map(|e| e.id)
168 .collect();
169
170 if !max_age_ids.is_empty() {
171 self.episodes.retain(|e| !max_age_ids.contains(&e.id));
172 }
173
174 let current_len = self.episodes.len();
176 if current_len <= self.config.max_entries {
177 return;
178 }
179
180 let pinned_count = self.episodes.iter().filter(|e| e.pinned).count();
183 if pinned_count >= self.config.max_entries {
184 self.episodes.retain(|e| e.pinned);
185 return;
186 }
187
188 let slots_for_non_pinned = self.config.max_entries - pinned_count;
189
190 let mut candidates: Vec<(f64, DateTime<Utc>, u64)> = self
193 .episodes
194 .iter()
195 .filter(|e| !e.pinned)
196 .map(|e| {
197 (
198 self.compressor.importance(e.created_at, e.base_importance),
199 e.created_at,
200 e.id,
201 )
202 })
203 .collect();
204
205 if candidates.len() > slots_for_non_pinned {
208 let target_idx = candidates.len() - slots_for_non_pinned;
211
212 candidates.select_nth_unstable_by(target_idx, |a, b| {
214 a.0.partial_cmp(&b.0)
215 .unwrap_or(std::cmp::Ordering::Equal)
216 .then_with(|| a.1.cmp(&b.1))
217 });
218
219 let keep_ids: std::collections::HashSet<u64> =
222 candidates[target_idx..].iter().map(|c| c.2).collect();
223
224 self.episodes
226 .retain(|e| e.pinned || keep_ids.contains(&e.id));
227 }
228 }
229
230 pub fn compress_old(&mut self) -> usize {
232 if !self.config.auto_compress {
233 return 0;
234 }
235
236 let mut count = 0;
237 for episode in &mut self.episodes {
238 if episode.pinned {
239 continue;
240 }
241
242 let ratio = self.compressor.compression_ratio(episode.created_at);
243 if ratio > 0.1 {
244 episode.content = self.compressor.compress(&episode.content, ratio);
245 count += 1;
246 }
247 }
248 count
249 }
250
251 pub async fn compress_old_with_llm<L: vex_llm::LlmProvider + vex_llm::EmbeddingProvider>(
254 &mut self,
255 llm: &L,
256 vector_store: Option<&dyn VectorStoreBackend>,
257 tenant_id: Option<&str>,
258 ) -> usize {
259 if !self.config.auto_compress {
260 return 0;
261 }
262
263 let mut count = 0;
264 for episode in &mut self.episodes {
265 if episode.pinned {
266 continue;
267 }
268
269 let ratio = self.compressor.compression_ratio(episode.created_at);
270 if ratio > 0.1 {
271 match self
272 .compressor
273 .compress_with_llm(&episode.content, ratio, llm, vector_store, tenant_id)
274 .await
275 {
276 Ok(compressed) => {
277 tracing::debug!(
278 episode_id = %episode.id,
279 original_len = episode.content.len(),
280 compressed_len = compressed.len(),
281 ratio = ratio,
282 "Compressed episode with LLM"
283 );
284 episode.content = compressed;
285 count += 1;
286 }
287 Err(e) => {
288 tracing::warn!("LLM compression failed for episode {}: {}", episode.id, e);
289 episode.content = self.compressor.compress(&episode.content, ratio);
291 count += 1;
292 }
293 }
294 }
295 }
296 count
297 }
298
299 pub async fn summarize_all_with_llm<L: vex_llm::LlmProvider>(
303 &mut self,
304 llm: &L,
305 ) -> Result<String, vex_llm::LlmError> {
306 if self.episodes.is_empty() {
307 return Ok(String::from("No memories recorded."));
308 }
309
310 let all_content: String = self
312 .episodes
313 .iter()
314 .map(|e| {
315 format!(
316 "[{}] (importance: {:.1}): {}",
317 e.created_at.format("%Y-%m-%d %H:%M"),
318 e.base_importance,
319 e.content
320 )
321 })
322 .collect::<Vec<_>>()
323 .join("\n\n");
324
325 let prompt = format!(
326 "You are a memory consolidation system. Summarize the following episodic memories \
327 into a coherent narrative that preserves the most important information, decisions, \
328 and context. Focus on factual content and key events.\n\n\
329 MEMORIES:\n{}\n\n\
330 CONSOLIDATED SUMMARY:",
331 all_content
332 );
333
334 let summary = llm.ask(&prompt).await.map(|s| s.trim().to_string())?;
335
336 let mut pinned_ep = Episode::pinned(&summary);
338 pinned_ep.tags.push("consolidated_summary".to_string());
339 self.episodes.push_front(pinned_ep);
340
341 Ok(summary)
342 }
343
344 pub fn summarize(&self) -> String {
346 let total = self.len();
347 let pinned = self.episodes.iter().filter(|e| e.pinned).count();
348 let recent = self.recent().len();
349
350 format!(
351 "Memory: {} total ({} pinned, {} recent within {:?})",
352 total, pinned, recent, self.config.horizon
353 )
354 }
355}
356
357impl Default for EpisodicMemory {
358 fn default() -> Self {
359 Self::new(HorizonConfig::default())
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_episodic_memory() {
369 let mut memory = EpisodicMemory::default();
370
371 memory.remember("First event", 0.8);
372 memory.remember("Second event", 0.5);
373 memory.add(Episode::pinned("Important system info"));
374
375 assert_eq!(memory.len(), 3);
376 assert_eq!(memory.recent().len(), 3);
377 }
378
379 #[test]
380 fn test_by_importance() {
381 let mut memory = EpisodicMemory::default();
382
383 memory.remember("Low importance", 0.2);
384 memory.remember("High importance", 0.9);
385
386 let sorted = memory.by_importance();
387 assert!(sorted[0].1 > sorted[1].1);
388 }
389
390 #[test]
391 fn test_pinned_not_evicted() {
392 let config = HorizonConfig {
393 max_entries: 2,
394 ..Default::default()
395 };
396
397 let mut memory = EpisodicMemory::new(config);
398 memory.add(Episode::pinned("System"));
399 memory.remember("Event 1", 0.5);
400 memory.remember("Event 2", 0.5);
401 memory.remember("Event 3", 0.5);
402
403 assert!(memory.episodes().any(|e| e.content == "System"));
405 }
406}