Skip to main content

sh_layer3/memory_system/
long_term.rs

1//! # Long-term Memory
2//!
3//! 长期记忆:跨项目的通用知识,使用向量存储。
4
5use crate::memory_system::{DecayPolicy, MemoryStore, TimeBasedDecay};
6use crate::retriever_engine::RetrieverEngine;
7use crate::types::{Layer3Result, MemoryEntry, MemoryQuery, MemoryTier};
8use async_trait::async_trait;
9use parking_lot::RwLock;
10use std::sync::Arc;
11
12/// Long-term Memory 实现
13///
14/// 使用向量数据库存储,支持语义检索。
15#[allow(dead_code)]
16pub struct LongTermMemory {
17    /// 检索引擎
18    retriever: Option<Arc<dyn RetrieverEngine>>,
19    /// 本地缓存
20    cache: Arc<RwLock<Vec<MemoryEntry>>>,
21    /// 衰减策略
22    #[allow(dead_code)]
23    decay_policy: Box<dyn DecayPolicy>,
24}
25
26impl LongTermMemory {
27    pub fn new(retriever: Option<Arc<dyn RetrieverEngine>>) -> Self {
28        Self {
29            retriever,
30            cache: Arc::new(RwLock::new(Vec::new())),
31            decay_policy: Box::new(TimeBasedDecay::default()),
32        }
33    }
34}
35
36impl Default for LongTermMemory {
37    fn default() -> Self {
38        Self::new(None)
39    }
40}
41
42#[async_trait]
43impl MemoryStore for LongTermMemory {
44    fn tier(&self) -> MemoryTier {
45        MemoryTier::LongTerm
46    }
47
48    async fn store(&self, entry: MemoryEntry) -> Layer3Result<String> {
49        let id = entry.id.clone();
50
51        // 存储到检索引擎(如果有)
52        if let Some(retriever) = &self.retriever {
53            use crate::retriever_engine::Document;
54            let doc = Document::new(&entry.content).with_source(&entry.id);
55            retriever.index(vec![doc]).await?;
56        }
57
58        // 缓存
59        self.cache.write().push(entry);
60
61        Ok(id)
62    }
63
64    async fn get(&self, id: &str) -> Layer3Result<Option<MemoryEntry>> {
65        let cache = self.cache.read();
66        Ok(cache.iter().find(|e| e.id == id).cloned())
67    }
68
69    async fn delete(&self, id: &str) -> Layer3Result<bool> {
70        // 从检索引擎删除
71        if let Some(retriever) = &self.retriever {
72            retriever.delete(&[id.to_string()]).await?;
73        }
74
75        // 从缓存删除
76        let mut cache = self.cache.write();
77        let len_before = cache.len();
78        cache.retain(|e| e.id != id);
79        Ok(cache.len() < len_before)
80    }
81
82    async fn query(&self, query: &MemoryQuery) -> Layer3Result<Vec<MemoryEntry>> {
83        // 使用向量检索
84        if let Some(retriever) = &self.retriever {
85            let results = retriever
86                .retrieve(&query.query, query.limit.unwrap_or(10))
87                .await?;
88            let entries: Vec<MemoryEntry> = results
89                .into_iter()
90                .map(|r| MemoryEntry {
91                    id: r.doc_id,
92                    tier: MemoryTier::LongTerm,
93                    content: r.content,
94                    metadata: r.metadata.into_iter().collect(),
95                    created_at: chrono::Utc::now(),
96                    last_accessed: chrono::Utc::now(),
97                    access_count: 0,
98                    importance: r.score,
99                })
100                .collect();
101            return Ok(entries);
102        }
103
104        // 回退到缓存搜索
105        let cache = self.cache.read();
106        let results: Vec<MemoryEntry> = cache
107            .iter()
108            .filter(|e| e.content.contains(&query.query))
109            .take(query.limit.unwrap_or(10))
110            .cloned()
111            .collect();
112        Ok(results)
113    }
114
115    async fn list(&self, limit: Option<usize>) -> Layer3Result<Vec<MemoryEntry>> {
116        let cache = self.cache.read();
117        Ok(cache
118            .iter()
119            .take(limit.unwrap_or(usize::MAX))
120            .cloned()
121            .collect())
122    }
123
124    async fn clear(&self) -> Layer3Result<usize> {
125        let count = self.cache.read().len();
126        self.cache.write().clear();
127
128        // 清空检索引擎
129        if let Some(retriever) = &self.retriever {
130            retriever.clear().await?;
131        }
132
133        Ok(count)
134    }
135
136    async fn count(&self) -> Layer3Result<usize> {
137        Ok(self.cache.read().len())
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_long_term_memory_tier() {
147        let memory = LongTermMemory::default();
148        assert_eq!(memory.tier(), MemoryTier::LongTerm);
149    }
150}