1pub use crate::utils::cosine_similarity;
12use crate::utils::{decode_embedding, mmr_rerank, now_unix};
13use anyhow::Result;
14use compact_str::CompactString;
15use rusqlite::Connection;
16use serde_json::Value;
17use std::{collections::HashMap, future::Future, path::Path, sync::Mutex};
18
19mod embedder;
20mod inmemory;
21mod memory;
22mod sql;
23mod utils;
24
25pub use embedder::{Embedder, NoEmbedder};
26pub use inmemory::InMemory;
27
28#[derive(Debug, Clone, Default)]
30pub struct MemoryEntry {
31 pub key: CompactString,
33 pub value: String,
35 pub metadata: Option<Value>,
37 pub created_at: u64,
39 pub accessed_at: u64,
41 pub access_count: u32,
43 pub embedding: Option<Vec<f32>>,
45}
46
47#[derive(Debug, Clone, Default)]
49pub struct RecallOptions {
50 pub limit: usize,
52 pub time_range: Option<(u64, u64)>,
54 pub relevance_threshold: Option<f32>,
56}
57
58pub trait Memory: Send + Sync {
66 fn get(&self, key: &str) -> Option<String>;
68
69 fn entries(&self) -> Vec<(String, String)>;
71
72 fn set(&self, key: impl Into<String>, value: impl Into<String>) -> Option<String>;
74
75 fn remove(&self, key: &str) -> Option<String>;
77
78 fn compile(&self) -> String {
80 let entries = self.entries();
81 if entries.is_empty() {
82 return String::new();
83 }
84
85 let mut out = String::from("<memory>\n");
86 for (key, value) in &entries {
87 out.push_str(&format!("<{key}>\n"));
88 out.push_str(value);
89 if !value.ends_with('\n') {
90 out.push('\n');
91 }
92 out.push_str(&format!("</{key}>\n"));
93 }
94 out.push_str("</memory>");
95 out
96 }
97
98 fn store(
100 &self,
101 key: impl Into<String> + Send,
102 value: impl Into<String> + Send,
103 ) -> impl Future<Output = Result<()>> + Send {
104 self.set(key, value);
105 async { Ok(()) }
106 }
107
108 fn recall(
110 &self,
111 _query: &str,
112 _options: RecallOptions,
113 ) -> impl Future<Output = Result<Vec<MemoryEntry>>> + Send {
114 async { Ok(Vec::new()) }
115 }
116
117 fn compile_relevant(&self, _query: &str) -> impl Future<Output = String> + Send {
119 let compiled = self.compile();
120 async move { compiled }
121 }
122}
123
124pub fn with_memory(mut config: wcore::AgentConfig, memory: &impl Memory) -> wcore::AgentConfig {
126 let compiled = memory.compile();
127 if !compiled.is_empty() {
128 config.system_prompt = format!("{}\n\n{compiled}", config.system_prompt);
129 }
130 config
131}
132
133pub struct SqliteMemory<E: Embedder> {
138 conn: Mutex<Connection>,
139 embedder: Option<E>,
140}
141
142impl<E: Embedder> SqliteMemory<E> {
143 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
145 let conn = Connection::open(path)?;
146 let mem = Self {
147 conn: Mutex::new(conn),
148 embedder: None,
149 };
150 mem.init_schema()?;
151 Ok(mem)
152 }
153
154 pub fn in_memory() -> Result<Self> {
156 let conn = Connection::open_in_memory()?;
157 let mem = Self {
158 conn: Mutex::new(conn),
159 embedder: None,
160 };
161 mem.init_schema()?;
162 Ok(mem)
163 }
164
165 pub fn with_embedder(mut self, embedder: E) -> Self {
167 self.embedder = Some(embedder);
168 self
169 }
170
171 fn init_schema(&self) -> Result<()> {
173 let conn = self.conn.lock().unwrap();
174 conn.execute_batch(sql::SCHEMA)?;
175 Ok(())
176 }
177
178 fn recall_sync(
184 &self,
185 query: &str,
186 options: &RecallOptions,
187 query_embedding: Option<&[f32]>,
188 ) -> Result<Vec<MemoryEntry>> {
189 let now = now_unix();
190 let limit = if options.limit == 0 {
191 10
192 } else {
193 options.limit
194 };
195
196 let (bm25_candidates, vec_candidates) = {
198 let conn = self.conn.lock().unwrap();
199
200 let mut fts_stmt = conn.prepare(sql::RECALL_FTS)?;
202 let bm25: Vec<(MemoryEntry, f64)> = fts_stmt
203 .query_map([query], |row| {
204 let emb_blob: Option<Vec<u8>> = row.get(6)?;
205 Ok(MemoryEntry {
206 key: CompactString::new(row.get::<_, String>(0)?),
207 value: row.get(1)?,
208 metadata: row
209 .get::<_, Option<String>>(2)?
210 .and_then(|s| serde_json::from_str(&s).ok()),
211 created_at: row.get::<_, i64>(3)? as u64,
212 accessed_at: row.get::<_, i64>(4)? as u64,
213 access_count: row.get::<_, i32>(5)? as u32,
214 embedding: emb_blob.map(|b| decode_embedding(&b)),
215 })
216 .map(|entry| (entry, row.get::<_, f64>(7).unwrap_or(0.0)))
217 })?
218 .filter_map(|r| r.ok())
219 .collect();
220
221 let vec = if query_embedding.is_some() {
223 let mut vec_stmt = conn.prepare(sql::RECALL_VECTOR)?;
224 vec_stmt
225 .query_map([], |row| {
226 let emb_blob: Option<Vec<u8>> = row.get(6)?;
227 Ok(MemoryEntry {
228 key: CompactString::new(row.get::<_, String>(0)?),
229 value: row.get(1)?,
230 metadata: row
231 .get::<_, Option<String>>(2)?
232 .and_then(|s| serde_json::from_str(&s).ok()),
233 created_at: row.get::<_, i64>(3)? as u64,
234 accessed_at: row.get::<_, i64>(4)? as u64,
235 access_count: row.get::<_, i32>(5)? as u32,
236 embedding: emb_blob.map(|b| decode_embedding(&b)),
237 })
238 })?
239 .filter_map(|r| r.ok())
240 .collect::<Vec<_>>()
241 } else {
242 Vec::new()
243 };
244
245 (bm25, vec)
246 };
248
249 let lambda = std::f64::consts::LN_2 / 30.0;
253 let bm25_scored: Vec<(MemoryEntry, f64)> = bm25_candidates
254 .into_iter()
255 .map(|(entry, bm25_rank)| {
256 let bm25_score = -bm25_rank;
257 let age_days = now.saturating_sub(entry.accessed_at) as f64 / 86400.0;
258 let decay = (-lambda * age_days).exp();
259 (entry, bm25_score * decay)
260 })
261 .collect();
262
263 let scored = if let Some(q_emb) = query_embedding {
264 let vec_scored: Vec<(MemoryEntry, f64)> = vec_candidates
266 .into_iter()
267 .filter_map(|entry| {
268 let sim = entry
269 .embedding
270 .as_ref()
271 .map(|e| cosine_similarity(e, q_emb))
272 .unwrap_or(0.0);
273 if sim > 0.0 { Some((entry, sim)) } else { None }
274 })
275 .collect();
276
277 let k = 60.0_f64;
280
281 let mut bm25_ranked = bm25_scored;
282 bm25_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
283
284 let mut vec_ranked = vec_scored;
285 vec_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
286
287 let rrf_scores: Vec<f64>;
289 let bm25_in_vec: Vec<bool>;
290 {
291 let vec_rank_map: HashMap<&str, usize> = vec_ranked
292 .iter()
293 .enumerate()
294 .map(|(i, (e, _))| (e.key.as_str(), i + 1))
295 .collect();
296 let bm25_key_set: HashMap<&str, ()> = bm25_ranked
297 .iter()
298 .map(|(e, _)| (e.key.as_str(), ()))
299 .collect();
300
301 rrf_scores = bm25_ranked
303 .iter()
304 .enumerate()
305 .map(|(i, (e, _))| {
306 1.0 / (k + (i + 1) as f64)
307 + vec_rank_map
308 .get(e.key.as_str())
309 .map(|&r| 1.0 / (k + r as f64))
310 .unwrap_or(0.0)
311 })
312 .collect();
313
314 bm25_in_vec = vec_ranked
316 .iter()
317 .map(|(e, _)| bm25_key_set.contains_key(e.key.as_str()))
318 .collect();
319 }
321
322 let mut fused = Vec::with_capacity(bm25_ranked.len() + vec_ranked.len());
324 for (score, (entry, _)) in rrf_scores.into_iter().zip(bm25_ranked) {
325 fused.push((entry, score));
326 }
327 for (i, (entry, _)) in vec_ranked.into_iter().enumerate() {
328 if bm25_in_vec[i] {
329 continue;
330 }
331 fused.push((entry, 1.0 / (k + (i + 1) as f64)));
332 }
333 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
334 fused
335 } else {
336 bm25_scored
337 };
338
339 if scored.is_empty() {
340 return Ok(Vec::new());
341 }
342
343 let mut filtered = scored;
345 if let Some((start, end)) = options.time_range {
346 filtered.retain(|(entry, _)| entry.created_at >= start && entry.created_at <= end);
347 }
348 if let Some(threshold) = options.relevance_threshold {
349 filtered.retain(|(_, score)| *score >= threshold as f64);
350 }
351 if filtered.is_empty() {
352 return Ok(Vec::new());
353 }
354
355 let use_cosine = query_embedding.is_some();
356 Ok(mmr_rerank(filtered, limit, 0.7, use_cosine))
357 }
358
359 pub fn store_with_metadata(
361 &self,
362 key: &str,
363 value: &str,
364 metadata: Option<&Value>,
365 embedding: Option<&[f32]>,
366 ) -> Result<()> {
367 let conn = self.conn.lock().unwrap();
368 let now = now_unix() as i64;
369 let meta_json = metadata.map(|m| serde_json::to_string(m).unwrap());
370 let emb_blob: Option<Vec<u8>> =
371 embedding.map(|e| e.iter().flat_map(|f| f.to_le_bytes()).collect());
372
373 conn.execute(
374 sql::UPSERT_FULL,
375 rusqlite::params![key, value, meta_json, now, emb_blob],
376 )?;
377 Ok(())
378 }
379
380 pub fn get_entry(&self, key: &str) -> Option<MemoryEntry> {
382 let conn = self.conn.lock().unwrap();
383 conn.query_row(sql::SELECT_ENTRY, [key], |row| {
384 let emb_blob: Option<Vec<u8>> = row.get(6)?;
385 Ok(MemoryEntry {
386 key: CompactString::new(row.get::<_, String>(0)?),
387 value: row.get(1)?,
388 metadata: row
389 .get::<_, Option<String>>(2)?
390 .and_then(|s| serde_json::from_str(&s).ok()),
391 created_at: row.get::<_, i64>(3)? as u64,
392 accessed_at: row.get::<_, i64>(4)? as u64,
393 access_count: row.get::<_, i32>(5)? as u32,
394 embedding: emb_blob.map(|b| decode_embedding(&b)),
395 })
396 })
397 .ok()
398 }
399}