Skip to main content

walrus_memory/
lib.rs

1//! Memory backends for Walrus agents.
2//!
3//! Defines the [`Memory`] trait, [`Embedder`] trait, and concrete implementations:
4//! [`InMemory`] (volatile) and [`SqliteMemory`] (persistent with FTS5 + vector recall).
5//!
6//! Memory is **not chat history**. It is structured knowledge — extracted facts,
7//! user preferences, agent persona — that gets compiled into the system prompt.
8//!
9//! All SQL lives in `sql/*.sql` files, loaded via `include_str!`.
10
11pub use crate::utils::cosine_similarity;
12use crate::utils::{decode_embedding, mmr_rerank, now_unix};
13use anyhow::Result;
14use compact_str::CompactString;
15use rusqlite::Connection;
16use serde_json::Value;
17use std::{collections::HashMap, future::Future, path::Path, sync::Mutex};
18
19mod embedder;
20mod inmemory;
21mod memory;
22mod sql;
23mod utils;
24
25pub use embedder::{Embedder, NoEmbedder};
26pub use inmemory::InMemory;
27
28/// A structured memory entry with metadata and optional embedding.
29#[derive(Debug, Clone, Default)]
30pub struct MemoryEntry {
31    /// Entry key (identity string).
32    pub key: CompactString,
33    /// Entry value (unbounded content).
34    pub value: String,
35    /// Optional structured metadata (JSON).
36    pub metadata: Option<Value>,
37    /// Unix timestamp when the entry was created.
38    pub created_at: u64,
39    /// Unix timestamp when the entry was last accessed.
40    pub accessed_at: u64,
41    /// Number of times the entry has been accessed.
42    pub access_count: u32,
43    /// Optional embedding vector for semantic search.
44    pub embedding: Option<Vec<f32>>,
45}
46
47/// Options controlling memory recall behavior.
48#[derive(Debug, Clone, Default)]
49pub struct RecallOptions {
50    /// Maximum number of results (0 = implementation default).
51    pub limit: usize,
52    /// Filter by creation time range (start, end) in unix seconds.
53    pub time_range: Option<(u64, u64)>,
54    /// Minimum relevance score threshold (0.0–1.0).
55    pub relevance_threshold: Option<f32>,
56}
57
58/// Structured knowledge memory for LLM agents.
59///
60/// Implementations store named key-value pairs that get compiled
61/// into the system prompt via [`compile()`](Memory::compile).
62///
63/// Uses `&self` for all methods — implementations must handle
64/// interior mutability (e.g. via `Mutex`).
65pub trait Memory: Send + Sync {
66    /// Get the value for a key (owned).
67    fn get(&self, key: &str) -> Option<String>;
68
69    /// Get all key-value pairs (owned).
70    fn entries(&self) -> Vec<(String, String)>;
71
72    /// Set (upsert) a key-value pair. Returns the previous value if the key existed.
73    fn set(&self, key: impl Into<String>, value: impl Into<String>) -> Option<String>;
74
75    /// Remove a key. Returns the removed value if it existed.
76    fn remove(&self, key: &str) -> Option<String>;
77
78    /// Compile all entries into a string for system prompt injection.
79    fn compile(&self) -> String {
80        let entries = self.entries();
81        if entries.is_empty() {
82            return String::new();
83        }
84
85        let mut out = String::from("<memory>\n");
86        for (key, value) in &entries {
87            out.push_str(&format!("<{key}>\n"));
88            out.push_str(value);
89            if !value.ends_with('\n') {
90                out.push('\n');
91            }
92            out.push_str(&format!("</{key}>\n"));
93        }
94        out.push_str("</memory>");
95        out
96    }
97
98    /// Store a key-value pair (async). Default delegates to `set`.
99    fn store(
100        &self,
101        key: impl Into<String> + Send,
102        value: impl Into<String> + Send,
103    ) -> impl Future<Output = Result<()>> + Send {
104        self.set(key, value);
105        async { Ok(()) }
106    }
107
108    /// Search for relevant entries (async). Default returns empty.
109    fn recall(
110        &self,
111        _query: &str,
112        _options: RecallOptions,
113    ) -> impl Future<Output = Result<Vec<MemoryEntry>>> + Send {
114        async { Ok(Vec::new()) }
115    }
116
117    /// Compile relevant entries for a query (async). Default delegates to `compile`.
118    fn compile_relevant(&self, _query: &str) -> impl Future<Output = String> + Send {
119        let compiled = self.compile();
120        async move { compiled }
121    }
122}
123
124/// Apply memory to an agent config — appends compiled memory to the system prompt.
125pub fn with_memory(mut config: wcore::AgentConfig, memory: &impl Memory) -> wcore::AgentConfig {
126    let compiled = memory.compile();
127    if !compiled.is_empty() {
128        config.system_prompt = format!("{}\n\n{compiled}", config.system_prompt);
129    }
130    config
131}
132
133/// SQLite-backed memory store with optional embedding support.
134///
135/// Wraps a `rusqlite::Connection` in a `Mutex` for thread safety.
136/// Generic over `E: Embedder` for optional vector search.
137pub struct SqliteMemory<E: Embedder> {
138    conn: Mutex<Connection>,
139    embedder: Option<E>,
140}
141
142impl<E: Embedder> SqliteMemory<E> {
143    /// Open or create a SQLite database at the given path.
144    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
145        let conn = Connection::open(path)?;
146        let mem = Self {
147            conn: Mutex::new(conn),
148            embedder: None,
149        };
150        mem.init_schema()?;
151        Ok(mem)
152    }
153
154    /// Create an in-memory database (useful for testing).
155    pub fn in_memory() -> Result<Self> {
156        let conn = Connection::open_in_memory()?;
157        let mem = Self {
158            conn: Mutex::new(conn),
159            embedder: None,
160        };
161        mem.init_schema()?;
162        Ok(mem)
163    }
164
165    /// Attach an embedder for vector search.
166    pub fn with_embedder(mut self, embedder: E) -> Self {
167        self.embedder = Some(embedder);
168        self
169    }
170
171    /// Initialize the database schema.
172    fn init_schema(&self) -> Result<()> {
173        let conn = self.conn.lock().unwrap();
174        conn.execute_batch(sql::SCHEMA)?;
175        Ok(())
176    }
177
178    /// Execute the recall pipeline synchronously.
179    ///
180    /// 1. BM25 via FTS5 MATCH (under lock)
181    /// 2. Vector scan (under lock, if embeddings requested)
182    /// 3. Lock released — scoring, RRF fusion, MMR done without lock
183    fn recall_sync(
184        &self,
185        query: &str,
186        options: &RecallOptions,
187        query_embedding: Option<&[f32]>,
188    ) -> Result<Vec<MemoryEntry>> {
189        let now = now_unix();
190        let limit = if options.limit == 0 {
191            10
192        } else {
193            options.limit
194        };
195
196        // Phase 1: DB queries under lock. Collect raw rows, release lock.
197        let (bm25_candidates, vec_candidates) = {
198            let conn = self.conn.lock().unwrap();
199
200            // BM25 path: FTS5 MATCH.
201            let mut fts_stmt = conn.prepare(sql::RECALL_FTS)?;
202            let bm25: Vec<(MemoryEntry, f64)> = fts_stmt
203                .query_map([query], |row| {
204                    let emb_blob: Option<Vec<u8>> = row.get(6)?;
205                    Ok(MemoryEntry {
206                        key: CompactString::new(row.get::<_, String>(0)?),
207                        value: row.get(1)?,
208                        metadata: row
209                            .get::<_, Option<String>>(2)?
210                            .and_then(|s| serde_json::from_str(&s).ok()),
211                        created_at: row.get::<_, i64>(3)? as u64,
212                        accessed_at: row.get::<_, i64>(4)? as u64,
213                        access_count: row.get::<_, i32>(5)? as u32,
214                        embedding: emb_blob.map(|b| decode_embedding(&b)),
215                    })
216                    .map(|entry| (entry, row.get::<_, f64>(7).unwrap_or(0.0)))
217                })?
218                .filter_map(|r| r.ok())
219                .collect();
220
221            // Vector path (only if query embedding provided).
222            let vec = if query_embedding.is_some() {
223                let mut vec_stmt = conn.prepare(sql::RECALL_VECTOR)?;
224                vec_stmt
225                    .query_map([], |row| {
226                        let emb_blob: Option<Vec<u8>> = row.get(6)?;
227                        Ok(MemoryEntry {
228                            key: CompactString::new(row.get::<_, String>(0)?),
229                            value: row.get(1)?,
230                            metadata: row
231                                .get::<_, Option<String>>(2)?
232                                .and_then(|s| serde_json::from_str(&s).ok()),
233                            created_at: row.get::<_, i64>(3)? as u64,
234                            accessed_at: row.get::<_, i64>(4)? as u64,
235                            access_count: row.get::<_, i32>(5)? as u32,
236                            embedding: emb_blob.map(|b| decode_embedding(&b)),
237                        })
238                    })?
239                    .filter_map(|r| r.ok())
240                    .collect::<Vec<_>>()
241            } else {
242                Vec::new()
243            };
244
245            (bm25, vec)
246            // conn lock dropped here
247        };
248
249        // Phase 2: Scoring and fusion (no lock held).
250
251        // Temporal decay: score * e^(-lambda * age_days), half-life 30 days.
252        let lambda = std::f64::consts::LN_2 / 30.0;
253        let bm25_scored: Vec<(MemoryEntry, f64)> = bm25_candidates
254            .into_iter()
255            .map(|(entry, bm25_rank)| {
256                let bm25_score = -bm25_rank;
257                let age_days = now.saturating_sub(entry.accessed_at) as f64 / 86400.0;
258                let decay = (-lambda * age_days).exp();
259                (entry, bm25_score * decay)
260            })
261            .collect();
262
263        let scored = if let Some(q_emb) = query_embedding {
264            // Compute cosine similarity for vector candidates.
265            let vec_scored: Vec<(MemoryEntry, f64)> = vec_candidates
266                .into_iter()
267                .filter_map(|entry| {
268                    let sim = entry
269                        .embedding
270                        .as_ref()
271                        .map(|e| cosine_similarity(e, q_emb))
272                        .unwrap_or(0.0);
273                    if sim > 0.0 { Some((entry, sim)) } else { None }
274                })
275                .collect();
276
277            // RRF fusion: score = 1/(k + rank_bm25) + 1/(k + rank_vector), k=60.
278            // Borrowed-key HashMaps for O(1) rank lookup, no key cloning.
279            let k = 60.0_f64;
280
281            let mut bm25_ranked = bm25_scored;
282            bm25_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
283
284            let mut vec_ranked = vec_scored;
285            vec_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
286
287            // Compute RRF scores while vecs are borrowed, then drain entries.
288            let rrf_scores: Vec<f64>;
289            let bm25_in_vec: Vec<bool>;
290            {
291                let vec_rank_map: HashMap<&str, usize> = vec_ranked
292                    .iter()
293                    .enumerate()
294                    .map(|(i, (e, _))| (e.key.as_str(), i + 1))
295                    .collect();
296                let bm25_key_set: HashMap<&str, ()> = bm25_ranked
297                    .iter()
298                    .map(|(e, _)| (e.key.as_str(), ()))
299                    .collect();
300
301                // Score BM25 entries (index = rank).
302                rrf_scores = bm25_ranked
303                    .iter()
304                    .enumerate()
305                    .map(|(i, (e, _))| {
306                        1.0 / (k + (i + 1) as f64)
307                            + vec_rank_map
308                                .get(e.key.as_str())
309                                .map(|&r| 1.0 / (k + r as f64))
310                                .unwrap_or(0.0)
311                    })
312                    .collect();
313
314                // Mark which vec entries are also in bm25 (for dedup).
315                bm25_in_vec = vec_ranked
316                    .iter()
317                    .map(|(e, _)| bm25_key_set.contains_key(e.key.as_str()))
318                    .collect();
319                // borrowed maps dropped here
320            }
321
322            // Drain entries and pair with scores.
323            let mut fused = Vec::with_capacity(bm25_ranked.len() + vec_ranked.len());
324            for (score, (entry, _)) in rrf_scores.into_iter().zip(bm25_ranked) {
325                fused.push((entry, score));
326            }
327            for (i, (entry, _)) in vec_ranked.into_iter().enumerate() {
328                if bm25_in_vec[i] {
329                    continue;
330                }
331                fused.push((entry, 1.0 / (k + (i + 1) as f64)));
332            }
333            fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
334            fused
335        } else {
336            bm25_scored
337        };
338
339        if scored.is_empty() {
340            return Ok(Vec::new());
341        }
342
343        // Phase 3: Filters and MMR (no lock held).
344        let mut filtered = scored;
345        if let Some((start, end)) = options.time_range {
346            filtered.retain(|(entry, _)| entry.created_at >= start && entry.created_at <= end);
347        }
348        if let Some(threshold) = options.relevance_threshold {
349            filtered.retain(|(_, score)| *score >= threshold as f64);
350        }
351        if filtered.is_empty() {
352            return Ok(Vec::new());
353        }
354
355        let use_cosine = query_embedding.is_some();
356        Ok(mmr_rerank(filtered, limit, 0.7, use_cosine))
357    }
358
359    /// Store a key-value pair with optional metadata and embedding.
360    pub fn store_with_metadata(
361        &self,
362        key: &str,
363        value: &str,
364        metadata: Option<&Value>,
365        embedding: Option<&[f32]>,
366    ) -> Result<()> {
367        let conn = self.conn.lock().unwrap();
368        let now = now_unix() as i64;
369        let meta_json = metadata.map(|m| serde_json::to_string(m).unwrap());
370        let emb_blob: Option<Vec<u8>> =
371            embedding.map(|e| e.iter().flat_map(|f| f.to_le_bytes()).collect());
372
373        conn.execute(
374            sql::UPSERT_FULL,
375            rusqlite::params![key, value, meta_json, now, emb_blob],
376        )?;
377        Ok(())
378    }
379
380    /// Get a full MemoryEntry for a key.
381    pub fn get_entry(&self, key: &str) -> Option<MemoryEntry> {
382        let conn = self.conn.lock().unwrap();
383        conn.query_row(sql::SELECT_ENTRY, [key], |row| {
384            let emb_blob: Option<Vec<u8>> = row.get(6)?;
385            Ok(MemoryEntry {
386                key: CompactString::new(row.get::<_, String>(0)?),
387                value: row.get(1)?,
388                metadata: row
389                    .get::<_, Option<String>>(2)?
390                    .and_then(|s| serde_json::from_str(&s).ok()),
391                created_at: row.get::<_, i64>(3)? as u64,
392                accessed_at: row.get::<_, i64>(4)? as u64,
393                access_count: row.get::<_, i32>(5)? as u32,
394                embedding: emb_blob.map(|b| decode_embedding(&b)),
395            })
396        })
397        .ok()
398    }
399}