Skip to main content

wesichain_memory/
semantic.rs

1//! Semantic (vector-backed) long-term memory and entity memory.
2//!
3//! # Components
4//!
5//! - [`VectorMemoryStore`] — stores conversation turns as vector embeddings,
6//!   retrieves the most semantically relevant memories at query time.
7//! - [`EntityMemory`] — key/value store for named entities (e.g. user name,
8//!   project name) persisted across conversation turns.
9//! - [`MemoryRouter`] — fan-outs load/save to multiple `Memory` layers.
10//!
11//! # Example
12//! ```ignore
13//! let mem = VectorMemoryStore::new(my_embedder, my_vector_store, 3);
14//! mem.save_context("t1", &inputs, &outputs).await?;
15//! let vars = mem.load_memory_variables("t1").await?;
16//! ```
17
18use std::collections::HashMap;
19use std::sync::{Arc, Mutex};
20
21use async_trait::async_trait;
22use serde_json::Value;
23use wesichain_core::{Document, Embedding, VectorStore, WesichainError};
24
25use crate::Memory;
26
27// ── VectorMemoryStore ─────────────────────────────────────────────────────────
28
29/// A long-term memory that stores conversation turns as vector embeddings and
30/// retrieves the most relevant memories for each new query.
31pub struct VectorMemoryStore<E, V> {
32    embedder: Arc<E>,
33    store: Arc<V>,
34    top_k: usize,
35    memory_key: String,
36}
37
38impl<E, V> VectorMemoryStore<E, V>
39where
40    E: Embedding + Send + Sync + 'static,
41    V: VectorStore + Send + Sync + 'static,
42{
43    /// Create a new `VectorMemoryStore`.
44    ///
45    /// - `embedder`: converts text to vectors
46    /// - `store`: the underlying vector store
47    /// - `top_k`: number of memories to retrieve per query
48    pub fn new(embedder: Arc<E>, store: Arc<V>, top_k: usize) -> Self {
49        Self {
50            embedder,
51            store,
52            top_k,
53            memory_key: "history".to_string(),
54        }
55    }
56
57    /// Override the key used in the returned memory variables map.
58    pub fn with_memory_key(mut self, key: impl Into<String>) -> Self {
59        self.memory_key = key.into();
60        self
61    }
62}
63
64#[async_trait]
65impl<E, V> Memory for VectorMemoryStore<E, V>
66where
67    E: Embedding + Send + Sync + 'static,
68    V: VectorStore + Send + Sync + 'static,
69{
70    async fn load_memory_variables(
71        &self,
72        thread_id: &str,
73    ) -> Result<HashMap<String, Value>, WesichainError> {
74        // Use thread_id itself as the query text — callers should pass the
75        // latest user message as thread_id when they want contextual recall.
76        let query_emb = self.embedder.embed(thread_id).await.map_err(|e| {
77            WesichainError::LlmProvider(format!("embedding failed: {e}"))
78        })?;
79
80        let results = self
81            .store
82            .search(&query_emb, self.top_k, None)
83            .await
84            .map_err(|e| WesichainError::LlmProvider(format!("vector search failed: {e}")))?;
85
86        let memories: Vec<Value> = results
87            .into_iter()
88            .map(|r| Value::String(r.document.content))
89            .collect();
90
91        let mut map = HashMap::new();
92        map.insert(self.memory_key.clone(), Value::Array(memories));
93        Ok(map)
94    }
95
96    async fn save_context(
97        &self,
98        thread_id: &str,
99        inputs: &HashMap<String, Value>,
100        outputs: &HashMap<String, Value>,
101    ) -> Result<(), WesichainError> {
102        // Build a human-readable turn text and embed it
103        let input_text = inputs
104            .get("input")
105            .or_else(|| inputs.values().next())
106            .map(|v| v.as_str().unwrap_or("").to_string())
107            .unwrap_or_default();
108        let output_text = outputs
109            .get("output")
110            .or_else(|| outputs.values().next())
111            .map(|v| v.as_str().unwrap_or("").to_string())
112            .unwrap_or_default();
113
114        let turn_text = format!("Human: {input_text}\nAI: {output_text}");
115
116        let embedding = self.embedder.embed(&turn_text).await.map_err(|e| {
117            WesichainError::LlmProvider(format!("embedding failed: {e}"))
118        })?;
119
120        let mut metadata = HashMap::new();
121        metadata.insert("thread_id".to_string(), Value::String(thread_id.to_string()));
122
123        let doc = Document {
124            id: uuid::Uuid::new_v4().to_string(),
125            content: turn_text,
126            metadata,
127            embedding: Some(embedding),
128        };
129
130        self.store
131            .add(vec![doc])
132            .await
133            .map_err(|e| WesichainError::LlmProvider(format!("vector store write failed: {e}")))?;
134
135        Ok(())
136    }
137
138    async fn clear(&self, _thread_id: &str) -> Result<(), WesichainError> {
139        // Vector stores typically don't support per-thread bulk delete via the
140        // base trait — callers can downcast if needed.
141        Ok(())
142    }
143}
144
145// ── EntityMemory ──────────────────────────────────────────────────────────────
146
147/// In-memory key/value store for named entities.
148///
149/// On `save_context` all values from the `outputs` map are stored as entities.
150/// On `load_memory_variables` the full entity map for the thread is returned
151/// under the `"entities"` key.
152#[derive(Default, Clone)]
153pub struct EntityMemory {
154    inner: Arc<Mutex<HashMap<String, HashMap<String, Value>>>>,
155}
156
157impl EntityMemory {
158    pub fn new() -> Self {
159        Self::default()
160    }
161
162    /// Directly upsert an entity value for a thread.
163    pub fn upsert(&self, thread_id: &str, key: impl Into<String>, value: Value) {
164        let mut guard = self.inner.lock().unwrap();
165        guard
166            .entry(thread_id.to_string())
167            .or_default()
168            .insert(key.into(), value);
169    }
170
171    /// Return all entities for a thread as a flat map.
172    pub fn entities(&self, thread_id: &str) -> HashMap<String, Value> {
173        let guard = self.inner.lock().unwrap();
174        guard.get(thread_id).cloned().unwrap_or_default()
175    }
176}
177
178#[async_trait]
179impl Memory for EntityMemory {
180    async fn load_memory_variables(
181        &self,
182        thread_id: &str,
183    ) -> Result<HashMap<String, Value>, WesichainError> {
184        let entities = self.entities(thread_id);
185        let mut map = HashMap::new();
186        map.insert(
187            "entities".to_string(),
188            serde_json::to_value(entities)
189                .map_err(|e| WesichainError::Custom(e.to_string()))?,
190        );
191        Ok(map)
192    }
193
194    async fn save_context(
195        &self,
196        thread_id: &str,
197        _inputs: &HashMap<String, Value>,
198        outputs: &HashMap<String, Value>,
199    ) -> Result<(), WesichainError> {
200        let mut guard = self.inner.lock().unwrap();
201        let thread_entities = guard.entry(thread_id.to_string()).or_default();
202        for (k, v) in outputs {
203            thread_entities.insert(k.clone(), v.clone());
204        }
205        Ok(())
206    }
207
208    async fn clear(&self, thread_id: &str) -> Result<(), WesichainError> {
209        let mut guard = self.inner.lock().unwrap();
210        guard.remove(thread_id);
211        Ok(())
212    }
213}
214
215// ── MemoryRouter ──────────────────────────────────────────────────────────────
216
217/// Routes memory operations to multiple [`Memory`] layers.
218///
219/// `load_memory_variables` merges results from all layers (later layers win on
220/// key collisions).  `save_context` and `clear` are called on all layers.
221pub struct MemoryRouter {
222    layers: Vec<Arc<dyn Memory>>,
223}
224
225impl MemoryRouter {
226    pub fn new(layers: Vec<Arc<dyn Memory>>) -> Self {
227        Self { layers }
228    }
229
230    pub fn push(mut self, layer: Arc<dyn Memory>) -> Self {
231        self.layers.push(layer);
232        self
233    }
234}
235
236#[async_trait]
237impl Memory for MemoryRouter {
238    async fn load_memory_variables(
239        &self,
240        thread_id: &str,
241    ) -> Result<HashMap<String, Value>, WesichainError> {
242        let mut merged = HashMap::new();
243        for layer in &self.layers {
244            let vars = layer.load_memory_variables(thread_id).await?;
245            merged.extend(vars);
246        }
247        Ok(merged)
248    }
249
250    async fn save_context(
251        &self,
252        thread_id: &str,
253        inputs: &HashMap<String, Value>,
254        outputs: &HashMap<String, Value>,
255    ) -> Result<(), WesichainError> {
256        for layer in &self.layers {
257            layer.save_context(thread_id, inputs, outputs).await?;
258        }
259        Ok(())
260    }
261
262    async fn clear(&self, thread_id: &str) -> Result<(), WesichainError> {
263        for layer in &self.layers {
264            layer.clear(thread_id).await?;
265        }
266        Ok(())
267    }
268}
269
270// ── tests ─────────────────────────────────────────────────────────────────────
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use wesichain_core::EmbeddingError;
276
277    // Minimal stub embedder: returns a fixed 3-dim vector
278    struct StubEmbedder;
279
280    #[async_trait]
281    impl Embedding for StubEmbedder {
282        async fn embed(&self, _text: &str) -> Result<Vec<f32>, EmbeddingError> {
283            Ok(vec![0.1, 0.2, 0.3])
284        }
285        async fn embed_batch(
286            &self,
287            texts: &[String],
288        ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
289            Ok(texts.iter().map(|_| vec![0.1, 0.2, 0.3]).collect())
290        }
291        fn dimension(&self) -> usize {
292            3
293        }
294    }
295
296    // Minimal stub vector store: stores docs in memory, returns all on search
297    #[derive(Default)]
298    struct StubVectorStore {
299        docs: Mutex<Vec<Document>>,
300    }
301
302    #[async_trait]
303    impl VectorStore for StubVectorStore {
304        async fn add(&self, new_docs: Vec<Document>) -> Result<(), wesichain_core::StoreError> {
305            self.docs.lock().unwrap().extend(new_docs);
306            Ok(())
307        }
308        async fn search(
309            &self,
310            _query: &[f32],
311            top_k: usize,
312            _filter: Option<&wesichain_core::MetadataFilter>,
313        ) -> Result<Vec<wesichain_core::SearchResult>, wesichain_core::StoreError> {
314            let docs = self.docs.lock().unwrap();
315            Ok(docs
316                .iter()
317                .take(top_k)
318                .map(|d| wesichain_core::SearchResult { document: d.clone(), score: 1.0 })
319                .collect())
320        }
321        async fn delete(&self, _ids: &[String]) -> Result<(), wesichain_core::StoreError> {
322            Ok(())
323        }
324    }
325
326    #[tokio::test]
327    async fn vector_memory_round_trip() {
328        let store = VectorMemoryStore::new(
329            Arc::new(StubEmbedder),
330            Arc::new(StubVectorStore::default()),
331            5,
332        );
333
334        let mut inputs = HashMap::new();
335        inputs.insert("input".to_string(), Value::String("Hello".to_string()));
336        let mut outputs = HashMap::new();
337        outputs.insert("output".to_string(), Value::String("Hi!".to_string()));
338
339        store.save_context("t1", &inputs, &outputs).await.unwrap();
340
341        let vars = store.load_memory_variables("latest message").await.unwrap();
342        let history = vars.get("history").unwrap().as_array().unwrap();
343        assert_eq!(history.len(), 1);
344        assert!(history[0].as_str().unwrap().contains("Human: Hello"));
345    }
346
347    #[tokio::test]
348    async fn entity_memory_stores_and_loads() {
349        let mem = EntityMemory::new();
350        let mut inputs = HashMap::new();
351        inputs.insert("user".to_string(), Value::String("Alice".to_string()));
352        let mut outputs = HashMap::new();
353        outputs.insert("project".to_string(), Value::String("wesichain".to_string()));
354
355        mem.save_context("t1", &inputs, &outputs).await.unwrap();
356
357        let vars = mem.load_memory_variables("t1").await.unwrap();
358        let entities = vars.get("entities").unwrap().as_object().unwrap();
359        assert_eq!(entities.get("project").unwrap().as_str().unwrap(), "wesichain");
360    }
361
362    #[tokio::test]
363    async fn entity_memory_clear() {
364        let mem = EntityMemory::new();
365        mem.upsert("t1", "foo", Value::String("bar".to_string()));
366        mem.clear("t1").await.unwrap();
367        assert!(mem.entities("t1").is_empty());
368    }
369
370    #[tokio::test]
371    async fn memory_router_merges_layers() {
372        let entity_mem = Arc::new(EntityMemory::new());
373        entity_mem.upsert("t1", "entity_key", Value::String("v1".to_string()));
374
375        let router = MemoryRouter::new(vec![entity_mem]);
376        let vars = router.load_memory_variables("t1").await.unwrap();
377        assert!(vars.contains_key("entities"));
378    }
379}