Skip to main content

walrus_memory/sqlite/
mod.rs

1//! SQLite-backed memory store with optional embedding support.
2//!
3//! Wraps a `rusqlite::Connection` in a `Mutex` for thread safety.
4//! Generic over `E: Embedder` for optional vector search.
5
6use anyhow::Result;
7use compact_str::CompactString;
8use rusqlite::Connection;
9use serde_json::Value;
10use std::{collections::HashMap, path::Path, sync::Mutex};
11use utils::{cosine_similarity, decode_embedding, mmr_rerank, now_unix};
12use wcore::{Embedder, MemoryEntry, RecallOptions};
13
14mod memory;
15mod sql;
16mod utils;
17
18/// SQLite-backed memory store with optional embedding support.
19pub struct SqliteMemory<E: Embedder> {
20    pub(crate) conn: Mutex<Connection>,
21    pub(crate) embedder: Option<E>,
22}
23
24impl<E: Embedder> SqliteMemory<E> {
25    /// Open or create a SQLite database at the given path.
26    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
27        let conn = Connection::open(path)?;
28        let mem = Self {
29            conn: Mutex::new(conn),
30            embedder: None,
31        };
32        mem.init_schema()?;
33        Ok(mem)
34    }
35
36    /// Create an in-memory database (useful for testing).
37    pub fn in_memory() -> Result<Self> {
38        let conn = Connection::open_in_memory()?;
39        let mem = Self {
40            conn: Mutex::new(conn),
41            embedder: None,
42        };
43        mem.init_schema()?;
44        Ok(mem)
45    }
46
47    /// Attach an embedder for vector search.
48    pub fn with_embedder(mut self, embedder: E) -> Self {
49        self.embedder = Some(embedder);
50        self
51    }
52
53    /// Initialize the database schema.
54    fn init_schema(&self) -> Result<()> {
55        let conn = self.conn.lock().unwrap_or_else(|e| e.into_inner());
56        conn.execute_batch(sql::SCHEMA)?;
57        Ok(())
58    }
59
60    /// Execute the recall pipeline synchronously.
61    ///
62    /// 1. BM25 via FTS5 MATCH (under lock)
63    /// 2. Vector scan (under lock, if embeddings requested)
64    /// 3. Lock released — scoring, RRF fusion, MMR done without lock
65    pub(crate) fn recall_sync(
66        &self,
67        query: &str,
68        options: &RecallOptions,
69        query_embedding: Option<&[f32]>,
70    ) -> Result<Vec<MemoryEntry>> {
71        let now = now_unix();
72        let limit = if options.limit == 0 {
73            10
74        } else {
75            options.limit
76        };
77
78        // Phase 1: DB queries under lock. Collect raw rows, release lock.
79        let (bm25_candidates, vec_candidates) = {
80            let conn = self.conn.lock().unwrap_or_else(|e| e.into_inner());
81
82            // BM25 path: FTS5 MATCH.
83            let mut fts_stmt = conn.prepare(sql::RECALL_FTS)?;
84            let bm25: Vec<(MemoryEntry, f64)> = fts_stmt
85                .query_map([query], |row| {
86                    let emb_blob: Option<Vec<u8>> = row.get(6)?;
87                    Ok(MemoryEntry {
88                        key: CompactString::new(row.get::<_, String>(0)?),
89                        value: row.get(1)?,
90                        metadata: row
91                            .get::<_, Option<String>>(2)?
92                            .and_then(|s| serde_json::from_str(&s).ok()),
93                        created_at: row.get::<_, i64>(3)? as u64,
94                        accessed_at: row.get::<_, i64>(4)? as u64,
95                        access_count: row.get::<_, i32>(5)? as u32,
96                        embedding: emb_blob.map(|b| decode_embedding(&b)),
97                    })
98                    .map(|entry| (entry, row.get::<_, f64>(7).unwrap_or(0.0)))
99                })?
100                .filter_map(|r| r.ok())
101                .collect();
102
103            // Vector path (only if query embedding provided).
104            let vec = if query_embedding.is_some() {
105                let mut vec_stmt = conn.prepare(sql::RECALL_VECTOR)?;
106                vec_stmt
107                    .query_map([], |row| {
108                        let emb_blob: Option<Vec<u8>> = row.get(6)?;
109                        Ok(MemoryEntry {
110                            key: CompactString::new(row.get::<_, String>(0)?),
111                            value: row.get(1)?,
112                            metadata: row
113                                .get::<_, Option<String>>(2)?
114                                .and_then(|s| serde_json::from_str(&s).ok()),
115                            created_at: row.get::<_, i64>(3)? as u64,
116                            accessed_at: row.get::<_, i64>(4)? as u64,
117                            access_count: row.get::<_, i32>(5)? as u32,
118                            embedding: emb_blob.map(|b| decode_embedding(&b)),
119                        })
120                    })?
121                    .filter_map(|r| r.ok())
122                    .collect::<Vec<_>>()
123            } else {
124                Vec::new()
125            };
126
127            (bm25, vec)
128            // conn lock dropped here
129        };
130
131        // Phase 2: Scoring and fusion (no lock held).
132
133        // Temporal decay: score * e^(-lambda * age_days), half-life 30 days.
134        let lambda = std::f64::consts::LN_2 / 30.0;
135        let bm25_scored: Vec<(MemoryEntry, f64)> = bm25_candidates
136            .into_iter()
137            .map(|(entry, bm25_rank)| {
138                let bm25_score = -bm25_rank;
139                let age_days = now.saturating_sub(entry.accessed_at) as f64 / 86400.0;
140                let decay = (-lambda * age_days).exp();
141                (entry, bm25_score * decay)
142            })
143            .collect();
144
145        let scored = if let Some(q_emb) = query_embedding {
146            // Compute cosine similarity for vector candidates.
147            let vec_scored: Vec<(MemoryEntry, f64)> = vec_candidates
148                .into_iter()
149                .filter_map(|entry| {
150                    let sim = entry
151                        .embedding
152                        .as_ref()
153                        .map(|e| cosine_similarity(e, q_emb))
154                        .unwrap_or(0.0);
155                    if sim > 0.0 { Some((entry, sim)) } else { None }
156                })
157                .collect();
158
159            // RRF fusion: score = 1/(k + rank_bm25) + 1/(k + rank_vector), k=60.
160            let k = 60.0_f64;
161
162            let mut bm25_ranked = bm25_scored;
163            bm25_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
164
165            let mut vec_ranked = vec_scored;
166            vec_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
167
168            // Compute RRF scores while vecs are borrowed, then drain entries.
169            let rrf_scores: Vec<f64>;
170            let bm25_in_vec: Vec<bool>;
171            {
172                let vec_rank_map: HashMap<&str, usize> = vec_ranked
173                    .iter()
174                    .enumerate()
175                    .map(|(i, (e, _))| (e.key.as_str(), i + 1))
176                    .collect();
177                let bm25_key_set: HashMap<&str, ()> = bm25_ranked
178                    .iter()
179                    .map(|(e, _)| (e.key.as_str(), ()))
180                    .collect();
181
182                rrf_scores = bm25_ranked
183                    .iter()
184                    .enumerate()
185                    .map(|(i, (e, _))| {
186                        1.0 / (k + (i + 1) as f64)
187                            + vec_rank_map
188                                .get(e.key.as_str())
189                                .map(|&r| 1.0 / (k + r as f64))
190                                .unwrap_or(0.0)
191                    })
192                    .collect();
193
194                bm25_in_vec = vec_ranked
195                    .iter()
196                    .map(|(e, _)| bm25_key_set.contains_key(e.key.as_str()))
197                    .collect();
198            }
199
200            let mut fused = Vec::with_capacity(bm25_ranked.len() + vec_ranked.len());
201            for (score, (entry, _)) in rrf_scores.into_iter().zip(bm25_ranked) {
202                fused.push((entry, score));
203            }
204            for (i, (entry, _)) in vec_ranked.into_iter().enumerate() {
205                if bm25_in_vec[i] {
206                    continue;
207                }
208                fused.push((entry, 1.0 / (k + (i + 1) as f64)));
209            }
210            fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
211            fused
212        } else {
213            bm25_scored
214        };
215
216        if scored.is_empty() {
217            return Ok(Vec::new());
218        }
219
220        // Phase 3: Filters and MMR (no lock held).
221        let mut filtered = scored;
222        if let Some((start, end)) = options.time_range {
223            filtered.retain(|(entry, _)| entry.created_at >= start && entry.created_at <= end);
224        }
225        if let Some(threshold) = options.relevance_threshold {
226            filtered.retain(|(_, score)| *score >= threshold as f64);
227        }
228        if filtered.is_empty() {
229            return Ok(Vec::new());
230        }
231
232        let use_cosine = query_embedding.is_some();
233        Ok(mmr_rerank(filtered, limit, 0.7, use_cosine))
234    }
235
236    /// Store a key-value pair with optional metadata and embedding.
237    pub fn store_with_metadata(
238        &self,
239        key: &str,
240        value: &str,
241        metadata: Option<&Value>,
242        embedding: Option<&[f32]>,
243    ) -> Result<()> {
244        let conn = self.conn.lock().unwrap_or_else(|e| e.into_inner());
245        let now = now_unix() as i64;
246        let meta_json = metadata.map(serde_json::to_string).transpose()?;
247        let emb_blob: Option<Vec<u8>> =
248            embedding.map(|e| e.iter().flat_map(|f| f.to_le_bytes()).collect());
249
250        conn.execute(
251            sql::UPSERT_FULL,
252            rusqlite::params![key, value, meta_json, now, emb_blob],
253        )?;
254        Ok(())
255    }
256
257    /// Get a full MemoryEntry for a key.
258    pub fn get_entry(&self, key: &str) -> Option<MemoryEntry> {
259        let conn = self.conn.lock().unwrap_or_else(|e| e.into_inner());
260        conn.query_row(sql::SELECT_ENTRY, [key], |row| {
261            let emb_blob: Option<Vec<u8>> = row.get(6)?;
262            Ok(MemoryEntry {
263                key: CompactString::new(row.get::<_, String>(0)?),
264                value: row.get(1)?,
265                metadata: row
266                    .get::<_, Option<String>>(2)?
267                    .and_then(|s| serde_json::from_str(&s).ok()),
268                created_at: row.get::<_, i64>(3)? as u64,
269                accessed_at: row.get::<_, i64>(4)? as u64,
270                access_count: row.get::<_, i32>(5)? as u32,
271                embedding: emb_blob.map(|b| decode_embedding(&b)),
272            })
273        })
274        .ok()
275    }
276}