Skip to main content

walrus_memory/sqlite/
mod.rs

1//! SQLite-backed memory store with optional embedding support.
2//!
3//! Wraps a `rusqlite::Connection` in a `Mutex` for thread safety.
4//! Generic over `E: Embedder` for optional vector search.
5
6use crate::{
7    Embedder, MemoryEntry, RecallOptions,
8    utils::{cosine_similarity, decode_embedding, mmr_rerank, now_unix},
9};
10use anyhow::Result;
11use compact_str::CompactString;
12use rusqlite::Connection;
13use serde_json::Value;
14use std::{collections::HashMap, path::Path, sync::Mutex};
15
16mod memory;
17mod sql;
18
19/// SQLite-backed memory store with optional embedding support.
20pub struct SqliteMemory<E: Embedder> {
21    pub(crate) conn: Mutex<Connection>,
22    pub(crate) embedder: Option<E>,
23}
24
25impl<E: Embedder> SqliteMemory<E> {
26    /// Open or create a SQLite database at the given path.
27    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
28        let conn = Connection::open(path)?;
29        let mem = Self {
30            conn: Mutex::new(conn),
31            embedder: None,
32        };
33        mem.init_schema()?;
34        Ok(mem)
35    }
36
37    /// Create an in-memory database (useful for testing).
38    pub fn in_memory() -> Result<Self> {
39        let conn = Connection::open_in_memory()?;
40        let mem = Self {
41            conn: Mutex::new(conn),
42            embedder: None,
43        };
44        mem.init_schema()?;
45        Ok(mem)
46    }
47
48    /// Attach an embedder for vector search.
49    pub fn with_embedder(mut self, embedder: E) -> Self {
50        self.embedder = Some(embedder);
51        self
52    }
53
54    /// Initialize the database schema.
55    fn init_schema(&self) -> Result<()> {
56        let conn = self.conn.lock().unwrap();
57        conn.execute_batch(sql::SCHEMA)?;
58        Ok(())
59    }
60
61    /// Execute the recall pipeline synchronously.
62    ///
63    /// 1. BM25 via FTS5 MATCH (under lock)
64    /// 2. Vector scan (under lock, if embeddings requested)
65    /// 3. Lock released — scoring, RRF fusion, MMR done without lock
66    pub(crate) fn recall_sync(
67        &self,
68        query: &str,
69        options: &RecallOptions,
70        query_embedding: Option<&[f32]>,
71    ) -> Result<Vec<MemoryEntry>> {
72        let now = now_unix();
73        let limit = if options.limit == 0 {
74            10
75        } else {
76            options.limit
77        };
78
79        // Phase 1: DB queries under lock. Collect raw rows, release lock.
80        let (bm25_candidates, vec_candidates) = {
81            let conn = self.conn.lock().unwrap();
82
83            // BM25 path: FTS5 MATCH.
84            let mut fts_stmt = conn.prepare(sql::RECALL_FTS)?;
85            let bm25: Vec<(MemoryEntry, f64)> = fts_stmt
86                .query_map([query], |row| {
87                    let emb_blob: Option<Vec<u8>> = row.get(6)?;
88                    Ok(MemoryEntry {
89                        key: CompactString::new(row.get::<_, String>(0)?),
90                        value: row.get(1)?,
91                        metadata: row
92                            .get::<_, Option<String>>(2)?
93                            .and_then(|s| serde_json::from_str(&s).ok()),
94                        created_at: row.get::<_, i64>(3)? as u64,
95                        accessed_at: row.get::<_, i64>(4)? as u64,
96                        access_count: row.get::<_, i32>(5)? as u32,
97                        embedding: emb_blob.map(|b| decode_embedding(&b)),
98                    })
99                    .map(|entry| (entry, row.get::<_, f64>(7).unwrap_or(0.0)))
100                })?
101                .filter_map(|r| r.ok())
102                .collect();
103
104            // Vector path (only if query embedding provided).
105            let vec = if query_embedding.is_some() {
106                let mut vec_stmt = conn.prepare(sql::RECALL_VECTOR)?;
107                vec_stmt
108                    .query_map([], |row| {
109                        let emb_blob: Option<Vec<u8>> = row.get(6)?;
110                        Ok(MemoryEntry {
111                            key: CompactString::new(row.get::<_, String>(0)?),
112                            value: row.get(1)?,
113                            metadata: row
114                                .get::<_, Option<String>>(2)?
115                                .and_then(|s| serde_json::from_str(&s).ok()),
116                            created_at: row.get::<_, i64>(3)? as u64,
117                            accessed_at: row.get::<_, i64>(4)? as u64,
118                            access_count: row.get::<_, i32>(5)? as u32,
119                            embedding: emb_blob.map(|b| decode_embedding(&b)),
120                        })
121                    })?
122                    .filter_map(|r| r.ok())
123                    .collect::<Vec<_>>()
124            } else {
125                Vec::new()
126            };
127
128            (bm25, vec)
129            // conn lock dropped here
130        };
131
132        // Phase 2: Scoring and fusion (no lock held).
133
134        // Temporal decay: score * e^(-lambda * age_days), half-life 30 days.
135        let lambda = std::f64::consts::LN_2 / 30.0;
136        let bm25_scored: Vec<(MemoryEntry, f64)> = bm25_candidates
137            .into_iter()
138            .map(|(entry, bm25_rank)| {
139                let bm25_score = -bm25_rank;
140                let age_days = now.saturating_sub(entry.accessed_at) as f64 / 86400.0;
141                let decay = (-lambda * age_days).exp();
142                (entry, bm25_score * decay)
143            })
144            .collect();
145
146        let scored = if let Some(q_emb) = query_embedding {
147            // Compute cosine similarity for vector candidates.
148            let vec_scored: Vec<(MemoryEntry, f64)> = vec_candidates
149                .into_iter()
150                .filter_map(|entry| {
151                    let sim = entry
152                        .embedding
153                        .as_ref()
154                        .map(|e| cosine_similarity(e, q_emb))
155                        .unwrap_or(0.0);
156                    if sim > 0.0 { Some((entry, sim)) } else { None }
157                })
158                .collect();
159
160            // RRF fusion: score = 1/(k + rank_bm25) + 1/(k + rank_vector), k=60.
161            let k = 60.0_f64;
162
163            let mut bm25_ranked = bm25_scored;
164            bm25_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
165
166            let mut vec_ranked = vec_scored;
167            vec_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
168
169            // Compute RRF scores while vecs are borrowed, then drain entries.
170            let rrf_scores: Vec<f64>;
171            let bm25_in_vec: Vec<bool>;
172            {
173                let vec_rank_map: HashMap<&str, usize> = vec_ranked
174                    .iter()
175                    .enumerate()
176                    .map(|(i, (e, _))| (e.key.as_str(), i + 1))
177                    .collect();
178                let bm25_key_set: HashMap<&str, ()> = bm25_ranked
179                    .iter()
180                    .map(|(e, _)| (e.key.as_str(), ()))
181                    .collect();
182
183                rrf_scores = bm25_ranked
184                    .iter()
185                    .enumerate()
186                    .map(|(i, (e, _))| {
187                        1.0 / (k + (i + 1) as f64)
188                            + vec_rank_map
189                                .get(e.key.as_str())
190                                .map(|&r| 1.0 / (k + r as f64))
191                                .unwrap_or(0.0)
192                    })
193                    .collect();
194
195                bm25_in_vec = vec_ranked
196                    .iter()
197                    .map(|(e, _)| bm25_key_set.contains_key(e.key.as_str()))
198                    .collect();
199            }
200
201            let mut fused = Vec::with_capacity(bm25_ranked.len() + vec_ranked.len());
202            for (score, (entry, _)) in rrf_scores.into_iter().zip(bm25_ranked) {
203                fused.push((entry, score));
204            }
205            for (i, (entry, _)) in vec_ranked.into_iter().enumerate() {
206                if bm25_in_vec[i] {
207                    continue;
208                }
209                fused.push((entry, 1.0 / (k + (i + 1) as f64)));
210            }
211            fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212            fused
213        } else {
214            bm25_scored
215        };
216
217        if scored.is_empty() {
218            return Ok(Vec::new());
219        }
220
221        // Phase 3: Filters and MMR (no lock held).
222        let mut filtered = scored;
223        if let Some((start, end)) = options.time_range {
224            filtered.retain(|(entry, _)| entry.created_at >= start && entry.created_at <= end);
225        }
226        if let Some(threshold) = options.relevance_threshold {
227            filtered.retain(|(_, score)| *score >= threshold as f64);
228        }
229        if filtered.is_empty() {
230            return Ok(Vec::new());
231        }
232
233        let use_cosine = query_embedding.is_some();
234        Ok(mmr_rerank(filtered, limit, 0.7, use_cosine))
235    }
236
237    /// Store a key-value pair with optional metadata and embedding.
238    pub fn store_with_metadata(
239        &self,
240        key: &str,
241        value: &str,
242        metadata: Option<&Value>,
243        embedding: Option<&[f32]>,
244    ) -> Result<()> {
245        let conn = self.conn.lock().unwrap();
246        let now = now_unix() as i64;
247        let meta_json = metadata.map(|m| serde_json::to_string(m).unwrap());
248        let emb_blob: Option<Vec<u8>> =
249            embedding.map(|e| e.iter().flat_map(|f| f.to_le_bytes()).collect());
250
251        conn.execute(
252            sql::UPSERT_FULL,
253            rusqlite::params![key, value, meta_json, now, emb_blob],
254        )?;
255        Ok(())
256    }
257
258    /// Get a full MemoryEntry for a key.
259    pub fn get_entry(&self, key: &str) -> Option<MemoryEntry> {
260        let conn = self.conn.lock().unwrap();
261        conn.query_row(sql::SELECT_ENTRY, [key], |row| {
262            let emb_blob: Option<Vec<u8>> = row.get(6)?;
263            Ok(MemoryEntry {
264                key: CompactString::new(row.get::<_, String>(0)?),
265                value: row.get(1)?,
266                metadata: row
267                    .get::<_, Option<String>>(2)?
268                    .and_then(|s| serde_json::from_str(&s).ok()),
269                created_at: row.get::<_, i64>(3)? as u64,
270                accessed_at: row.get::<_, i64>(4)? as u64,
271                access_count: row.get::<_, i32>(5)? as u32,
272                embedding: emb_blob.map(|b| decode_embedding(&b)),
273            })
274        })
275        .ok()
276    }
277}