sh_layer3/memory_system/
long_term.rs1use crate::memory_system::{DecayPolicy, MemoryStore, TimeBasedDecay};
6use crate::retriever_engine::RetrieverEngine;
7use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier};
8use async_trait::async_trait;
9use parking_lot::RwLock;
10use std::sync::Arc;
11
12#[allow(dead_code)]
16pub struct LongTermMemory {
17 retriever: Option<Arc<dyn RetrieverEngine>>,
19 cache: Arc<RwLock<Vec<MemoryEntry>>>,
21 #[allow(dead_code)]
23 decay_policy: Box<dyn DecayPolicy>,
24}
25
26impl LongTermMemory {
27 pub fn new(retriever: Option<Arc<dyn RetrieverEngine>>) -> Self {
28 Self {
29 retriever,
30 cache: Arc::new(RwLock::new(Vec::new())),
31 decay_policy: Box::new(TimeBasedDecay::default()),
32 }
33 }
34}
35
36impl Default for LongTermMemory {
37 fn default() -> Self {
38 Self::new(None)
39 }
40}
41
42#[async_trait]
43impl MemoryStore for LongTermMemory {
44 fn tier(&self) -> MemoryTier {
45 MemoryTier::LongTerm
46 }
47
48 async fn store(&self, entry: MemoryEntry) -> Layer3Result<String> {
49 let id = entry.id.clone();
50
51 if let Some(retriever) = &self.retriever {
53 use crate::retriever_engine::Document;
54 let doc = Document::new(&entry.content).with_source(&entry.id);
55 retriever.index(vec![doc]).await?;
56 }
57
58 self.cache.write().push(entry);
60
61 Ok(id)
62 }
63
64 async fn get(&self, id: &str) -> Layer3Result<Option<MemoryEntry>> {
65 let cache = self.cache.read();
66 Ok(cache.iter().find(|e| e.id == id).cloned())
67 }
68
69 async fn delete(&self, id: &str) -> Layer3Result<bool> {
70 if let Some(retriever) = &self.retriever {
72 retriever.delete(&[id.to_string()]).await?;
73 }
74
75 let mut cache = self.cache.write();
77 let len_before = cache.len();
78 cache.retain(|e| e.id != id);
79 Ok(cache.len() < len_before)
80 }
81
82 async fn query(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
83 if let Some(retriever) = &self.retriever {
85 let results = retriever
86 .retrieve(&query.query, query.limit.unwrap_or(10))
87 .await?;
88 let entries: Vec<MemoryEntry> = results
89 .into_iter()
90 .map(|r| MemoryEntry {
91 id: r.doc_id,
92 tier: MemoryTier::LongTerm,
93 content: r.content,
94 metadata: r.metadata.into_iter().collect(),
95 created_at: chrono::Utc::now(),
96 last_accessed: chrono::Utc::now(),
97 access_count: 0,
98 importance: r.score,
99 })
100 .collect();
101 return Ok(entries);
102 }
103
104 let cache = self.cache.read();
106 let results: Vec<MemoryEntry> = cache
107 .iter()
108 .filter(|e| e.content.contains(&query.query))
109 .take(query.limit.unwrap_or(10))
110 .cloned()
111 .collect();
112 Ok(results)
113 }
114
115 async fn list(&self, limit: Option<usize>) -> Layer3Result<Vec<MemoryEntry>> {
116 let cache = self.cache.read();
117 Ok(cache
118 .iter()
119 .take(limit.unwrap_or(usize::MAX))
120 .cloned()
121 .collect())
122 }
123
124 async fn clear(&self) -> Layer3Result<usize> {
125 let count = self.cache.read().len();
126 self.cache.write().clear();
127
128 if let Some(retriever) = &self.retriever {
130 retriever.clear().await?;
131 }
132
133 Ok(count)
134 }
135
136 async fn count(&self) -> Layer3Result<usize> {
137 Ok(self.cache.read().len())
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_long_term_memory_tier() {
147 let memory = LongTermMemory::default();
148 assert_eq!(memory.tier(), MemoryTier::LongTerm);
149 }
150}