Skip to main content

sh_layer3/memory_system/
system.rs

1//! # Unified Memory System
2//!
3//! 整合四层记忆的统一接口。
4
5use crate::memory_system::{session::SessionMemory, working::WorkingMemory, MemoryStore};
6use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier};
7use async_trait::async_trait;
8use sh_layer2::generate_short_id;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// 统一记忆系统
13///
14/// 整合 Working, Session, Project, LongTerm 四层记忆。
15#[allow(dead_code)]
16pub struct UnifiedMemorySystem {
17    /// 工作记忆(内存环形缓冲)
18    working: Arc<WorkingMemory>,
19    /// 会话记忆(会话级别 HashMap)
20    session: Arc<SessionMemory>,
21    /// 项目记忆(文件持久化)
22    project: Option<Arc<dyn MemoryStore>>,
23    /// 长期记忆(向量检索)
24    long_term: Option<Arc<dyn MemoryStore>>,
25    /// 当前会话 ID
26    #[allow(dead_code)]
27    session_id: String,
28}
29
30impl UnifiedMemorySystem {
31    /// 创建新的记忆系统
32    pub fn new(session_id: impl Into<String>) -> Self {
33        let session_id = session_id.into();
34        Self {
35            working: Arc::new(WorkingMemory::new(100)),
36            session: Arc::new(SessionMemory::new(&session_id)),
37            project: None,
38            long_term: None,
39            session_id,
40        }
41    }
42
43    /// 设置项目记忆存储
44    pub fn with_project(mut self, project: Arc<dyn MemoryStore>) -> Self {
45        self.project = Some(project);
46        self
47    }
48
49    /// 设置长期记忆存储
50    pub fn with_long_term(mut self, long_term: Arc<dyn MemoryStore>) -> Self {
51        self.long_term = Some(long_term);
52        self
53    }
54
55    /// 获取工作记忆
56    pub fn working(&self) -> &WorkingMemory {
57        &self.working
58    }
59
60    /// 获取会话记忆
61    pub fn session(&self) -> &SessionMemory {
62        &self.session
63    }
64
65    /// 存储到指定层级
66    pub async fn store_at(
67        &self,
68        tier: MemoryTier,
69        content: impl Into<String>,
70    ) -> Layer3Result<String> {
71        let entry = MemoryEntry {
72            id: generate_short_id(),
73            tier,
74            content: content.into(),
75            metadata: Default::default(),
76            created_at: chrono::Utc::now(),
77            last_accessed: chrono::Utc::now(),
78            access_count: 0,
79            importance: 0.5,
80        };
81
82        match tier {
83            MemoryTier::Working => self.working.store(entry).await,
84            MemoryTier::Session => self.session.store(entry).await,
85            MemoryTier::Project => {
86                if let Some(ref project) = self.project {
87                    project.store(entry).await
88                } else {
89                    self.session.store(entry).await
90                }
91            }
92            MemoryTier::LongTerm => {
93                if let Some(ref long_term) = self.long_term {
94                    long_term.store(entry).await
95                } else {
96                    self.session.store(entry).await
97                }
98            }
99        }
100    }
101
102    /// 跨层级查询
103    ///
104    /// 按 Working -> Session -> Project -> LongTerm 顺序查询
105    pub async fn query_all(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
106        let mut results = Vec::new();
107        let limit = query.limit.unwrap_or(10);
108
109        // Working
110        let working_results = self.working.query(query).await?;
111        results.extend(working_results);
112        if results.len() >= limit {
113            return Ok(results.into_iter().take(limit).collect());
114        }
115
116        // Session
117        let session_results = self.session.query(query).await?;
118        results.extend(session_results);
119        if results.len() >= limit {
120            return Ok(results.into_iter().take(limit).collect());
121        }
122
123        // Project
124        if let Some(ref project) = self.project {
125            let project_results = project.query(query).await?;
126            results.extend(project_results);
127            if results.len() >= limit {
128                return Ok(results.into_iter().take(limit).collect());
129            }
130        }
131
132        // LongTerm
133        if let Some(ref long_term) = self.long_term {
134            let long_term_results = long_term.query(query).await?;
135            results.extend(long_term_results);
136        }
137
138        Ok(results.into_iter().take(limit).collect())
139    }
140
141    /// 获取层级统计
142    pub async fn stats(&self) -> Layer3Result<HashMap<MemoryTier, usize>> {
143        let mut stats = HashMap::new();
144        stats.insert(MemoryTier::Working, self.working.count().await?);
145        stats.insert(MemoryTier::Session, self.session.count().await?);
146        if let Some(ref project) = self.project {
147            stats.insert(MemoryTier::Project, project.count().await?);
148        }
149        if let Some(ref long_term) = self.long_term {
150            stats.insert(MemoryTier::LongTerm, long_term.count().await?);
151        }
152        Ok(stats)
153    }
154
155    /// 清空指定层级
156    pub async fn clear_tier(&self, tier: MemoryTier) -> Layer3Result<usize> {
157        match tier {
158            MemoryTier::Working => self.working.clear().await,
159            MemoryTier::Session => self.session.clear().await,
160            MemoryTier::Project => {
161                if let Some(ref project) = self.project {
162                    project.clear().await
163                } else {
164                    Ok(0)
165                }
166            }
167            MemoryTier::LongTerm => {
168                if let Some(ref long_term) = self.long_term {
169                    long_term.clear().await
170                } else {
171                    Ok(0)
172                }
173            }
174        }
175    }
176}
177
178/// MemorySystem trait 实现
179#[async_trait]
180impl crate::memory_system::MemorySystem for UnifiedMemorySystem {
181    async fn store(&self, tier: MemoryTier, content: String) -> Layer3Result<String> {
182        self.store_at(tier, content).await
183    }
184
185    async fn get(&self, tier: MemoryTier, id: &str) -> Layer3Result<Option<MemoryEntry>> {
186        match tier {
187            MemoryTier::Working => self.working.get(id).await,
188            MemoryTier::Session => self.session.get(id).await,
189            MemoryTier::Project => {
190                if let Some(ref project) = self.project {
191                    project.get(id).await
192                } else {
193                    Ok(None)
194                }
195            }
196            MemoryTier::LongTerm => {
197                if let Some(ref long_term) = self.long_term {
198                    long_term.get(id).await
199                } else {
200                    Ok(None)
201                }
202            }
203        }
204    }
205
206    async fn query_all(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
207        self.query_all(query).await
208    }
209
210    async fn query(&self, tier: MemoryTier, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
211        match tier {
212            MemoryTier::Working => self.working.query(query).await,
213            MemoryTier::Session => self.session.query(query).await,
214            MemoryTier::Project => {
215                if let Some(ref project) = self.project {
216                    project.query(query).await
217                } else {
218                    Ok(Vec::new())
219                }
220            }
221            MemoryTier::LongTerm => {
222                if let Some(ref long_term) = self.long_term {
223                    long_term.query(query).await
224                } else {
225                    Ok(Vec::new())
226                }
227            }
228        }
229    }
230
231    async fn delete(&self, tier: MemoryTier, id: &str) -> Layer3Result<bool> {
232        match tier {
233            MemoryTier::Working => self.working.delete(id).await,
234            MemoryTier::Session => self.session.delete(id).await,
235            MemoryTier::Project => {
236                if let Some(ref project) = self.project {
237                    project.delete(id).await
238                } else {
239                    Ok(false)
240                }
241            }
242            MemoryTier::LongTerm => {
243                if let Some(ref long_term) = self.long_term {
244                    long_term.delete(id).await
245                } else {
246                    Ok(false)
247                }
248            }
249        }
250    }
251
252    async fn clear(&self, tier: MemoryTier) -> Layer3Result<usize> {
253        self.clear_tier(tier).await
254    }
255
256    async fn stats(&self) -> Layer3Result<HashMap<MemoryTier, usize>> {
257        self.stats().await
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[tokio::test]
266    async fn test_unified_memory_system() {
267        let system = UnifiedMemorySystem::new("test-session");
268
269        // 测试工作记忆
270        let id = system
271            .store_at(MemoryTier::Working, "test working memory")
272            .await
273            .unwrap();
274        assert!(!id.is_empty());
275
276        // 测试统计
277        let stats = system.stats().await.unwrap();
278        assert!(stats.contains_key(&MemoryTier::Working));
279    }
280
281    #[test]
282    fn test_memory_system_creation() {
283        let system = UnifiedMemorySystem::new("test");
284        assert!(system.project.is_none());
285        assert!(system.long_term.is_none());
286    }
287}