1use anyhow::{Context, Result};
2use chrono::{DateTime, Utc};
3use rusqlite::ffi::sqlite3_auto_extension;
4use rusqlite::{params, Connection};
5use sqlite_vec::sqlite3_vec_init;
6use std::collections::HashMap;
7use std::path::Path;
8use std::sync::{Mutex, Once};
9
10use crate::config::MemoryConfig;
11use crate::types::{MemoryCategory, MemoryEntry, MemorySource, MemoryStats, UserProfileFact};
12
13fn ensure_sqlite_vec_registered() {
15 static INIT: Once = Once::new();
16 INIT.call_once(|| unsafe {
17 #[allow(clippy::missing_transmute_annotations)]
18 sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ())));
19 });
20}
21
22pub struct MemoryStore {
23 conn: Mutex<Connection>,
24 embedding_dimension: usize,
25}
26
27impl MemoryStore {
28 pub fn new(config: &MemoryConfig) -> Result<Self> {
29 let path = &config.storage_path;
30 if let Some(parent) = Path::new(path).parent() {
31 std::fs::create_dir_all(parent)
32 .with_context(|| format!("Failed to create directory for {path}"))?;
33 }
34
35 ensure_sqlite_vec_registered();
36
37 let conn = Connection::open(path)
38 .with_context(|| format!("Failed to open memory database at {path}"))?;
39
40 conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")?;
41
42 let store = Self {
43 conn: Mutex::new(conn),
44 embedding_dimension: config.embedding_dimension,
45 };
46 store.init_schema()?;
47 Ok(store)
48 }
49
50 pub fn new_in_memory(embedding_dimension: usize) -> Result<Self> {
51 ensure_sqlite_vec_registered();
52
53 let conn = Connection::open_in_memory()?;
54 conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")?;
55
56 let store = Self {
57 conn: Mutex::new(conn),
58 embedding_dimension,
59 };
60 store.init_schema()?;
61 Ok(store)
62 }
63
64 fn init_schema(&self) -> Result<()> {
65 let conn = self.conn.lock().unwrap();
66 conn.execute_batch(
67 "CREATE TABLE IF NOT EXISTS memories (
68 id TEXT PRIMARY KEY,
69 content TEXT NOT NULL,
70 source TEXT NOT NULL DEFAULT 'manual',
71 category TEXT NOT NULL DEFAULT 'other',
72 container_tag TEXT NOT NULL DEFAULT 'default',
73 metadata TEXT DEFAULT '{}',
74 session_id TEXT,
75 created_at TEXT NOT NULL DEFAULT (datetime('now')),
76 updated_at TEXT NOT NULL DEFAULT (datetime('now'))
77 );
78
79 CREATE INDEX IF NOT EXISTS idx_memories_container ON memories(container_tag);
80 CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
81 CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);
82 CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at);
83
84 CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
85 content,
86 id UNINDEXED,
87 tokenize='porter unicode61'
88 );
89
90 CREATE TABLE IF NOT EXISTS embedding_cache (
91 content_hash TEXT PRIMARY KEY,
92 embedding BLOB NOT NULL,
93 model TEXT NOT NULL,
94 created_at TEXT NOT NULL DEFAULT (datetime('now'))
95 );
96
97 CREATE TABLE IF NOT EXISTS user_profile (
98 id TEXT PRIMARY KEY,
99 fact_type TEXT NOT NULL DEFAULT 'static',
100 category TEXT NOT NULL DEFAULT 'fact',
101 content TEXT NOT NULL,
102 created_at TEXT NOT NULL DEFAULT (datetime('now')),
103 updated_at TEXT NOT NULL DEFAULT (datetime('now'))
104 );",
105 )?;
106
107 let dim = self.embedding_dimension;
108 conn.execute_batch(&format!(
109 "CREATE VIRTUAL TABLE IF NOT EXISTS memories_vec USING vec0(
110 id TEXT PRIMARY KEY,
111 embedding float[{dim}]
112 );"
113 ))?;
114
115 Ok(())
116 }
117
118 pub fn insert(&self, entry: &MemoryEntry, embedding: &[f32]) -> Result<()> {
119 let conn = self.conn.lock().unwrap();
120 let tx = conn.unchecked_transaction()?;
121
122 let metadata_json = serde_json::to_string(&entry.metadata)?;
123 let created = entry.created_at.to_rfc3339();
124 let updated = entry.updated_at.to_rfc3339();
125
126 tx.execute(
127 "INSERT OR REPLACE INTO memories (id, content, source, category, container_tag, metadata, session_id, created_at, updated_at)
128 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
129 params![
130 entry.id,
131 entry.content,
132 entry.source.as_str(),
133 entry.category.as_str(),
134 entry.container_tag,
135 metadata_json,
136 entry.session_id,
137 created,
138 updated,
139 ],
140 )?;
141
142 let embedding_bytes = embedding_to_bytes(embedding);
143 tx.execute(
144 "INSERT OR REPLACE INTO memories_vec (id, embedding) VALUES (?1, ?2)",
145 params![entry.id, embedding_bytes],
146 )?;
147
148 tx.execute(
149 "INSERT OR REPLACE INTO memories_fts (id, content) VALUES (?1, ?2)",
150 params![entry.id, entry.content],
151 )?;
152
153 tx.commit()?;
154 Ok(())
155 }
156
157 pub fn store_text_only(&self, entry: &MemoryEntry) -> Result<i64> {
158 let conn = self.conn.lock().unwrap();
159 let tx = conn.unchecked_transaction()?;
160
161 let metadata_json = serde_json::to_string(&entry.metadata)?;
162 let created = entry.created_at.to_rfc3339();
163 let updated = entry.updated_at.to_rfc3339();
164
165 tx.execute(
166 "INSERT OR REPLACE INTO memories (id, content, source, category, container_tag, metadata, session_id, created_at, updated_at)
167 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
168 params![
169 entry.id,
170 entry.content,
171 entry.source.as_str(),
172 entry.category.as_str(),
173 entry.container_tag,
174 metadata_json,
175 entry.session_id,
176 created,
177 updated,
178 ],
179 )?;
180
181 tx.execute(
182 "INSERT OR REPLACE INTO memories_fts (id, content) VALUES (?1, ?2)",
183 params![entry.id, entry.content],
184 )?;
185
186 let row_id = tx.last_insert_rowid();
187 tx.commit()?;
188 Ok(row_id)
189 }
190
191 pub fn update_embedding(&self, id: i64, embedding: &[f32]) -> Result<()> {
192 let conn = self.conn.lock().unwrap();
193 let tx = conn.unchecked_transaction()?;
194 let embedding_bytes = embedding_to_bytes(embedding);
195 let updated = tx.execute(
196 "INSERT OR REPLACE INTO memories_vec (id, embedding)
197 SELECT id, ?1 FROM memories WHERE rowid = ?2",
198 params![embedding_bytes, id],
199 )?;
200
201 if updated == 0 {
202 anyhow::bail!("memory row not found for rowid {id}");
203 }
204
205 tx.commit()?;
206 Ok(())
207 }
208
209 pub fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
210 let conn = self.conn.lock().unwrap();
211 let mut stmt = conn.prepare(
212 "SELECT id, content, source, category, container_tag, metadata, session_id, created_at, updated_at
213 FROM memories WHERE id = ?1",
214 )?;
215
216 let result = stmt
217 .query_row(params![id], |row| Ok(row_to_entry(row)))
218 .optional()?;
219
220 match result {
221 Some(entry) => Ok(Some(entry?)),
222 None => Ok(None),
223 }
224 }
225
226 pub fn delete(&self, id: &str) -> Result<bool> {
227 let conn = self.conn.lock().unwrap();
228 let tx = conn.unchecked_transaction()?;
229
230 let deleted = tx.execute("DELETE FROM memories WHERE id = ?1", params![id])?;
231 tx.execute("DELETE FROM memories_vec WHERE id = ?1", params![id])?;
232 tx.execute("DELETE FROM memories_fts WHERE id = ?1", params![id])?;
233
234 tx.commit()?;
235 Ok(deleted > 0)
236 }
237
238 pub fn delete_by_ids(&self, ids: &[String]) -> Result<usize> {
239 if ids.is_empty() {
240 return Ok(0);
241 }
242 let conn = self.conn.lock().unwrap();
243 let tx = conn.unchecked_transaction()?;
244 let mut count = 0;
245
246 for id in ids {
247 count += tx.execute("DELETE FROM memories WHERE id = ?1", params![id])?;
248 tx.execute("DELETE FROM memories_vec WHERE id = ?1", params![id])?;
249 tx.execute("DELETE FROM memories_fts WHERE id = ?1", params![id])?;
250 }
251
252 tx.commit()?;
253 Ok(count)
254 }
255
256 pub fn list(&self, container_tag: Option<&str>, limit: usize) -> Result<Vec<MemoryEntry>> {
257 let conn = self.conn.lock().unwrap();
258 let (sql, param_values): (String, Vec<Box<dyn rusqlite::types::ToSql>>) = match container_tag {
259 Some(tag) => (
260 "SELECT id, content, source, category, container_tag, metadata, session_id, created_at, updated_at
261 FROM memories WHERE container_tag = ?1 ORDER BY created_at DESC LIMIT ?2"
262 .to_string(),
263 vec![Box::new(tag.to_string()), Box::new(limit as i64)],
264 ),
265 None => (
266 "SELECT id, content, source, category, container_tag, metadata, session_id, created_at, updated_at
267 FROM memories ORDER BY created_at DESC LIMIT ?1"
268 .to_string(),
269 vec![Box::new(limit as i64)],
270 ),
271 };
272
273 let mut stmt = conn.prepare(&sql)?;
274 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
275 param_values.iter().map(|b| b.as_ref()).collect();
276 let rows = stmt.query_map(params_ref.as_slice(), |row| Ok(row_to_entry(row)))?;
277
278 let mut entries = Vec::new();
279 for row in rows {
280 entries.push(row??);
281 }
282 Ok(entries)
283 }
284
285 pub fn vector_search(&self, embedding: &[f32], limit: usize) -> Result<Vec<(String, f32)>> {
286 let conn = self.conn.lock().unwrap();
287 let embedding_bytes = embedding_to_bytes(embedding);
288
289 let mut stmt = conn.prepare(
290 "SELECT id, distance FROM memories_vec WHERE embedding MATCH ?1 ORDER BY distance LIMIT ?2",
291 )?;
292
293 let rows = stmt.query_map(params![embedding_bytes, limit as i64], |row| {
294 Ok((row.get::<_, String>(0)?, row.get::<_, f32>(1)?))
295 })?;
296
297 let mut results = Vec::new();
298 for row in rows {
299 results.push(row?);
300 }
301 Ok(results)
302 }
303
304 pub fn fts_search(&self, query: &str, limit: usize) -> Result<Vec<(String, f64)>> {
305 let conn = self.conn.lock().unwrap();
306 let mut stmt = conn.prepare(
307 "SELECT id, bm25(memories_fts) as rank FROM memories_fts WHERE memories_fts MATCH ?1 ORDER BY rank LIMIT ?2",
308 )?;
309
310 let rows = stmt.query_map(params![query, limit as i64], |row| {
311 Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
312 })?;
313
314 let mut results = Vec::new();
315 for row in rows {
316 results.push(row?);
317 }
318 Ok(results)
319 }
320
321 pub fn get_cached_embedding(&self, content_hash: &str) -> Result<Option<Vec<f32>>> {
322 let conn = self.conn.lock().unwrap();
323 let mut stmt =
324 conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
325
326 let result = stmt
327 .query_row(params![content_hash], |row| {
328 let bytes: Vec<u8> = row.get(0)?;
329 Ok(bytes_to_embedding(&bytes))
330 })
331 .optional()?;
332
333 Ok(result)
334 }
335
336 pub fn cache_embedding(
337 &self,
338 content_hash: &str,
339 embedding: &[f32],
340 model: &str,
341 ) -> Result<()> {
342 let conn = self.conn.lock().unwrap();
343 let bytes = embedding_to_bytes(embedding);
344 conn.execute(
345 "INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, model) VALUES (?1, ?2, ?3)",
346 params![content_hash, bytes, model],
347 )?;
348 Ok(())
349 }
350
351 pub fn add_profile_fact(&self, fact: &UserProfileFact) -> Result<()> {
352 let conn = self.conn.lock().unwrap();
353 conn.execute(
354 "INSERT OR REPLACE INTO user_profile (id, fact_type, category, content, created_at, updated_at)
355 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
356 params![
357 fact.id,
358 fact.fact_type,
359 fact.category,
360 fact.content,
361 fact.created_at.to_rfc3339(),
362 fact.updated_at.to_rfc3339(),
363 ],
364 )?;
365 Ok(())
366 }
367
368 pub fn get_profile_facts(&self, fact_type: Option<&str>) -> Result<Vec<UserProfileFact>> {
369 let conn = self.conn.lock().unwrap();
370 let (sql, param_values): (String, Vec<Box<dyn rusqlite::types::ToSql>>) = match fact_type {
371 Some(ft) => (
372 "SELECT id, fact_type, category, content, created_at, updated_at FROM user_profile WHERE fact_type = ?1 ORDER BY created_at DESC".to_string(),
373 vec![Box::new(ft.to_string())],
374 ),
375 None => (
376 "SELECT id, fact_type, category, content, created_at, updated_at FROM user_profile ORDER BY created_at DESC".to_string(),
377 vec![],
378 ),
379 };
380
381 let mut stmt = conn.prepare(&sql)?;
382 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
383 param_values.iter().map(|b| b.as_ref()).collect();
384 let rows = stmt.query_map(params_ref.as_slice(), |row| Ok(row_to_profile_fact(row)))?;
385
386 let mut facts = Vec::new();
387 for row in rows {
388 facts.push(row??);
389 }
390 Ok(facts)
391 }
392
393 pub fn remove_profile_fact(&self, id: &str) -> Result<bool> {
394 let conn = self.conn.lock().unwrap();
395 let deleted = conn.execute("DELETE FROM user_profile WHERE id = ?1", params![id])?;
396 Ok(deleted > 0)
397 }
398
399 pub fn stats(&self) -> Result<MemoryStats> {
400 let conn = self.conn.lock().unwrap();
401
402 let total: usize = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
403
404 let mut by_category = HashMap::new();
405 let mut stmt = conn.prepare("SELECT category, COUNT(*) FROM memories GROUP BY category")?;
406 let rows = stmt.query_map([], |row| {
407 Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?))
408 })?;
409 for row in rows {
410 let (cat, count) = row?;
411 by_category.insert(cat, count);
412 }
413
414 let mut by_container = HashMap::new();
415 let mut stmt =
416 conn.prepare("SELECT container_tag, COUNT(*) FROM memories GROUP BY container_tag")?;
417 let rows = stmt.query_map([], |row| {
418 Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?))
419 })?;
420 for row in rows {
421 let (tag, count) = row?;
422 by_container.insert(tag, count);
423 }
424
425 let db_size = conn
426 .query_row(
427 "SELECT page_count * page_size FROM pragma_page_count, pragma_page_size",
428 [],
429 |row| row.get::<_, u64>(0),
430 )
431 .unwrap_or(0);
432
433 Ok(MemoryStats {
434 total_memories: total,
435 total_by_category: by_category,
436 total_by_container: by_container,
437 db_size_bytes: db_size,
438 })
439 }
440
441 pub fn count(&self) -> Result<usize> {
442 let conn = self.conn.lock().unwrap();
443 let count: usize = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
444 Ok(count)
445 }
446
447 pub fn cleanup(&self, container_tag: &str, retention_days: Option<u32>, max_memories: Option<usize>) -> Result<usize> {
448 let conn = self.conn.lock().unwrap();
449 let tx = conn.unchecked_transaction()?;
450 let mut total_deleted = 0;
451
452 if retention_days.is_none() && max_memories.is_none() {
454 return Ok(0);
455 }
456
457 if let Some(days) = retention_days {
459 let deleted = tx.execute(
460 "DELETE FROM memories WHERE created_at < datetime('now', ?1) AND container_tag = ?2",
461 params![format!("-{} days", days), container_tag],
462 )?;
463 total_deleted += deleted;
464 }
465
466 if let Some(max_count) = max_memories {
468 let deleted = tx.execute(
469 "DELETE FROM memories WHERE container_tag = ?1 AND id NOT IN (SELECT id FROM memories WHERE container_tag = ?2 ORDER BY created_at DESC LIMIT ?3)",
470 params![container_tag, container_tag, max_count as i64],
471 )?;
472 total_deleted += deleted;
473 }
474
475 tx.execute(
477 "DELETE FROM memories_fts WHERE rowid NOT IN (SELECT rowid FROM memories)",
478 [],
479 )?;
480
481 tx.commit()?;
482
483 tracing::info!(deleted = total_deleted, "memory cleanup completed");
484 Ok(total_deleted)
485 }
486}
487
488fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
489 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
490}
491
492fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
493 bytes
494 .chunks_exact(4)
495 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
496 .collect()
497}
498
499fn row_to_entry(row: &rusqlite::Row<'_>) -> Result<MemoryEntry> {
500 let metadata_str: String = row.get(5)?;
501 let metadata: HashMap<String, serde_json::Value> =
502 serde_json::from_str(&metadata_str).unwrap_or_default();
503
504 let created_str: String = row.get(7)?;
505 let updated_str: String = row.get(8)?;
506
507 let created_at = DateTime::parse_from_rfc3339(&created_str)
508 .map(|dt| dt.with_timezone(&Utc))
509 .unwrap_or_else(|_| Utc::now());
510 let updated_at = DateTime::parse_from_rfc3339(&updated_str)
511 .map(|dt| dt.with_timezone(&Utc))
512 .unwrap_or_else(|_| Utc::now());
513
514 Ok(MemoryEntry {
515 id: row.get(0)?,
516 content: row.get(1)?,
517 source: MemorySource::from_str_lossy(&row.get::<_, String>(2)?),
518 category: MemoryCategory::from_str_lossy(&row.get::<_, String>(3)?),
519 container_tag: row.get(4)?,
520 metadata,
521 session_id: row.get(6)?,
522 created_at,
523 updated_at,
524 })
525}
526
527fn row_to_profile_fact(row: &rusqlite::Row<'_>) -> Result<UserProfileFact> {
528 let created_str: String = row.get(4)?;
529 let updated_str: String = row.get(5)?;
530
531 let created_at = DateTime::parse_from_rfc3339(&created_str)
532 .map(|dt| dt.with_timezone(&Utc))
533 .unwrap_or_else(|_| Utc::now());
534 let updated_at = DateTime::parse_from_rfc3339(&updated_str)
535 .map(|dt| dt.with_timezone(&Utc))
536 .unwrap_or_else(|_| Utc::now());
537
538 Ok(UserProfileFact {
539 id: row.get(0)?,
540 fact_type: row.get(1)?,
541 category: row.get(2)?,
542 content: row.get(3)?,
543 created_at,
544 updated_at,
545 })
546}
547
548use rusqlite::OptionalExtension;
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553
554 fn make_embedding(dim: usize, seed: f32) -> Vec<f32> {
555 (0..dim)
556 .map(|i| ((i as f32 + seed) / dim as f32).sin())
557 .collect()
558 }
559
560 #[test]
561 fn create_store_and_schema() {
562 let store = MemoryStore::new_in_memory(128).unwrap();
563 let stats = store.stats().unwrap();
564 assert_eq!(stats.total_memories, 0);
565 }
566
567 #[test]
568 fn insert_and_get() {
569 let store = MemoryStore::new_in_memory(128).unwrap();
570 let entry = MemoryEntry::new("I prefer dark mode", MemorySource::Manual, "default");
571 let embedding = make_embedding(128, 1.0);
572
573 store.insert(&entry, &embedding).unwrap();
574 let retrieved = store.get(&entry.id).unwrap().unwrap();
575
576 assert_eq!(retrieved.content, "I prefer dark mode");
577 assert_eq!(retrieved.category, MemoryCategory::Preference);
578 assert_eq!(retrieved.container_tag, "default");
579 }
580
581 #[test]
582 fn store_text_only_inserts_and_returns_row_id() {
583 let store = MemoryStore::new_in_memory(128).unwrap();
584 let entry = MemoryEntry::new("text-only memory", MemorySource::Conversation, "default");
585
586 let row_id = store.store_text_only(&entry).unwrap();
587
588 assert!(row_id > 0);
589 assert_eq!(store.count().unwrap(), 1);
590 let retrieved = store.get(&entry.id).unwrap().unwrap();
591 assert_eq!(retrieved.content, "text-only memory");
592 }
593
594 #[test]
595 fn update_embedding_updates_vector_row() {
596 let store = MemoryStore::new_in_memory(128).unwrap();
597 let entry = MemoryEntry::new("needs embedding", MemorySource::Conversation, "default");
598 let embedding = make_embedding(128, 4.0);
599
600 let row_id = store.store_text_only(&entry).unwrap();
601 store.update_embedding(row_id, &embedding).unwrap();
602
603 let results = store.vector_search(&embedding, 1).unwrap();
604 assert_eq!(results.len(), 1);
605 assert_eq!(results[0].0, entry.id);
606 }
607
608 #[test]
609 fn delete_removes_from_all_tables() {
610 let store = MemoryStore::new_in_memory(128).unwrap();
611 let entry = MemoryEntry::new("test content", MemorySource::Manual, "default");
612 let embedding = make_embedding(128, 1.0);
613
614 store.insert(&entry, &embedding).unwrap();
615 assert_eq!(store.count().unwrap(), 1);
616
617 let deleted = store.delete(&entry.id).unwrap();
618 assert!(deleted);
619 assert_eq!(store.count().unwrap(), 0);
620 assert!(store.get(&entry.id).unwrap().is_none());
621 }
622
623 #[test]
624 fn list_with_container_filter() {
625 let store = MemoryStore::new_in_memory(128).unwrap();
626
627 let e1 = MemoryEntry::new("entry one", MemorySource::Manual, "work");
628 let e2 = MemoryEntry::new("entry two", MemorySource::Manual, "personal");
629 let e3 = MemoryEntry::new("entry three", MemorySource::Manual, "work");
630
631 store.insert(&e1, &make_embedding(128, 1.0)).unwrap();
632 store.insert(&e2, &make_embedding(128, 2.0)).unwrap();
633 store.insert(&e3, &make_embedding(128, 3.0)).unwrap();
634
635 let work = store.list(Some("work"), 10).unwrap();
636 assert_eq!(work.len(), 2);
637
638 let personal = store.list(Some("personal"), 10).unwrap();
639 assert_eq!(personal.len(), 1);
640
641 let all = store.list(None, 10).unwrap();
642 assert_eq!(all.len(), 3);
643 }
644
645 #[test]
646 fn vector_search_returns_results() {
647 let store = MemoryStore::new_in_memory(128).unwrap();
648
649 for i in 0..5 {
650 let entry = MemoryEntry::new(format!("memory {i}"), MemorySource::Manual, "default");
651 store
652 .insert(&entry, &make_embedding(128, i as f32))
653 .unwrap();
654 }
655
656 let query_embedding = make_embedding(128, 2.5);
657 let results = store.vector_search(&query_embedding, 3).unwrap();
658 assert!(!results.is_empty());
659 assert!(results.len() <= 3);
660 }
661
662 #[test]
663 fn fts_search_returns_results() {
664 let store = MemoryStore::new_in_memory(128).unwrap();
665
666 let e1 = MemoryEntry::new("rust programming language", MemorySource::Manual, "default");
667 let e2 = MemoryEntry::new("python scripting language", MemorySource::Manual, "default");
668 let e3 = MemoryEntry::new("rust async runtime tokio", MemorySource::Manual, "default");
669
670 store.insert(&e1, &make_embedding(128, 1.0)).unwrap();
671 store.insert(&e2, &make_embedding(128, 2.0)).unwrap();
672 store.insert(&e3, &make_embedding(128, 3.0)).unwrap();
673
674 let results = store.fts_search("rust", 10).unwrap();
675 assert_eq!(results.len(), 2);
676 }
677
678 #[test]
679 fn embedding_cache() {
680 let store = MemoryStore::new_in_memory(128).unwrap();
681 let hash = "abc123";
682 let embedding = make_embedding(128, 1.0);
683
684 assert!(store.get_cached_embedding(hash).unwrap().is_none());
685
686 store
687 .cache_embedding(hash, &embedding, "test-model")
688 .unwrap();
689
690 let cached = store.get_cached_embedding(hash).unwrap().unwrap();
691 assert_eq!(cached.len(), 128);
692 assert!((cached[0] - embedding[0]).abs() < 1e-6);
693 }
694
695 #[test]
696 fn stats_counting() {
697 let store = MemoryStore::new_in_memory(128).unwrap();
698
699 let e1 = MemoryEntry::new("I prefer vim", MemorySource::Manual, "work");
700 let e2 = MemoryEntry::new("The sky is blue", MemorySource::Manual, "personal");
701
702 store.insert(&e1, &make_embedding(128, 1.0)).unwrap();
703 store.insert(&e2, &make_embedding(128, 2.0)).unwrap();
704
705 let stats = store.stats().unwrap();
706 assert_eq!(stats.total_memories, 2);
707 assert!(stats.total_by_container.contains_key("work"));
708 assert!(stats.total_by_container.contains_key("personal"));
709 }
710
711 #[test]
712 fn profile_facts_crud() {
713 let store = MemoryStore::new_in_memory(128).unwrap();
714
715 let fact = UserProfileFact {
716 id: "f1".to_string(),
717 fact_type: "static".to_string(),
718 category: "preference".to_string(),
719 content: "Prefers dark mode".to_string(),
720 created_at: Utc::now(),
721 updated_at: Utc::now(),
722 };
723
724 store.add_profile_fact(&fact).unwrap();
725
726 let facts = store.get_profile_facts(Some("static")).unwrap();
727 assert_eq!(facts.len(), 1);
728 assert_eq!(facts[0].content, "Prefers dark mode");
729
730 let removed = store.remove_profile_fact("f1").unwrap();
731 assert!(removed);
732
733 let facts = store.get_profile_facts(None).unwrap();
734 assert!(facts.is_empty());
735 }
736 #[test]
737 fn cleanup_with_retention_days_keeps_recent_entries() {
738 let store = MemoryStore::new_in_memory(128).unwrap();
739 let container = "test-container";
740
741 let entry = MemoryEntry::new("recent memory", MemorySource::Manual, container);
743 store.insert(&entry, &make_embedding(128, 1.0)).unwrap();
744 assert_eq!(store.count().unwrap(), 1);
745
746 let deleted = store.cleanup(container, Some(1000), None).unwrap();
748 assert_eq!(deleted, 0);
749 assert_eq!(store.count().unwrap(), 1);
750 }
751
752 #[test]
753 fn cleanup_with_max_memories_keeps_only_newest() {
754 let store = MemoryStore::new_in_memory(128).unwrap();
755 let container = "test-container";
756
757 for i in 0..5 {
759 let entry = MemoryEntry::new(format!("memory {i}"), MemorySource::Manual, container);
760 store.insert(&entry, &make_embedding(128, i as f32)).unwrap();
761 }
762 assert_eq!(store.count().unwrap(), 5);
763
764 let deleted = store.cleanup(container, None, Some(2)).unwrap();
766 assert_eq!(deleted, 3);
767 assert_eq!(store.count().unwrap(), 2);
768 }
769
770 #[test]
771 fn cleanup_with_both_none_is_noop() {
772 let store = MemoryStore::new_in_memory(128).unwrap();
773 let container = "test-container";
774
775 let entry = MemoryEntry::new("memory", MemorySource::Manual, container);
777 store.insert(&entry, &make_embedding(128, 1.0)).unwrap();
778 assert_eq!(store.count().unwrap(), 1);
779
780 let deleted = store.cleanup(container, None, None).unwrap();
782 assert_eq!(deleted, 0);
783 assert_eq!(store.count().unwrap(), 1);
784 }
785
786}