1use crate::memory_system::{
11 DecayPolicy, FileBackend, MemoryStore, StorageContainer, TimeBasedDecay,
12};
13use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier};
14use async_trait::async_trait;
15use parking_lot::RwLock;
16use std::collections::HashMap;
17use std::path::PathBuf;
18use std::sync::Arc;
19
20#[derive(Debug, Clone)]
22pub struct AutoSaveConfig {
23 pub enabled: bool,
25 pub interval_ms: u64,
27 pub save_on_store: bool,
29 pub min_changes: usize,
31}
32
33impl Default for AutoSaveConfig {
34 fn default() -> Self {
35 Self {
36 enabled: false,
37 interval_ms: 5000,
38 save_on_store: false,
39 min_changes: 5,
40 }
41 }
42}
43
44#[derive(Debug, Default)]
46struct DirtyFlag {
47 dirty: RwLock<bool>,
48}
49
50impl DirtyFlag {
51 fn new() -> Self {
52 Self {
53 dirty: RwLock::new(false),
54 }
55 }
56
57 fn mark_dirty(&self) {
58 *self.dirty.write() = true;
59 }
60
61 fn mark_clean(&self) {
62 *self.dirty.write() = false;
63 }
64
65 fn is_dirty(&self) -> bool {
66 *self.dirty.read()
67 }
68}
69
70pub struct SessionMemory {
79 storage: Arc<RwLock<HashMap<String, MemoryEntry>>>,
81 session_id: String,
83 #[allow(dead_code)]
85 decay_policy: Box<dyn DecayPolicy>,
86 file_backend: Option<Arc<dyn FileBackend>>,
88 auto_save_config: AutoSaveConfig,
90 changes_since_save: Arc<RwLock<usize>>,
92 dirty_flag: Arc<DirtyFlag>,
94 drop_guard: Arc<RwLock<bool>>,
96}
97
98impl SessionMemory {
99 pub fn new(session_id: impl Into<String>) -> Self {
101 Self {
102 storage: Arc::new(RwLock::new(HashMap::new())),
103 session_id: session_id.into(),
104 decay_policy: Box::new(TimeBasedDecay::default()),
105 file_backend: None,
106 auto_save_config: AutoSaveConfig::default(),
107 changes_since_save: Arc::new(RwLock::new(0)),
108 dirty_flag: Arc::new(DirtyFlag::new()),
109 drop_guard: Arc::new(RwLock::new(false)),
110 }
111 }
112
113 pub fn with_persistence(
115 session_id: impl Into<String>,
116 backend: Arc<dyn FileBackend>,
117 auto_save: AutoSaveConfig,
118 ) -> Self {
119 Self {
120 storage: Arc::new(RwLock::new(HashMap::new())),
121 session_id: session_id.into(),
122 decay_policy: Box::new(TimeBasedDecay::default()),
123 file_backend: Some(backend),
124 auto_save_config: auto_save,
125 changes_since_save: Arc::new(RwLock::new(0)),
126 dirty_flag: Arc::new(DirtyFlag::new()),
127 drop_guard: Arc::new(RwLock::new(false)),
128 }
129 }
130
131 pub fn with_persistent_storage(
149 session_id: impl Into<String>,
150 path: impl Into<PathBuf>,
151 auto_save_on_drop: bool,
152 ) -> Self {
153 let session_id = session_id.into();
154 let backend = Arc::new(crate::memory_system::JsonFileBackend::with_session_id(
155 path,
156 session_id.clone(),
157 ));
158 Self::with_persistence(
159 session_id,
160 backend,
161 AutoSaveConfig {
162 enabled: auto_save_on_drop,
163 save_on_store: false,
164 min_changes: 1,
165 ..Default::default()
166 },
167 )
168 }
169
170 pub fn with_json_backend(session_id: impl Into<String>, path: impl Into<PathBuf>) -> Self {
172 let backend = Arc::new(crate::memory_system::JsonFileBackend::new(path));
173 Self::with_persistence(
174 session_id,
175 backend,
176 AutoSaveConfig {
177 enabled: true,
178 save_on_store: true,
179 ..Default::default()
180 },
181 )
182 }
183
184 pub async fn load_from_file(
189 session_id: impl Into<String>,
190 path: impl Into<PathBuf>,
191 ) -> Layer3Result<Self> {
192 let session_id = session_id.into();
193 let path = path.into();
194 let backend = Arc::new(crate::memory_system::JsonFileBackend::with_session_id(
195 &path,
196 session_id.clone(),
197 ));
198
199 let container = backend.load_container().await?;
200 let storage: HashMap<String, MemoryEntry> = container
201 .entries
202 .into_iter()
203 .map(|e| (e.id.clone(), e))
204 .collect();
205
206 tracing::info!(
207 "Loaded session {} from {} ({} entries, version {})",
208 session_id,
209 path.display(),
210 storage.len(),
211 container.version
212 );
213
214 Ok(Self {
215 storage: Arc::new(RwLock::new(storage)),
216 session_id,
217 decay_policy: Box::new(TimeBasedDecay::default()),
218 file_backend: Some(backend),
219 auto_save_config: AutoSaveConfig {
220 enabled: true,
221 save_on_store: true,
222 ..Default::default()
223 },
224 changes_since_save: Arc::new(RwLock::new(0)),
225 dirty_flag: Arc::new(DirtyFlag::new()),
226 drop_guard: Arc::new(RwLock::new(false)),
227 })
228 }
229
230 pub async fn load_or_create(
232 session_id: impl Into<String>,
233 path: impl Into<PathBuf>,
234 ) -> Layer3Result<Self> {
235 let path = path.into();
236 if path.exists() {
237 Self::load_from_file(session_id, &path).await
238 } else {
239 Ok(Self::with_persistent_storage(session_id, &path, true))
240 }
241 }
242
243 pub async fn save(&self) -> Layer3Result<()> {
245 if let Some(backend) = &self.file_backend {
246 let entries: Vec<MemoryEntry> = self.storage.read().values().cloned().collect();
247 backend
248 .save_with_session(&self.session_id, &entries)
249 .await?;
250 *self.changes_since_save.write() = 0;
251 self.dirty_flag.mark_clean();
252 tracing::info!(
253 "Session {} saved to {}",
254 self.session_id,
255 backend.path().display()
256 );
257 }
258 Ok(())
259 }
260
261 fn save_sync(&self) -> Layer3Result<()> {
263 if let Some(backend) = &self.file_backend {
264 let entries: Vec<MemoryEntry> = self.storage.read().values().cloned().collect();
265
266 let json =
268 serde_json::to_string_pretty(&StorageContainer::new(&self.session_id, entries))?;
269
270 let path = backend.path().to_path_buf();
271 let temp_path = path.with_extension(format!("tmp.{}", std::process::id()));
272
273 std::fs::write(&temp_path, &json)?;
275 std::fs::rename(&temp_path, &path)?;
277
278 tracing::info!(
279 "Session {} saved (sync) to {}",
280 self.session_id,
281 path.display()
282 );
283 }
284 Ok(())
285 }
286
287 fn should_auto_save(&self) -> bool {
289 if !self.auto_save_config.enabled || self.file_backend.is_none() {
290 return false;
291 }
292
293 let changes = *self.changes_since_save.read();
294 if self.auto_save_config.save_on_store {
295 return true;
296 }
297
298 changes >= self.auto_save_config.min_changes
299 }
300
301 async fn maybe_auto_save(&self) -> Layer3Result<()> {
303 if self.should_auto_save() {
304 self.save().await?;
305 }
306 Ok(())
307 }
308
309 pub fn session_id(&self) -> &str {
311 &self.session_id
312 }
313
314 pub fn changes_since_save(&self) -> usize {
316 *self.changes_since_save.read()
317 }
318
319 pub fn is_dirty(&self) -> bool {
321 self.dirty_flag.is_dirty()
322 }
323
324 pub fn persistence_path(&self) -> Option<&std::path::Path> {
326 self.file_backend.as_ref().map(|b| b.path())
327 }
328}
329
330impl Drop for SessionMemory {
331 fn drop(&mut self) {
332 if *self.drop_guard.read() {
334 return;
335 }
336 *self.drop_guard.write() = true;
337
338 if self.dirty_flag.is_dirty()
340 && self.file_backend.is_some()
341 && self.auto_save_config.enabled
342 {
343 if let Err(e) = self.save_sync() {
344 tracing::error!("Failed to auto-save session {}: {}", self.session_id, e);
345 }
346 }
347 }
348}
349
350impl Default for SessionMemory {
351 fn default() -> Self {
352 Self::new("default")
353 }
354}
355
356#[async_trait]
357impl MemoryStore for SessionMemory {
358 fn tier(&self) -> MemoryTier {
359 MemoryTier::Session
360 }
361
362 async fn store(&self, entry: MemoryEntry) -> Layer3Result<String> {
363 let id = entry.id.clone();
364 {
365 let mut storage = self.storage.write();
366 storage.insert(id.clone(), entry);
367 }
368 *self.changes_since_save.write() += 1;
369 self.dirty_flag.mark_dirty();
370
371 self.maybe_auto_save().await?;
373
374 Ok(id)
375 }
376
377 async fn get(&self, id: &str) -> Layer3Result<Option<MemoryEntry>> {
378 Ok(self.storage.read().get(id).cloned())
379 }
380
381 async fn delete(&self, id: &str) -> Layer3Result<bool> {
382 let removed = self.storage.write().remove(id).is_some();
383 if removed {
384 *self.changes_since_save.write() += 1;
385 self.dirty_flag.mark_dirty();
386 drop(self.storage.write()); self.maybe_auto_save().await?;
388 }
389 Ok(removed)
390 }
391
392 async fn query(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
393 let storage = self.storage.read();
394 let results: Vec<MemoryEntry> = storage
395 .values()
396 .filter(|e| {
397 if let Some(tier) = query.tier {
398 if e.tier != tier {
399 return false;
400 }
401 }
402 e.content.contains(&query.query)
403 })
404 .take(query.limit.unwrap_or(10))
405 .cloned()
406 .collect();
407 Ok(results)
408 }
409
410 async fn list(&self, limit: Option<usize>) -> Layer3Result<Vec<MemoryEntry>> {
411 let storage = self.storage.read();
412 Ok(storage
413 .values()
414 .take(limit.unwrap_or(usize::MAX))
415 .cloned()
416 .collect())
417 }
418
419 async fn clear(&self) -> Layer3Result<usize> {
420 let count = {
421 let mut storage = self.storage.write();
422 let count = storage.len();
423 storage.clear();
424 count
425 };
426
427 if let Some(backend) = &self.file_backend {
429 backend.clear().await?;
430 }
431 *self.changes_since_save.write() = 0;
432 self.dirty_flag.mark_clean();
433
434 Ok(count)
435 }
436
437 async fn count(&self) -> Layer3Result<usize> {
438 Ok(self.storage.read().len())
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445 use crate::types::MemoryTier;
446 use tempfile::tempdir;
447
448 fn create_test_entry(id: &str, content: &str) -> MemoryEntry {
449 MemoryEntry {
450 id: id.to_string(),
451 tier: MemoryTier::Session,
452 content: content.to_string(),
453 metadata: Default::default(),
454 created_at: chrono::Utc::now(),
455 last_accessed: chrono::Utc::now(),
456 access_count: 0,
457 importance: 0.5,
458 }
459 }
460
461 #[tokio::test]
462 async fn test_session_memory() {
463 let memory = SessionMemory::new("test-session");
464 assert_eq!(memory.tier(), MemoryTier::Session);
465 assert!(!memory.is_dirty());
466 }
467
468 #[tokio::test]
469 async fn test_session_with_persistence() {
470 let dir = tempdir().unwrap();
471 let path = dir.path().join("session.json");
472
473 let memory = SessionMemory::with_json_backend("test-session", &path);
474 memory
475 .store(create_test_entry("1", "test content"))
476 .await
477 .unwrap();
478
479 assert!(path.exists());
481 assert!(!memory.is_dirty()); }
483
484 #[tokio::test]
485 async fn test_load_from_file() {
486 let dir = tempdir().unwrap();
487 let path = dir.path().join("session.json");
488
489 let memory = SessionMemory::with_json_backend("test-session", &path);
491 memory
492 .store(create_test_entry("1", "saved content"))
493 .await
494 .unwrap();
495 memory.save().await.unwrap();
496
497 let loaded = SessionMemory::load_from_file("test-session", &path)
499 .await
500 .unwrap();
501 let entry = loaded.get("1").await.unwrap().unwrap();
502 assert_eq!(entry.content, "saved content");
503 }
504
505 #[tokio::test]
506 async fn test_manual_save() {
507 let dir = tempdir().unwrap();
508 let path = dir.path().join("session.json");
509
510 let memory = SessionMemory::new("test-session");
511 memory.save().await.unwrap();
513 assert!(!path.exists());
514 }
515
516 #[tokio::test]
517 async fn test_dirty_flag() {
518 let memory = SessionMemory::new("test-session");
519
520 assert!(!memory.is_dirty());
521
522 memory.store(create_test_entry("1", "test")).await.unwrap();
523 assert!(memory.is_dirty());
525
526 let dir = tempdir().unwrap();
527 let path = dir.path().join("session.json");
528 let persistent = SessionMemory::with_persistent_storage("test", &path, false);
529
530 assert!(!persistent.is_dirty());
531 persistent
532 .store(create_test_entry("1", "test"))
533 .await
534 .unwrap();
535 assert!(persistent.is_dirty());
536
537 persistent.save().await.unwrap();
538 assert!(!persistent.is_dirty());
539 }
540
541 #[tokio::test]
542 async fn test_drop_auto_save() {
543 let dir = tempdir().unwrap();
544 let path = dir.path().join("drop_save.json");
545
546 {
547 let memory = SessionMemory::with_persistent_storage("test-drop", &path, true);
548 memory
549 .store(create_test_entry("drop-1", "content before drop"))
550 .await
551 .unwrap();
552 }
555
556 assert!(path.exists());
558
559 let loaded = SessionMemory::load_from_file("test-drop", &path)
561 .await
562 .unwrap();
563 let entry = loaded.get("drop-1").await.unwrap().unwrap();
564 assert_eq!(entry.content, "content before drop");
565 }
566
567 #[tokio::test]
568 async fn test_load_or_create() {
569 let dir = tempdir().unwrap();
570 let existing_path = dir.path().join("existing.json");
571 let new_path = dir.path().join("new.json");
572
573 {
575 let memory = SessionMemory::with_json_backend("existing", &existing_path);
576 memory
577 .store(create_test_entry("existing-1", "existing content"))
578 .await
579 .unwrap();
580 }
581
582 let loaded = SessionMemory::load_or_create("existing", &existing_path)
584 .await
585 .unwrap();
586 assert!(loaded.get("existing-1").await.unwrap().is_some());
587
588 let new_memory = SessionMemory::load_or_create("new", &new_path)
590 .await
591 .unwrap();
592 assert!(!new_path.exists()); new_memory
594 .store(create_test_entry("new-1", "new content"))
595 .await
596 .unwrap();
597 assert!(new_path.exists()); }
599
600 #[tokio::test]
601 async fn test_version_migration() {
602 let dir = tempdir().unwrap();
603 let path = dir.path().join("legacy_session.json");
604
605 let legacy_entries = vec![create_test_entry("legacy-1", "legacy content")];
607 let legacy_json = serde_json::to_string_pretty(&legacy_entries).unwrap();
608 std::fs::write(&path, legacy_json).unwrap();
609
610 let loaded = SessionMemory::load_from_file("migrated", &path)
612 .await
613 .unwrap();
614 let entry = loaded.get("legacy-1").await.unwrap().unwrap();
615 assert_eq!(entry.content, "legacy content");
616 }
617
618 #[tokio::test]
619 async fn test_persistence_path() {
620 let memory = SessionMemory::new("test");
621 assert!(memory.persistence_path().is_none());
622
623 let dir = tempdir().unwrap();
624 let path = dir.path().join("session.json");
625 let persistent = SessionMemory::with_json_backend("test", &path);
626
627 let stored_path = persistent.persistence_path().unwrap();
628 assert_eq!(stored_path, path);
629 }
630
631 #[tokio::test]
632 async fn test_thread_safety() {
633 let dir = tempdir().unwrap();
634 let path = dir.path().join("concurrent.json");
635 let memory = Arc::new(SessionMemory::with_json_backend("concurrent", &path));
636
637 for i in 0..10 {
641 memory
642 .store(create_test_entry(&format!("entry-{}", i), "content"))
643 .await
644 .unwrap();
645 }
646
647 assert!(path.exists());
649 let loaded = SessionMemory::load_from_file("concurrent", &path)
650 .await
651 .unwrap();
652 assert_eq!(loaded.count().await.unwrap(), 10);
653 }
654
655 #[test]
656 fn test_dirty_flag_thread_safety() {
657 let flag = Arc::new(DirtyFlag::new());
658 let mut handles = vec![];
659
660 for _ in 0..100 {
661 let f = flag.clone();
662 let handle = std::thread::spawn(move || {
663 f.mark_dirty();
664 assert!(f.is_dirty());
665 });
666 handles.push(handle);
667 }
668
669 for handle in handles {
670 handle.join().unwrap();
671 }
672
673 assert!(flag.is_dirty());
674 }
675}