Skip to main content

sh_layer3/memory_system/
working.rs

1//! # Working Memory
2//!
3//! 工作记忆:当前对话上下文,临时存储。
4
5use crate::memory_system::{DecayPolicy, MemoryStore, TimeBasedDecay};
6use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier};
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use std::collections::VecDeque;
10use std::sync::Arc;
11
12/// Working Memory 实现
13///
14/// 使用环形缓冲区存储最近 N 条记忆。
15#[allow(dead_code)]
16pub struct WorkingMemory {
17    /// 存储缓冲区
18    buffer: Arc<RwLock<VecDeque<MemoryEntry>>>,
19    /// 最大容量
20    max_size: usize,
21    /// 衰减策略
22    #[allow(dead_code)]
23    decay_policy: Box<dyn DecayPolicy>,
24}
25
26impl WorkingMemory {
27    pub fn new(max_size: usize) -> Self {
28        Self {
29            buffer: Arc::new(RwLock::new(VecDeque::with_capacity(max_size))),
30            max_size,
31            decay_policy: Box::new(TimeBasedDecay::default()),
32        }
33    }
34}
35
36impl Default for WorkingMemory {
37    fn default() -> Self {
38        Self::new(100)
39    }
40}
41
42#[async_trait]
43impl MemoryStore for WorkingMemory {
44    fn tier(&self) -> MemoryTier {
45        MemoryTier::Working
46    }
47
48    async fn store(&self, entry: MemoryEntry) -> Layer3Result<String> {
49        let mut buffer = self.buffer.write();
50        if buffer.len() >= self.max_size {
51            buffer.pop_front();
52        }
53        let id = entry.id.clone();
54        buffer.push_back(entry);
55        Ok(id)
56    }
57
58    async fn get(&self, id: &str) -> Layer3Result<Option<MemoryEntry>> {
59        let buffer = self.buffer.read();
60        Ok(buffer.iter().find(|e| e.id == id).cloned())
61    }
62
63    async fn delete(&self, id: &str) -> Layer3Result<bool> {
64        let mut buffer = self.buffer.write();
65        let len_before = buffer.len();
66        buffer.retain(|e| e.id != id);
67        Ok(buffer.len() < len_before)
68    }
69
70    async fn query(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
71        let buffer = self.buffer.read();
72        let results: Vec<MemoryEntry> = buffer
73            .iter()
74            .filter(|e| {
75                if let Some(tier) = query.tier {
76                    if e.tier != tier {
77                        return false;
78                    }
79                }
80                e.content.contains(&query.query)
81            })
82            .take(query.limit.unwrap_or(10))
83            .cloned()
84            .collect();
85        Ok(results)
86    }
87
88    async fn list(&self, limit: Option<usize>) -> Layer3Result<Vec<MemoryEntry>> {
89        let buffer = self.buffer.read();
90        Ok(buffer
91            .iter()
92            .take(limit.unwrap_or(usize::MAX))
93            .cloned()
94            .collect())
95    }
96
97    async fn clear(&self) -> Layer3Result<usize> {
98        let mut buffer = self.buffer.write();
99        let count = buffer.len();
100        buffer.clear();
101        Ok(count)
102    }
103
104    async fn count(&self) -> Layer3Result<usize> {
105        Ok(self.buffer.read().len())
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[tokio::test]
114    async fn test_working_memory_store() {
115        let memory = WorkingMemory::new(10);
116        let entry = MemoryEntry {
117            id: "test-1".to_string(),
118            tier: MemoryTier::Working,
119            content: "test content".to_string(),
120            metadata: Default::default(),
121            created_at: chrono::Utc::now(),
122            last_accessed: chrono::Utc::now(),
123            access_count: 0,
124            importance: 0.5,
125        };
126        memory.store(entry).await.unwrap();
127        assert_eq!(memory.count().await.unwrap(), 1);
128    }
129}