Skip to main content

sh_layer3/memory_system/
session.rs

1//! # Session Memory
2//!
3//! 会话记忆:单次会话内的持久化存储。
4//!
5//! 支持可选的文件后端持久化,包括:
6//! - 自动保存(Drop 时 dirty flag)
7//! - 原子写入
8//! - 版本迁移
9
10use crate::memory_system::{
11    DecayPolicy, FileBackend, MemoryStore, StorageContainer, TimeBasedDecay,
12};
13use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier};
14use async_trait::async_trait;
15use parking_lot::RwLock;
16use std::collections::HashMap;
17use std::path::PathBuf;
18use std::sync::Arc;
19
20/// Auto-save configuration
21#[derive(Debug, Clone)]
22pub struct AutoSaveConfig {
23    /// Enable auto-save
24    pub enabled: bool,
25    /// Save interval in milliseconds
26    pub interval_ms: u64,
27    /// Save on every store operation
28    pub save_on_store: bool,
29    /// Minimum changes before save
30    pub min_changes: usize,
31}
32
33impl Default for AutoSaveConfig {
34    fn default() -> Self {
35        Self {
36            enabled: false,
37            interval_ms: 5000,
38            save_on_store: false,
39            min_changes: 5,
40        }
41    }
42}
43
44/// Dirty flag for tracking unsaved changes
45#[derive(Debug, Default)]
46struct DirtyFlag {
47    dirty: RwLock<bool>,
48}
49
50impl DirtyFlag {
51    fn new() -> Self {
52        Self {
53            dirty: RwLock::new(false),
54        }
55    }
56
57    fn mark_dirty(&self) {
58        *self.dirty.write() = true;
59    }
60
61    fn mark_clean(&self) {
62        *self.dirty.write() = false;
63    }
64
65    fn is_dirty(&self) -> bool {
66        *self.dirty.read()
67    }
68}
69
70/// Session Memory 实现
71///
72/// 使用 HashMap 存储会话期间的记忆,支持可选的文件持久化。
73///
74/// 特性:
75/// - Drop 时自动保存(如果 dirty)
76/// - 原子文件写入
77/// - 版本迁移支持
78pub struct SessionMemory {
79    /// 存储
80    storage: Arc<RwLock<HashMap<String, MemoryEntry>>>,
81    /// 会话 ID
82    session_id: String,
83    /// 衰减策略(未来功能预留)
84    #[allow(dead_code)]
85    decay_policy: Box<dyn DecayPolicy>,
86    /// 文件后端(可选)
87    file_backend: Option<Arc<dyn FileBackend>>,
88    /// 自动保存配置
89    auto_save_config: AutoSaveConfig,
90    /// 自上次保存后的变更计数
91    changes_since_save: Arc<RwLock<usize>>,
92    /// Dirty flag for Drop auto-save
93    dirty_flag: Arc<DirtyFlag>,
94    /// Prevent Drop from running multiple times
95    drop_guard: Arc<RwLock<bool>>,
96}
97
98impl SessionMemory {
99    /// 创建新的 Session Memory(无持久化)
100    pub fn new(session_id: impl Into<String>) -> Self {
101        Self {
102            storage: Arc::new(RwLock::new(HashMap::new())),
103            session_id: session_id.into(),
104            decay_policy: Box::new(TimeBasedDecay::default()),
105            file_backend: None,
106            auto_save_config: AutoSaveConfig::default(),
107            changes_since_save: Arc::new(RwLock::new(0)),
108            dirty_flag: Arc::new(DirtyFlag::new()),
109            drop_guard: Arc::new(RwLock::new(false)),
110        }
111    }
112
113    /// 创建带文件持久化的 Session Memory
114    pub fn with_persistence(
115        session_id: impl Into<String>,
116        backend: Arc<dyn FileBackend>,
117        auto_save: AutoSaveConfig,
118    ) -> Self {
119        Self {
120            storage: Arc::new(RwLock::new(HashMap::new())),
121            session_id: session_id.into(),
122            decay_policy: Box::new(TimeBasedDecay::default()),
123            file_backend: Some(backend),
124            auto_save_config: auto_save,
125            changes_since_save: Arc::new(RwLock::new(0)),
126            dirty_flag: Arc::new(DirtyFlag::new()),
127            drop_guard: Arc::new(RwLock::new(false)),
128        }
129    }
130
131    /// 创建带持久化存储的 Session Memory(推荐使用)
132    ///
133    /// 这是 `with_persistence` 的简化版本,使用 JSON 后端。
134    ///
135    /// # 参数
136    /// - `session_id`: 会话唯一标识符
137    /// - `path`: 存储文件路径
138    /// - `auto_save_on_drop`: 是否在 Drop 时自动保存
139    ///
140    /// # 示例
141    /// ```ignore
142    /// let memory = SessionMemory::with_persistent_storage(
143    ///     "my-session",
144    ///     "session.json",
145    ///     true
146    /// );
147    /// ```
148    pub fn with_persistent_storage(
149        session_id: impl Into<String>,
150        path: impl Into<PathBuf>,
151        auto_save_on_drop: bool,
152    ) -> Self {
153        let session_id = session_id.into();
154        let backend = Arc::new(crate::memory_system::JsonFileBackend::with_session_id(
155            path,
156            session_id.clone(),
157        ));
158        Self::with_persistence(
159            session_id,
160            backend,
161            AutoSaveConfig {
162                enabled: auto_save_on_drop,
163                save_on_store: false,
164                min_changes: 1,
165                ..Default::default()
166            },
167        )
168    }
169
170    /// 创建带 JSON 文件持久化的 Session Memory
171    pub fn with_json_backend(session_id: impl Into<String>, path: impl Into<PathBuf>) -> Self {
172        let backend = Arc::new(crate::memory_system::JsonFileBackend::new(path));
173        Self::with_persistence(
174            session_id,
175            backend,
176            AutoSaveConfig {
177                enabled: true,
178                save_on_store: true,
179                ..Default::default()
180            },
181        )
182    }
183
184    /// 从文件加载已保存的会话
185    ///
186    /// 如果文件不存在,创建新的空会话。
187    /// 如果文件存在,加载其中的数据并支持版本迁移。
188    pub async fn load_from_file(
189        session_id: impl Into<String>,
190        path: impl Into<PathBuf>,
191    ) -> Layer3Result<Self> {
192        let session_id = session_id.into();
193        let path = path.into();
194        let backend = Arc::new(crate::memory_system::JsonFileBackend::with_session_id(
195            &path,
196            session_id.clone(),
197        ));
198
199        let container = backend.load_container().await?;
200        let storage: HashMap<String, MemoryEntry> = container
201            .entries
202            .into_iter()
203            .map(|e| (e.id.clone(), e))
204            .collect();
205
206        tracing::info!(
207            "Loaded session {} from {} ({} entries, version {})",
208            session_id,
209            path.display(),
210            storage.len(),
211            container.version
212        );
213
214        Ok(Self {
215            storage: Arc::new(RwLock::new(storage)),
216            session_id,
217            decay_policy: Box::new(TimeBasedDecay::default()),
218            file_backend: Some(backend),
219            auto_save_config: AutoSaveConfig {
220                enabled: true,
221                save_on_store: true,
222                ..Default::default()
223            },
224            changes_since_save: Arc::new(RwLock::new(0)),
225            dirty_flag: Arc::new(DirtyFlag::new()),
226            drop_guard: Arc::new(RwLock::new(false)),
227        })
228    }
229
230    /// 尝试从文件加载,失败则创建新会话
231    pub async fn load_or_create(
232        session_id: impl Into<String>,
233        path: impl Into<PathBuf>,
234    ) -> Layer3Result<Self> {
235        let path = path.into();
236        if path.exists() {
237            Self::load_from_file(session_id, &path).await
238        } else {
239            Ok(Self::with_persistent_storage(session_id, &path, true))
240        }
241    }
242
243    /// 手动保存到文件
244    pub async fn save(&self) -> Layer3Result<()> {
245        if let Some(backend) = &self.file_backend {
246            let entries: Vec<MemoryEntry> = self.storage.read().values().cloned().collect();
247            backend
248                .save_with_session(&self.session_id, &entries)
249                .await?;
250            *self.changes_since_save.write() = 0;
251            self.dirty_flag.mark_clean();
252            tracing::info!(
253                "Session {} saved to {}",
254                self.session_id,
255                backend.path().display()
256            );
257        }
258        Ok(())
259    }
260
261    /// 同步保存(阻塞版本,用于 Drop)
262    fn save_sync(&self) -> Layer3Result<()> {
263        if let Some(backend) = &self.file_backend {
264            let entries: Vec<MemoryEntry> = self.storage.read().values().cloned().collect();
265
266            // Use blocking file write for Drop
267            let json =
268                serde_json::to_string_pretty(&StorageContainer::new(&self.session_id, entries))?;
269
270            let path = backend.path().to_path_buf();
271            let temp_path = path.with_extension(format!("tmp.{}", std::process::id()));
272
273            // Write to temp file
274            std::fs::write(&temp_path, &json)?;
275            // Atomic rename
276            std::fs::rename(&temp_path, &path)?;
277
278            tracing::info!(
279                "Session {} saved (sync) to {}",
280                self.session_id,
281                path.display()
282            );
283        }
284        Ok(())
285    }
286
287    /// 检查是否需要自动保存
288    fn should_auto_save(&self) -> bool {
289        if !self.auto_save_config.enabled || self.file_backend.is_none() {
290            return false;
291        }
292
293        let changes = *self.changes_since_save.read();
294        if self.auto_save_config.save_on_store {
295            return true;
296        }
297
298        changes >= self.auto_save_config.min_changes
299    }
300
301    /// 执行自动保存(如果需要)
302    async fn maybe_auto_save(&self) -> Layer3Result<()> {
303        if self.should_auto_save() {
304            self.save().await?;
305        }
306        Ok(())
307    }
308
309    /// 获取会话 ID
310    pub fn session_id(&self) -> &str {
311        &self.session_id
312    }
313
314    /// 获取变更计数
315    pub fn changes_since_save(&self) -> usize {
316        *self.changes_since_save.read()
317    }
318
319    /// 检查是否有未保存的变更
320    pub fn is_dirty(&self) -> bool {
321        self.dirty_flag.is_dirty()
322    }
323
324    /// 获取文件后端路径(如果有)
325    pub fn persistence_path(&self) -> Option<&std::path::Path> {
326        self.file_backend.as_ref().map(|b| b.path())
327    }
328}
329
330impl Drop for SessionMemory {
331    fn drop(&mut self) {
332        // Prevent double-drop
333        if *self.drop_guard.read() {
334            return;
335        }
336        *self.drop_guard.write() = true;
337
338        // Auto-save if dirty and backend exists
339        if self.dirty_flag.is_dirty()
340            && self.file_backend.is_some()
341            && self.auto_save_config.enabled
342        {
343            if let Err(e) = self.save_sync() {
344                tracing::error!("Failed to auto-save session {}: {}", self.session_id, e);
345            }
346        }
347    }
348}
349
350impl Default for SessionMemory {
351    fn default() -> Self {
352        Self::new("default")
353    }
354}
355
356#[async_trait]
357impl MemoryStore for SessionMemory {
358    fn tier(&self) -> MemoryTier {
359        MemoryTier::Session
360    }
361
362    async fn store(&self, entry: MemoryEntry) -> Layer3Result<String> {
363        let id = entry.id.clone();
364        {
365            let mut storage = self.storage.write();
366            storage.insert(id.clone(), entry);
367        }
368        *self.changes_since_save.write() += 1;
369        self.dirty_flag.mark_dirty();
370
371        // Auto-save if configured
372        self.maybe_auto_save().await?;
373
374        Ok(id)
375    }
376
377    async fn get(&self, id: &str) -> Layer3Result<Option<MemoryEntry>> {
378        Ok(self.storage.read().get(id).cloned())
379    }
380
381    async fn delete(&self, id: &str) -> Layer3Result<bool> {
382        let removed = self.storage.write().remove(id).is_some();
383        if removed {
384            *self.changes_since_save.write() += 1;
385            self.dirty_flag.mark_dirty();
386            drop(self.storage.write()); // Ensure lock is dropped before await
387            self.maybe_auto_save().await?;
388        }
389        Ok(removed)
390    }
391
392    async fn query(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
393        let storage = self.storage.read();
394        let results: Vec<MemoryEntry> = storage
395            .values()
396            .filter(|e| {
397                if let Some(tier) = query.tier {
398                    if e.tier != tier {
399                        return false;
400                    }
401                }
402                e.content.contains(&query.query)
403            })
404            .take(query.limit.unwrap_or(10))
405            .cloned()
406            .collect();
407        Ok(results)
408    }
409
410    async fn list(&self, limit: Option<usize>) -> Layer3Result<Vec<MemoryEntry>> {
411        let storage = self.storage.read();
412        Ok(storage
413            .values()
414            .take(limit.unwrap_or(usize::MAX))
415            .cloned()
416            .collect())
417    }
418
419    async fn clear(&self) -> Layer3Result<usize> {
420        let count = {
421            let mut storage = self.storage.write();
422            let count = storage.len();
423            storage.clear();
424            count
425        };
426
427        // Clear file backend too
428        if let Some(backend) = &self.file_backend {
429            backend.clear().await?;
430        }
431        *self.changes_since_save.write() = 0;
432        self.dirty_flag.mark_clean();
433
434        Ok(count)
435    }
436
437    async fn count(&self) -> Layer3Result<usize> {
438        Ok(self.storage.read().len())
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use crate::types::MemoryTier;
446    use tempfile::tempdir;
447
448    fn create_test_entry(id: &str, content: &str) -> MemoryEntry {
449        MemoryEntry {
450            id: id.to_string(),
451            tier: MemoryTier::Session,
452            content: content.to_string(),
453            metadata: Default::default(),
454            created_at: chrono::Utc::now(),
455            last_accessed: chrono::Utc::now(),
456            access_count: 0,
457            importance: 0.5,
458        }
459    }
460
461    #[tokio::test]
462    async fn test_session_memory() {
463        let memory = SessionMemory::new("test-session");
464        assert_eq!(memory.tier(), MemoryTier::Session);
465        assert!(!memory.is_dirty());
466    }
467
468    #[tokio::test]
469    async fn test_session_with_persistence() {
470        let dir = tempdir().unwrap();
471        let path = dir.path().join("session.json");
472
473        let memory = SessionMemory::with_json_backend("test-session", &path);
474        memory
475            .store(create_test_entry("1", "test content"))
476            .await
477            .unwrap();
478
479        // Should auto-save
480        assert!(path.exists());
481        assert!(!memory.is_dirty()); // Saved, so not dirty
482    }
483
484    #[tokio::test]
485    async fn test_load_from_file() {
486        let dir = tempdir().unwrap();
487        let path = dir.path().join("session.json");
488
489        // First save some data
490        let memory = SessionMemory::with_json_backend("test-session", &path);
491        memory
492            .store(create_test_entry("1", "saved content"))
493            .await
494            .unwrap();
495        memory.save().await.unwrap();
496
497        // Load from file
498        let loaded = SessionMemory::load_from_file("test-session", &path)
499            .await
500            .unwrap();
501        let entry = loaded.get("1").await.unwrap().unwrap();
502        assert_eq!(entry.content, "saved content");
503    }
504
505    #[tokio::test]
506    async fn test_manual_save() {
507        let dir = tempdir().unwrap();
508        let path = dir.path().join("session.json");
509
510        let memory = SessionMemory::new("test-session");
511        // No backend, save should succeed but do nothing
512        memory.save().await.unwrap();
513        assert!(!path.exists());
514    }
515
516    #[tokio::test]
517    async fn test_dirty_flag() {
518        let memory = SessionMemory::new("test-session");
519
520        assert!(!memory.is_dirty());
521
522        memory.store(create_test_entry("1", "test")).await.unwrap();
523        // Dirty flag tracks unsaved changes even without backend
524        assert!(memory.is_dirty());
525
526        let dir = tempdir().unwrap();
527        let path = dir.path().join("session.json");
528        let persistent = SessionMemory::with_persistent_storage("test", &path, false);
529
530        assert!(!persistent.is_dirty());
531        persistent
532            .store(create_test_entry("1", "test"))
533            .await
534            .unwrap();
535        assert!(persistent.is_dirty());
536
537        persistent.save().await.unwrap();
538        assert!(!persistent.is_dirty());
539    }
540
541    #[tokio::test]
542    async fn test_drop_auto_save() {
543        let dir = tempdir().unwrap();
544        let path = dir.path().join("drop_save.json");
545
546        {
547            let memory = SessionMemory::with_persistent_storage("test-drop", &path, true);
548            memory
549                .store(create_test_entry("drop-1", "content before drop"))
550                .await
551                .unwrap();
552            // Auto-save triggers when changes >= min_changes (1), so dirty may be cleared
553            // Drop happens here
554        }
555
556        // File should exist after drop
557        assert!(path.exists());
558
559        // Load and verify
560        let loaded = SessionMemory::load_from_file("test-drop", &path)
561            .await
562            .unwrap();
563        let entry = loaded.get("drop-1").await.unwrap().unwrap();
564        assert_eq!(entry.content, "content before drop");
565    }
566
567    #[tokio::test]
568    async fn test_load_or_create() {
569        let dir = tempdir().unwrap();
570        let existing_path = dir.path().join("existing.json");
571        let new_path = dir.path().join("new.json");
572
573        // Create existing file
574        {
575            let memory = SessionMemory::with_json_backend("existing", &existing_path);
576            memory
577                .store(create_test_entry("existing-1", "existing content"))
578                .await
579                .unwrap();
580        }
581
582        // Load existing
583        let loaded = SessionMemory::load_or_create("existing", &existing_path)
584            .await
585            .unwrap();
586        assert!(loaded.get("existing-1").await.unwrap().is_some());
587
588        // Create new
589        let new_memory = SessionMemory::load_or_create("new", &new_path)
590            .await
591            .unwrap();
592        assert!(!new_path.exists()); // Not saved yet
593        new_memory
594            .store(create_test_entry("new-1", "new content"))
595            .await
596            .unwrap();
597        assert!(new_path.exists()); // Auto-saved
598    }
599
600    #[tokio::test]
601    async fn test_version_migration() {
602        let dir = tempdir().unwrap();
603        let path = dir.path().join("legacy_session.json");
604
605        // Write legacy format (just array of entries)
606        let legacy_entries = vec![create_test_entry("legacy-1", "legacy content")];
607        let legacy_json = serde_json::to_string_pretty(&legacy_entries).unwrap();
608        std::fs::write(&path, legacy_json).unwrap();
609
610        // Load should migrate
611        let loaded = SessionMemory::load_from_file("migrated", &path)
612            .await
613            .unwrap();
614        let entry = loaded.get("legacy-1").await.unwrap().unwrap();
615        assert_eq!(entry.content, "legacy content");
616    }
617
618    #[tokio::test]
619    async fn test_persistence_path() {
620        let memory = SessionMemory::new("test");
621        assert!(memory.persistence_path().is_none());
622
623        let dir = tempdir().unwrap();
624        let path = dir.path().join("session.json");
625        let persistent = SessionMemory::with_json_backend("test", &path);
626
627        let stored_path = persistent.persistence_path().unwrap();
628        assert_eq!(stored_path, path);
629    }
630
631    #[tokio::test]
632    async fn test_thread_safety() {
633        let dir = tempdir().unwrap();
634        let path = dir.path().join("concurrent.json");
635        let memory = Arc::new(SessionMemory::with_json_backend("concurrent", &path));
636
637        // Sequential stores to avoid Windows file locking issues
638        // Note: Concurrent stores with save_on_store=true can cause race conditions
639        // on Windows due to atomic_write rename operations
640        for i in 0..10 {
641            memory
642                .store(create_test_entry(&format!("entry-{}", i), "content"))
643                .await
644                .unwrap();
645        }
646
647        // All entries should be saved
648        assert!(path.exists());
649        let loaded = SessionMemory::load_from_file("concurrent", &path)
650            .await
651            .unwrap();
652        assert_eq!(loaded.count().await.unwrap(), 10);
653    }
654
655    #[test]
656    fn test_dirty_flag_thread_safety() {
657        let flag = Arc::new(DirtyFlag::new());
658        let mut handles = vec![];
659
660        for _ in 0..100 {
661            let f = flag.clone();
662            let handle = std::thread::spawn(move || {
663                f.mark_dirty();
664                assert!(f.is_dirty());
665            });
666            handles.push(handle);
667        }
668
669        for handle in handles {
670            handle.join().unwrap();
671        }
672
673        assert!(flag.is_dirty());
674    }
675}