1use anyhow::Result;
7use compact_str::CompactString;
8use rusqlite::Connection;
9use serde_json::Value;
10use std::{collections::HashMap, path::Path, sync::Mutex};
11use utils::{cosine_similarity, decode_embedding, mmr_rerank, now_unix};
12use wcore::{Embedder, MemoryEntry, RecallOptions};
13
14mod memory;
15mod sql;
16mod utils;
17
18pub struct SqliteMemory<E: Embedder> {
20 pub(crate) conn: Mutex<Connection>,
21 pub(crate) embedder: Option<E>,
22}
23
24impl<E: Embedder> SqliteMemory<E> {
25 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
27 let conn = Connection::open(path)?;
28 let mem = Self {
29 conn: Mutex::new(conn),
30 embedder: None,
31 };
32 mem.init_schema()?;
33 Ok(mem)
34 }
35
36 pub fn in_memory() -> Result<Self> {
38 let conn = Connection::open_in_memory()?;
39 let mem = Self {
40 conn: Mutex::new(conn),
41 embedder: None,
42 };
43 mem.init_schema()?;
44 Ok(mem)
45 }
46
47 pub fn with_embedder(mut self, embedder: E) -> Self {
49 self.embedder = Some(embedder);
50 self
51 }
52
53 fn init_schema(&self) -> Result<()> {
55 let conn = self.conn.lock().unwrap();
56 conn.execute_batch(sql::SCHEMA)?;
57 Ok(())
58 }
59
60 pub(crate) fn recall_sync(
66 &self,
67 query: &str,
68 options: &RecallOptions,
69 query_embedding: Option<&[f32]>,
70 ) -> Result<Vec<MemoryEntry>> {
71 let now = now_unix();
72 let limit = if options.limit == 0 {
73 10
74 } else {
75 options.limit
76 };
77
78 let (bm25_candidates, vec_candidates) = {
80 let conn = self.conn.lock().unwrap();
81
82 let mut fts_stmt = conn.prepare(sql::RECALL_FTS)?;
84 let bm25: Vec<(MemoryEntry, f64)> = fts_stmt
85 .query_map([query], |row| {
86 let emb_blob: Option<Vec<u8>> = row.get(6)?;
87 Ok(MemoryEntry {
88 key: CompactString::new(row.get::<_, String>(0)?),
89 value: row.get(1)?,
90 metadata: row
91 .get::<_, Option<String>>(2)?
92 .and_then(|s| serde_json::from_str(&s).ok()),
93 created_at: row.get::<_, i64>(3)? as u64,
94 accessed_at: row.get::<_, i64>(4)? as u64,
95 access_count: row.get::<_, i32>(5)? as u32,
96 embedding: emb_blob.map(|b| decode_embedding(&b)),
97 })
98 .map(|entry| (entry, row.get::<_, f64>(7).unwrap_or(0.0)))
99 })?
100 .filter_map(|r| r.ok())
101 .collect();
102
103 let vec = if query_embedding.is_some() {
105 let mut vec_stmt = conn.prepare(sql::RECALL_VECTOR)?;
106 vec_stmt
107 .query_map([], |row| {
108 let emb_blob: Option<Vec<u8>> = row.get(6)?;
109 Ok(MemoryEntry {
110 key: CompactString::new(row.get::<_, String>(0)?),
111 value: row.get(1)?,
112 metadata: row
113 .get::<_, Option<String>>(2)?
114 .and_then(|s| serde_json::from_str(&s).ok()),
115 created_at: row.get::<_, i64>(3)? as u64,
116 accessed_at: row.get::<_, i64>(4)? as u64,
117 access_count: row.get::<_, i32>(5)? as u32,
118 embedding: emb_blob.map(|b| decode_embedding(&b)),
119 })
120 })?
121 .filter_map(|r| r.ok())
122 .collect::<Vec<_>>()
123 } else {
124 Vec::new()
125 };
126
127 (bm25, vec)
128 };
130
131 let lambda = std::f64::consts::LN_2 / 30.0;
135 let bm25_scored: Vec<(MemoryEntry, f64)> = bm25_candidates
136 .into_iter()
137 .map(|(entry, bm25_rank)| {
138 let bm25_score = -bm25_rank;
139 let age_days = now.saturating_sub(entry.accessed_at) as f64 / 86400.0;
140 let decay = (-lambda * age_days).exp();
141 (entry, bm25_score * decay)
142 })
143 .collect();
144
145 let scored = if let Some(q_emb) = query_embedding {
146 let vec_scored: Vec<(MemoryEntry, f64)> = vec_candidates
148 .into_iter()
149 .filter_map(|entry| {
150 let sim = entry
151 .embedding
152 .as_ref()
153 .map(|e| cosine_similarity(e, q_emb))
154 .unwrap_or(0.0);
155 if sim > 0.0 { Some((entry, sim)) } else { None }
156 })
157 .collect();
158
159 let k = 60.0_f64;
161
162 let mut bm25_ranked = bm25_scored;
163 bm25_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
164
165 let mut vec_ranked = vec_scored;
166 vec_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
167
168 let rrf_scores: Vec<f64>;
170 let bm25_in_vec: Vec<bool>;
171 {
172 let vec_rank_map: HashMap<&str, usize> = vec_ranked
173 .iter()
174 .enumerate()
175 .map(|(i, (e, _))| (e.key.as_str(), i + 1))
176 .collect();
177 let bm25_key_set: HashMap<&str, ()> = bm25_ranked
178 .iter()
179 .map(|(e, _)| (e.key.as_str(), ()))
180 .collect();
181
182 rrf_scores = bm25_ranked
183 .iter()
184 .enumerate()
185 .map(|(i, (e, _))| {
186 1.0 / (k + (i + 1) as f64)
187 + vec_rank_map
188 .get(e.key.as_str())
189 .map(|&r| 1.0 / (k + r as f64))
190 .unwrap_or(0.0)
191 })
192 .collect();
193
194 bm25_in_vec = vec_ranked
195 .iter()
196 .map(|(e, _)| bm25_key_set.contains_key(e.key.as_str()))
197 .collect();
198 }
199
200 let mut fused = Vec::with_capacity(bm25_ranked.len() + vec_ranked.len());
201 for (score, (entry, _)) in rrf_scores.into_iter().zip(bm25_ranked) {
202 fused.push((entry, score));
203 }
204 for (i, (entry, _)) in vec_ranked.into_iter().enumerate() {
205 if bm25_in_vec[i] {
206 continue;
207 }
208 fused.push((entry, 1.0 / (k + (i + 1) as f64)));
209 }
210 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
211 fused
212 } else {
213 bm25_scored
214 };
215
216 if scored.is_empty() {
217 return Ok(Vec::new());
218 }
219
220 let mut filtered = scored;
222 if let Some((start, end)) = options.time_range {
223 filtered.retain(|(entry, _)| entry.created_at >= start && entry.created_at <= end);
224 }
225 if let Some(threshold) = options.relevance_threshold {
226 filtered.retain(|(_, score)| *score >= threshold as f64);
227 }
228 if filtered.is_empty() {
229 return Ok(Vec::new());
230 }
231
232 let use_cosine = query_embedding.is_some();
233 Ok(mmr_rerank(filtered, limit, 0.7, use_cosine))
234 }
235
236 pub fn store_with_metadata(
238 &self,
239 key: &str,
240 value: &str,
241 metadata: Option<&Value>,
242 embedding: Option<&[f32]>,
243 ) -> Result<()> {
244 let conn = self.conn.lock().unwrap();
245 let now = now_unix() as i64;
246 let meta_json = metadata.map(|m| serde_json::to_string(m).unwrap());
247 let emb_blob: Option<Vec<u8>> =
248 embedding.map(|e| e.iter().flat_map(|f| f.to_le_bytes()).collect());
249
250 conn.execute(
251 sql::UPSERT_FULL,
252 rusqlite::params![key, value, meta_json, now, emb_blob],
253 )?;
254 Ok(())
255 }
256
257 pub fn get_entry(&self, key: &str) -> Option<MemoryEntry> {
259 let conn = self.conn.lock().unwrap();
260 conn.query_row(sql::SELECT_ENTRY, [key], |row| {
261 let emb_blob: Option<Vec<u8>> = row.get(6)?;
262 Ok(MemoryEntry {
263 key: CompactString::new(row.get::<_, String>(0)?),
264 value: row.get(1)?,
265 metadata: row
266 .get::<_, Option<String>>(2)?
267 .and_then(|s| serde_json::from_str(&s).ok()),
268 created_at: row.get::<_, i64>(3)? as u64,
269 accessed_at: row.get::<_, i64>(4)? as u64,
270 access_count: row.get::<_, i32>(5)? as u32,
271 embedding: emb_blob.map(|b| decode_embedding(&b)),
272 })
273 })
274 .ok()
275 }
276}