sh_layer3/memory_system/
system.rs1use crate::memory_system::{session::SessionMemory, working::WorkingMemory, MemoryStore};
6use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier};
7use async_trait::async_trait;
8use sh_layer2::generate_short_id;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[allow(dead_code)]
16pub struct UnifiedMemorySystem {
17 working: Arc<WorkingMemory>,
19 session: Arc<SessionMemory>,
21 project: Option<Arc<dyn MemoryStore>>,
23 long_term: Option<Arc<dyn MemoryStore>>,
25 #[allow(dead_code)]
27 session_id: String,
28}
29
30impl UnifiedMemorySystem {
31 pub fn new(session_id: impl Into<String>) -> Self {
33 let session_id = session_id.into();
34 Self {
35 working: Arc::new(WorkingMemory::new(100)),
36 session: Arc::new(SessionMemory::new(&session_id)),
37 project: None,
38 long_term: None,
39 session_id,
40 }
41 }
42
43 pub fn with_project(mut self, project: Arc<dyn MemoryStore>) -> Self {
45 self.project = Some(project);
46 self
47 }
48
49 pub fn with_long_term(mut self, long_term: Arc<dyn MemoryStore>) -> Self {
51 self.long_term = Some(long_term);
52 self
53 }
54
55 pub fn working(&self) -> &WorkingMemory {
57 &self.working
58 }
59
60 pub fn session(&self) -> &SessionMemory {
62 &self.session
63 }
64
65 pub async fn store_at(
67 &self,
68 tier: MemoryTier,
69 content: impl Into<String>,
70 ) -> Layer3Result<String> {
71 let entry = MemoryEntry {
72 id: generate_short_id(),
73 tier,
74 content: content.into(),
75 metadata: Default::default(),
76 created_at: chrono::Utc::now(),
77 last_accessed: chrono::Utc::now(),
78 access_count: 0,
79 importance: 0.5,
80 };
81
82 match tier {
83 MemoryTier::Working => self.working.store(entry).await,
84 MemoryTier::Session => self.session.store(entry).await,
85 MemoryTier::Project => {
86 if let Some(ref project) = self.project {
87 project.store(entry).await
88 } else {
89 self.session.store(entry).await
90 }
91 }
92 MemoryTier::LongTerm => {
93 if let Some(ref long_term) = self.long_term {
94 long_term.store(entry).await
95 } else {
96 self.session.store(entry).await
97 }
98 }
99 }
100 }
101
102 pub async fn query_all(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
106 let mut results = Vec::new();
107 let limit = query.limit.unwrap_or(10);
108
109 let working_results = self.working.query(query).await?;
111 results.extend(working_results);
112 if results.len() >= limit {
113 return Ok(results.into_iter().take(limit).collect());
114 }
115
116 let session_results = self.session.query(query).await?;
118 results.extend(session_results);
119 if results.len() >= limit {
120 return Ok(results.into_iter().take(limit).collect());
121 }
122
123 if let Some(ref project) = self.project {
125 let project_results = project.query(query).await?;
126 results.extend(project_results);
127 if results.len() >= limit {
128 return Ok(results.into_iter().take(limit).collect());
129 }
130 }
131
132 if let Some(ref long_term) = self.long_term {
134 let long_term_results = long_term.query(query).await?;
135 results.extend(long_term_results);
136 }
137
138 Ok(results.into_iter().take(limit).collect())
139 }
140
141 pub async fn stats(&self) -> Layer3Result<HashMap<MemoryTier, usize>> {
143 let mut stats = HashMap::new();
144 stats.insert(MemoryTier::Working, self.working.count().await?);
145 stats.insert(MemoryTier::Session, self.session.count().await?);
146 if let Some(ref project) = self.project {
147 stats.insert(MemoryTier::Project, project.count().await?);
148 }
149 if let Some(ref long_term) = self.long_term {
150 stats.insert(MemoryTier::LongTerm, long_term.count().await?);
151 }
152 Ok(stats)
153 }
154
155 pub async fn clear_tier(&self, tier: MemoryTier) -> Layer3Result<usize> {
157 match tier {
158 MemoryTier::Working => self.working.clear().await,
159 MemoryTier::Session => self.session.clear().await,
160 MemoryTier::Project => {
161 if let Some(ref project) = self.project {
162 project.clear().await
163 } else {
164 Ok(0)
165 }
166 }
167 MemoryTier::LongTerm => {
168 if let Some(ref long_term) = self.long_term {
169 long_term.clear().await
170 } else {
171 Ok(0)
172 }
173 }
174 }
175 }
176}
177
178#[async_trait]
180impl crate::memory_system::MemorySystem for UnifiedMemorySystem {
181 async fn store(&self, tier: MemoryTier, content: String) -> Layer3Result<String> {
182 self.store_at(tier, content).await
183 }
184
185 async fn get(&self, tier: MemoryTier, id: &str) -> Layer3Result<Option<MemoryEntry>> {
186 match tier {
187 MemoryTier::Working => self.working.get(id).await,
188 MemoryTier::Session => self.session.get(id).await,
189 MemoryTier::Project => {
190 if let Some(ref project) = self.project {
191 project.get(id).await
192 } else {
193 Ok(None)
194 }
195 }
196 MemoryTier::LongTerm => {
197 if let Some(ref long_term) = self.long_term {
198 long_term.get(id).await
199 } else {
200 Ok(None)
201 }
202 }
203 }
204 }
205
206 async fn query_all(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
207 self.query_all(query).await
208 }
209
210 async fn query(&self, tier: MemoryTier, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
211 match tier {
212 MemoryTier::Working => self.working.query(query).await,
213 MemoryTier::Session => self.session.query(query).await,
214 MemoryTier::Project => {
215 if let Some(ref project) = self.project {
216 project.query(query).await
217 } else {
218 Ok(Vec::new())
219 }
220 }
221 MemoryTier::LongTerm => {
222 if let Some(ref long_term) = self.long_term {
223 long_term.query(query).await
224 } else {
225 Ok(Vec::new())
226 }
227 }
228 }
229 }
230
231 async fn delete(&self, tier: MemoryTier, id: &str) -> Layer3Result<bool> {
232 match tier {
233 MemoryTier::Working => self.working.delete(id).await,
234 MemoryTier::Session => self.session.delete(id).await,
235 MemoryTier::Project => {
236 if let Some(ref project) = self.project {
237 project.delete(id).await
238 } else {
239 Ok(false)
240 }
241 }
242 MemoryTier::LongTerm => {
243 if let Some(ref long_term) = self.long_term {
244 long_term.delete(id).await
245 } else {
246 Ok(false)
247 }
248 }
249 }
250 }
251
252 async fn clear(&self, tier: MemoryTier) -> Layer3Result<usize> {
253 self.clear_tier(tier).await
254 }
255
256 async fn stats(&self) -> Layer3Result<HashMap<MemoryTier, usize>> {
257 self.stats().await
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[tokio::test]
266 async fn test_unified_memory_system() {
267 let system = UnifiedMemorySystem::new("test-session");
268
269 let id = system
271 .store_at(MemoryTier::Working, "test working memory")
272 .await
273 .unwrap();
274 assert!(!id.is_empty());
275
276 let stats = system.stats().await.unwrap();
278 assert!(stats.contains_key(&MemoryTier::Working));
279 }
280
281 #[test]
282 fn test_memory_system_creation() {
283 let system = UnifiedMemorySystem::new("test");
284 assert!(system.project.is_none());
285 assert!(system.long_term.is_none());
286 }
287}