1use std::collections::HashMap;
19use std::sync::{Arc, Mutex};
20
21use async_trait::async_trait;
22use serde_json::Value;
23use wesichain_core::{Document, Embedding, VectorStore, WesichainError};
24
25use crate::Memory;
26
27pub struct VectorMemoryStore<E, V> {
32 embedder: Arc<E>,
33 store: Arc<V>,
34 top_k: usize,
35 memory_key: String,
36}
37
38impl<E, V> VectorMemoryStore<E, V>
39where
40 E: Embedding + Send + Sync + 'static,
41 V: VectorStore + Send + Sync + 'static,
42{
43 pub fn new(embedder: Arc<E>, store: Arc<V>, top_k: usize) -> Self {
49 Self {
50 embedder,
51 store,
52 top_k,
53 memory_key: "history".to_string(),
54 }
55 }
56
57 pub fn with_memory_key(mut self, key: impl Into<String>) -> Self {
59 self.memory_key = key.into();
60 self
61 }
62}
63
64#[async_trait]
65impl<E, V> Memory for VectorMemoryStore<E, V>
66where
67 E: Embedding + Send + Sync + 'static,
68 V: VectorStore + Send + Sync + 'static,
69{
70 async fn load_memory_variables(
71 &self,
72 thread_id: &str,
73 ) -> Result<HashMap<String, Value>, WesichainError> {
74 let query_emb = self.embedder.embed(thread_id).await.map_err(|e| {
77 WesichainError::LlmProvider(format!("embedding failed: {e}"))
78 })?;
79
80 let results = self
81 .store
82 .search(&query_emb, self.top_k, None)
83 .await
84 .map_err(|e| WesichainError::LlmProvider(format!("vector search failed: {e}")))?;
85
86 let memories: Vec<Value> = results
87 .into_iter()
88 .map(|r| Value::String(r.document.content))
89 .collect();
90
91 let mut map = HashMap::new();
92 map.insert(self.memory_key.clone(), Value::Array(memories));
93 Ok(map)
94 }
95
96 async fn save_context(
97 &self,
98 thread_id: &str,
99 inputs: &HashMap<String, Value>,
100 outputs: &HashMap<String, Value>,
101 ) -> Result<(), WesichainError> {
102 let input_text = inputs
104 .get("input")
105 .or_else(|| inputs.values().next())
106 .map(|v| v.as_str().unwrap_or("").to_string())
107 .unwrap_or_default();
108 let output_text = outputs
109 .get("output")
110 .or_else(|| outputs.values().next())
111 .map(|v| v.as_str().unwrap_or("").to_string())
112 .unwrap_or_default();
113
114 let turn_text = format!("Human: {input_text}\nAI: {output_text}");
115
116 let embedding = self.embedder.embed(&turn_text).await.map_err(|e| {
117 WesichainError::LlmProvider(format!("embedding failed: {e}"))
118 })?;
119
120 let mut metadata = HashMap::new();
121 metadata.insert("thread_id".to_string(), Value::String(thread_id.to_string()));
122
123 let doc = Document {
124 id: uuid::Uuid::new_v4().to_string(),
125 content: turn_text,
126 metadata,
127 embedding: Some(embedding),
128 };
129
130 self.store
131 .add(vec![doc])
132 .await
133 .map_err(|e| WesichainError::LlmProvider(format!("vector store write failed: {e}")))?;
134
135 Ok(())
136 }
137
138 async fn clear(&self, _thread_id: &str) -> Result<(), WesichainError> {
139 Ok(())
142 }
143}
144
145#[derive(Default, Clone)]
153pub struct EntityMemory {
154 inner: Arc<Mutex<HashMap<String, HashMap<String, Value>>>>,
155}
156
157impl EntityMemory {
158 pub fn new() -> Self {
159 Self::default()
160 }
161
162 pub fn upsert(&self, thread_id: &str, key: impl Into<String>, value: Value) {
164 let mut guard = self.inner.lock().unwrap();
165 guard
166 .entry(thread_id.to_string())
167 .or_default()
168 .insert(key.into(), value);
169 }
170
171 pub fn entities(&self, thread_id: &str) -> HashMap<String, Value> {
173 let guard = self.inner.lock().unwrap();
174 guard.get(thread_id).cloned().unwrap_or_default()
175 }
176}
177
178#[async_trait]
179impl Memory for EntityMemory {
180 async fn load_memory_variables(
181 &self,
182 thread_id: &str,
183 ) -> Result<HashMap<String, Value>, WesichainError> {
184 let entities = self.entities(thread_id);
185 let mut map = HashMap::new();
186 map.insert(
187 "entities".to_string(),
188 serde_json::to_value(entities)
189 .map_err(|e| WesichainError::Custom(e.to_string()))?,
190 );
191 Ok(map)
192 }
193
194 async fn save_context(
195 &self,
196 thread_id: &str,
197 _inputs: &HashMap<String, Value>,
198 outputs: &HashMap<String, Value>,
199 ) -> Result<(), WesichainError> {
200 let mut guard = self.inner.lock().unwrap();
201 let thread_entities = guard.entry(thread_id.to_string()).or_default();
202 for (k, v) in outputs {
203 thread_entities.insert(k.clone(), v.clone());
204 }
205 Ok(())
206 }
207
208 async fn clear(&self, thread_id: &str) -> Result<(), WesichainError> {
209 let mut guard = self.inner.lock().unwrap();
210 guard.remove(thread_id);
211 Ok(())
212 }
213}
214
215pub struct MemoryRouter {
222 layers: Vec<Arc<dyn Memory>>,
223}
224
225impl MemoryRouter {
226 pub fn new(layers: Vec<Arc<dyn Memory>>) -> Self {
227 Self { layers }
228 }
229
230 pub fn push(mut self, layer: Arc<dyn Memory>) -> Self {
231 self.layers.push(layer);
232 self
233 }
234}
235
236#[async_trait]
237impl Memory for MemoryRouter {
238 async fn load_memory_variables(
239 &self,
240 thread_id: &str,
241 ) -> Result<HashMap<String, Value>, WesichainError> {
242 let mut merged = HashMap::new();
243 for layer in &self.layers {
244 let vars = layer.load_memory_variables(thread_id).await?;
245 merged.extend(vars);
246 }
247 Ok(merged)
248 }
249
250 async fn save_context(
251 &self,
252 thread_id: &str,
253 inputs: &HashMap<String, Value>,
254 outputs: &HashMap<String, Value>,
255 ) -> Result<(), WesichainError> {
256 for layer in &self.layers {
257 layer.save_context(thread_id, inputs, outputs).await?;
258 }
259 Ok(())
260 }
261
262 async fn clear(&self, thread_id: &str) -> Result<(), WesichainError> {
263 for layer in &self.layers {
264 layer.clear(thread_id).await?;
265 }
266 Ok(())
267 }
268}
269
270#[cfg(test)]
273mod tests {
274 use super::*;
275 use wesichain_core::EmbeddingError;
276
277 struct StubEmbedder;
279
280 #[async_trait]
281 impl Embedding for StubEmbedder {
282 async fn embed(&self, _text: &str) -> Result<Vec<f32>, EmbeddingError> {
283 Ok(vec![0.1, 0.2, 0.3])
284 }
285 async fn embed_batch(
286 &self,
287 texts: &[String],
288 ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
289 Ok(texts.iter().map(|_| vec![0.1, 0.2, 0.3]).collect())
290 }
291 fn dimension(&self) -> usize {
292 3
293 }
294 }
295
296 #[derive(Default)]
298 struct StubVectorStore {
299 docs: Mutex<Vec<Document>>,
300 }
301
302 #[async_trait]
303 impl VectorStore for StubVectorStore {
304 async fn add(&self, new_docs: Vec<Document>) -> Result<(), wesichain_core::StoreError> {
305 self.docs.lock().unwrap().extend(new_docs);
306 Ok(())
307 }
308 async fn search(
309 &self,
310 _query: &[f32],
311 top_k: usize,
312 _filter: Option<&wesichain_core::MetadataFilter>,
313 ) -> Result<Vec<wesichain_core::SearchResult>, wesichain_core::StoreError> {
314 let docs = self.docs.lock().unwrap();
315 Ok(docs
316 .iter()
317 .take(top_k)
318 .map(|d| wesichain_core::SearchResult { document: d.clone(), score: 1.0 })
319 .collect())
320 }
321 async fn delete(&self, _ids: &[String]) -> Result<(), wesichain_core::StoreError> {
322 Ok(())
323 }
324 }
325
326 #[tokio::test]
327 async fn vector_memory_round_trip() {
328 let store = VectorMemoryStore::new(
329 Arc::new(StubEmbedder),
330 Arc::new(StubVectorStore::default()),
331 5,
332 );
333
334 let mut inputs = HashMap::new();
335 inputs.insert("input".to_string(), Value::String("Hello".to_string()));
336 let mut outputs = HashMap::new();
337 outputs.insert("output".to_string(), Value::String("Hi!".to_string()));
338
339 store.save_context("t1", &inputs, &outputs).await.unwrap();
340
341 let vars = store.load_memory_variables("latest message").await.unwrap();
342 let history = vars.get("history").unwrap().as_array().unwrap();
343 assert_eq!(history.len(), 1);
344 assert!(history[0].as_str().unwrap().contains("Human: Hello"));
345 }
346
347 #[tokio::test]
348 async fn entity_memory_stores_and_loads() {
349 let mem = EntityMemory::new();
350 let mut inputs = HashMap::new();
351 inputs.insert("user".to_string(), Value::String("Alice".to_string()));
352 let mut outputs = HashMap::new();
353 outputs.insert("project".to_string(), Value::String("wesichain".to_string()));
354
355 mem.save_context("t1", &inputs, &outputs).await.unwrap();
356
357 let vars = mem.load_memory_variables("t1").await.unwrap();
358 let entities = vars.get("entities").unwrap().as_object().unwrap();
359 assert_eq!(entities.get("project").unwrap().as_str().unwrap(), "wesichain");
360 }
361
362 #[tokio::test]
363 async fn entity_memory_clear() {
364 let mem = EntityMemory::new();
365 mem.upsert("t1", "foo", Value::String("bar".to_string()));
366 mem.clear("t1").await.unwrap();
367 assert!(mem.entities("t1").is_empty());
368 }
369
370 #[tokio::test]
371 async fn memory_router_merges_layers() {
372 let entity_mem = Arc::new(EntityMemory::new());
373 entity_mem.upsert("t1", "entity_key", Value::String("v1".to_string()));
374
375 let router = MemoryRouter::new(vec![entity_mem]);
376 let vars = router.load_memory_variables("t1").await.unwrap();
377 assert!(vars.contains_key("entities"));
378 }
379}