1use crate::{
7 Embedder, MemoryEntry, RecallOptions,
8 utils::{cosine_similarity, decode_embedding, mmr_rerank, now_unix},
9};
10use anyhow::Result;
11use compact_str::CompactString;
12use rusqlite::Connection;
13use serde_json::Value;
14use std::{collections::HashMap, path::Path, sync::Mutex};
15
16mod memory;
17mod sql;
18
19pub struct SqliteMemory<E: Embedder> {
21 pub(crate) conn: Mutex<Connection>,
22 pub(crate) embedder: Option<E>,
23}
24
25impl<E: Embedder> SqliteMemory<E> {
26 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
28 let conn = Connection::open(path)?;
29 let mem = Self {
30 conn: Mutex::new(conn),
31 embedder: None,
32 };
33 mem.init_schema()?;
34 Ok(mem)
35 }
36
37 pub fn in_memory() -> Result<Self> {
39 let conn = Connection::open_in_memory()?;
40 let mem = Self {
41 conn: Mutex::new(conn),
42 embedder: None,
43 };
44 mem.init_schema()?;
45 Ok(mem)
46 }
47
48 pub fn with_embedder(mut self, embedder: E) -> Self {
50 self.embedder = Some(embedder);
51 self
52 }
53
54 fn init_schema(&self) -> Result<()> {
56 let conn = self.conn.lock().unwrap();
57 conn.execute_batch(sql::SCHEMA)?;
58 Ok(())
59 }
60
61 pub(crate) fn recall_sync(
67 &self,
68 query: &str,
69 options: &RecallOptions,
70 query_embedding: Option<&[f32]>,
71 ) -> Result<Vec<MemoryEntry>> {
72 let now = now_unix();
73 let limit = if options.limit == 0 {
74 10
75 } else {
76 options.limit
77 };
78
79 let (bm25_candidates, vec_candidates) = {
81 let conn = self.conn.lock().unwrap();
82
83 let mut fts_stmt = conn.prepare(sql::RECALL_FTS)?;
85 let bm25: Vec<(MemoryEntry, f64)> = fts_stmt
86 .query_map([query], |row| {
87 let emb_blob: Option<Vec<u8>> = row.get(6)?;
88 Ok(MemoryEntry {
89 key: CompactString::new(row.get::<_, String>(0)?),
90 value: row.get(1)?,
91 metadata: row
92 .get::<_, Option<String>>(2)?
93 .and_then(|s| serde_json::from_str(&s).ok()),
94 created_at: row.get::<_, i64>(3)? as u64,
95 accessed_at: row.get::<_, i64>(4)? as u64,
96 access_count: row.get::<_, i32>(5)? as u32,
97 embedding: emb_blob.map(|b| decode_embedding(&b)),
98 })
99 .map(|entry| (entry, row.get::<_, f64>(7).unwrap_or(0.0)))
100 })?
101 .filter_map(|r| r.ok())
102 .collect();
103
104 let vec = if query_embedding.is_some() {
106 let mut vec_stmt = conn.prepare(sql::RECALL_VECTOR)?;
107 vec_stmt
108 .query_map([], |row| {
109 let emb_blob: Option<Vec<u8>> = row.get(6)?;
110 Ok(MemoryEntry {
111 key: CompactString::new(row.get::<_, String>(0)?),
112 value: row.get(1)?,
113 metadata: row
114 .get::<_, Option<String>>(2)?
115 .and_then(|s| serde_json::from_str(&s).ok()),
116 created_at: row.get::<_, i64>(3)? as u64,
117 accessed_at: row.get::<_, i64>(4)? as u64,
118 access_count: row.get::<_, i32>(5)? as u32,
119 embedding: emb_blob.map(|b| decode_embedding(&b)),
120 })
121 })?
122 .filter_map(|r| r.ok())
123 .collect::<Vec<_>>()
124 } else {
125 Vec::new()
126 };
127
128 (bm25, vec)
129 };
131
132 let lambda = std::f64::consts::LN_2 / 30.0;
136 let bm25_scored: Vec<(MemoryEntry, f64)> = bm25_candidates
137 .into_iter()
138 .map(|(entry, bm25_rank)| {
139 let bm25_score = -bm25_rank;
140 let age_days = now.saturating_sub(entry.accessed_at) as f64 / 86400.0;
141 let decay = (-lambda * age_days).exp();
142 (entry, bm25_score * decay)
143 })
144 .collect();
145
146 let scored = if let Some(q_emb) = query_embedding {
147 let vec_scored: Vec<(MemoryEntry, f64)> = vec_candidates
149 .into_iter()
150 .filter_map(|entry| {
151 let sim = entry
152 .embedding
153 .as_ref()
154 .map(|e| cosine_similarity(e, q_emb))
155 .unwrap_or(0.0);
156 if sim > 0.0 { Some((entry, sim)) } else { None }
157 })
158 .collect();
159
160 let k = 60.0_f64;
162
163 let mut bm25_ranked = bm25_scored;
164 bm25_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
165
166 let mut vec_ranked = vec_scored;
167 vec_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
168
169 let rrf_scores: Vec<f64>;
171 let bm25_in_vec: Vec<bool>;
172 {
173 let vec_rank_map: HashMap<&str, usize> = vec_ranked
174 .iter()
175 .enumerate()
176 .map(|(i, (e, _))| (e.key.as_str(), i + 1))
177 .collect();
178 let bm25_key_set: HashMap<&str, ()> = bm25_ranked
179 .iter()
180 .map(|(e, _)| (e.key.as_str(), ()))
181 .collect();
182
183 rrf_scores = bm25_ranked
184 .iter()
185 .enumerate()
186 .map(|(i, (e, _))| {
187 1.0 / (k + (i + 1) as f64)
188 + vec_rank_map
189 .get(e.key.as_str())
190 .map(|&r| 1.0 / (k + r as f64))
191 .unwrap_or(0.0)
192 })
193 .collect();
194
195 bm25_in_vec = vec_ranked
196 .iter()
197 .map(|(e, _)| bm25_key_set.contains_key(e.key.as_str()))
198 .collect();
199 }
200
201 let mut fused = Vec::with_capacity(bm25_ranked.len() + vec_ranked.len());
202 for (score, (entry, _)) in rrf_scores.into_iter().zip(bm25_ranked) {
203 fused.push((entry, score));
204 }
205 for (i, (entry, _)) in vec_ranked.into_iter().enumerate() {
206 if bm25_in_vec[i] {
207 continue;
208 }
209 fused.push((entry, 1.0 / (k + (i + 1) as f64)));
210 }
211 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212 fused
213 } else {
214 bm25_scored
215 };
216
217 if scored.is_empty() {
218 return Ok(Vec::new());
219 }
220
221 let mut filtered = scored;
223 if let Some((start, end)) = options.time_range {
224 filtered.retain(|(entry, _)| entry.created_at >= start && entry.created_at <= end);
225 }
226 if let Some(threshold) = options.relevance_threshold {
227 filtered.retain(|(_, score)| *score >= threshold as f64);
228 }
229 if filtered.is_empty() {
230 return Ok(Vec::new());
231 }
232
233 let use_cosine = query_embedding.is_some();
234 Ok(mmr_rerank(filtered, limit, 0.7, use_cosine))
235 }
236
237 pub fn store_with_metadata(
239 &self,
240 key: &str,
241 value: &str,
242 metadata: Option<&Value>,
243 embedding: Option<&[f32]>,
244 ) -> Result<()> {
245 let conn = self.conn.lock().unwrap();
246 let now = now_unix() as i64;
247 let meta_json = metadata.map(|m| serde_json::to_string(m).unwrap());
248 let emb_blob: Option<Vec<u8>> =
249 embedding.map(|e| e.iter().flat_map(|f| f.to_le_bytes()).collect());
250
251 conn.execute(
252 sql::UPSERT_FULL,
253 rusqlite::params![key, value, meta_json, now, emb_blob],
254 )?;
255 Ok(())
256 }
257
258 pub fn get_entry(&self, key: &str) -> Option<MemoryEntry> {
260 let conn = self.conn.lock().unwrap();
261 conn.query_row(sql::SELECT_ENTRY, [key], |row| {
262 let emb_blob: Option<Vec<u8>> = row.get(6)?;
263 Ok(MemoryEntry {
264 key: CompactString::new(row.get::<_, String>(0)?),
265 value: row.get(1)?,
266 metadata: row
267 .get::<_, Option<String>>(2)?
268 .and_then(|s| serde_json::from_str(&s).ok()),
269 created_at: row.get::<_, i64>(3)? as u64,
270 accessed_at: row.get::<_, i64>(4)? as u64,
271 access_count: row.get::<_, i32>(5)? as u32,
272 embedding: emb_blob.map(|b| decode_embedding(&b)),
273 })
274 })
275 .ok()
276 }
277}