Skip to main content

uira_memory/
store.rs

1use anyhow::{Context, Result};
2use chrono::{DateTime, Utc};
3use rusqlite::ffi::sqlite3_auto_extension;
4use rusqlite::{params, Connection};
5use sqlite_vec::sqlite3_vec_init;
6use std::collections::HashMap;
7use std::path::Path;
8use std::sync::{Mutex, Once};
9
10use crate::config::MemoryConfig;
11use crate::types::{MemoryCategory, MemoryEntry, MemorySource, MemoryStats, UserProfileFact};
12
13/// Register the sqlite-vec extension globally (once per process).
14fn ensure_sqlite_vec_registered() {
15    static INIT: Once = Once::new();
16    INIT.call_once(|| unsafe {
17        #[allow(clippy::missing_transmute_annotations)]
18        sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ())));
19    });
20}
21
22pub struct MemoryStore {
23    conn: Mutex<Connection>,
24    embedding_dimension: usize,
25}
26
27impl MemoryStore {
28    pub fn new(config: &MemoryConfig) -> Result<Self> {
29        let path = &config.storage_path;
30        if let Some(parent) = Path::new(path).parent() {
31            std::fs::create_dir_all(parent)
32                .with_context(|| format!("Failed to create directory for {path}"))?;
33        }
34
35        ensure_sqlite_vec_registered();
36
37        let conn = Connection::open(path)
38            .with_context(|| format!("Failed to open memory database at {path}"))?;
39
40        conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")?;
41
42        let store = Self {
43            conn: Mutex::new(conn),
44            embedding_dimension: config.embedding_dimension,
45        };
46        store.init_schema()?;
47        Ok(store)
48    }
49
50    pub fn new_in_memory(embedding_dimension: usize) -> Result<Self> {
51        ensure_sqlite_vec_registered();
52
53        let conn = Connection::open_in_memory()?;
54        conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")?;
55
56        let store = Self {
57            conn: Mutex::new(conn),
58            embedding_dimension,
59        };
60        store.init_schema()?;
61        Ok(store)
62    }
63
64    fn init_schema(&self) -> Result<()> {
65        let conn = self.conn.lock().unwrap();
66        conn.execute_batch(
67            "CREATE TABLE IF NOT EXISTS memories (
68                id TEXT PRIMARY KEY,
69                content TEXT NOT NULL,
70                source TEXT NOT NULL DEFAULT 'manual',
71                category TEXT NOT NULL DEFAULT 'other',
72                container_tag TEXT NOT NULL DEFAULT 'default',
73                metadata TEXT DEFAULT '{}',
74                session_id TEXT,
75                created_at TEXT NOT NULL DEFAULT (datetime('now')),
76                updated_at TEXT NOT NULL DEFAULT (datetime('now'))
77            );
78
79            CREATE INDEX IF NOT EXISTS idx_memories_container ON memories(container_tag);
80            CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
81            CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);
82            CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at);
83
84            CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
85                content,
86                id UNINDEXED,
87                tokenize='porter unicode61'
88            );
89
90            CREATE TABLE IF NOT EXISTS embedding_cache (
91                content_hash TEXT PRIMARY KEY,
92                embedding BLOB NOT NULL,
93                model TEXT NOT NULL,
94                created_at TEXT NOT NULL DEFAULT (datetime('now'))
95            );
96
97            CREATE TABLE IF NOT EXISTS user_profile (
98                id TEXT PRIMARY KEY,
99                fact_type TEXT NOT NULL DEFAULT 'static',
100                category TEXT NOT NULL DEFAULT 'fact',
101                content TEXT NOT NULL,
102                created_at TEXT NOT NULL DEFAULT (datetime('now')),
103                updated_at TEXT NOT NULL DEFAULT (datetime('now'))
104            );",
105        )?;
106
107        let dim = self.embedding_dimension;
108        conn.execute_batch(&format!(
109            "CREATE VIRTUAL TABLE IF NOT EXISTS memories_vec USING vec0(
110                id TEXT PRIMARY KEY,
111                embedding float[{dim}]
112            );"
113        ))?;
114
115        Ok(())
116    }
117
118    pub fn insert(&self, entry: &MemoryEntry, embedding: &[f32]) -> Result<()> {
119        let conn = self.conn.lock().unwrap();
120        let tx = conn.unchecked_transaction()?;
121
122        let metadata_json = serde_json::to_string(&entry.metadata)?;
123        let created = entry.created_at.to_rfc3339();
124        let updated = entry.updated_at.to_rfc3339();
125
126        tx.execute(
127            "INSERT OR REPLACE INTO memories (id, content, source, category, container_tag, metadata, session_id, created_at, updated_at)
128             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
129            params![
130                entry.id,
131                entry.content,
132                entry.source.as_str(),
133                entry.category.as_str(),
134                entry.container_tag,
135                metadata_json,
136                entry.session_id,
137                created,
138                updated,
139            ],
140        )?;
141
142        let embedding_bytes = embedding_to_bytes(embedding);
143        tx.execute(
144            "INSERT OR REPLACE INTO memories_vec (id, embedding) VALUES (?1, ?2)",
145            params![entry.id, embedding_bytes],
146        )?;
147
148        tx.execute(
149            "INSERT OR REPLACE INTO memories_fts (id, content) VALUES (?1, ?2)",
150            params![entry.id, entry.content],
151        )?;
152
153        tx.commit()?;
154        Ok(())
155    }
156
157    pub fn store_text_only(&self, entry: &MemoryEntry) -> Result<i64> {
158        let conn = self.conn.lock().unwrap();
159        let tx = conn.unchecked_transaction()?;
160
161        let metadata_json = serde_json::to_string(&entry.metadata)?;
162        let created = entry.created_at.to_rfc3339();
163        let updated = entry.updated_at.to_rfc3339();
164
165        tx.execute(
166            "INSERT OR REPLACE INTO memories (id, content, source, category, container_tag, metadata, session_id, created_at, updated_at)
167             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
168            params![
169                entry.id,
170                entry.content,
171                entry.source.as_str(),
172                entry.category.as_str(),
173                entry.container_tag,
174                metadata_json,
175                entry.session_id,
176                created,
177                updated,
178            ],
179        )?;
180
181        tx.execute(
182            "INSERT OR REPLACE INTO memories_fts (id, content) VALUES (?1, ?2)",
183            params![entry.id, entry.content],
184        )?;
185
186        let row_id = tx.last_insert_rowid();
187        tx.commit()?;
188        Ok(row_id)
189    }
190
191    pub fn update_embedding(&self, id: i64, embedding: &[f32]) -> Result<()> {
192        let conn = self.conn.lock().unwrap();
193        let tx = conn.unchecked_transaction()?;
194        let embedding_bytes = embedding_to_bytes(embedding);
195        let updated = tx.execute(
196            "INSERT OR REPLACE INTO memories_vec (id, embedding)
197             SELECT id, ?1 FROM memories WHERE rowid = ?2",
198            params![embedding_bytes, id],
199        )?;
200
201        if updated == 0 {
202            anyhow::bail!("memory row not found for rowid {id}");
203        }
204
205        tx.commit()?;
206        Ok(())
207    }
208
209    pub fn get(&self, id: &str) -> Result<Option<MemoryEntry>> {
210        let conn = self.conn.lock().unwrap();
211        let mut stmt = conn.prepare(
212            "SELECT id, content, source, category, container_tag, metadata, session_id, created_at, updated_at
213             FROM memories WHERE id = ?1",
214        )?;
215
216        let result = stmt
217            .query_row(params![id], |row| Ok(row_to_entry(row)))
218            .optional()?;
219
220        match result {
221            Some(entry) => Ok(Some(entry?)),
222            None => Ok(None),
223        }
224    }
225
226    pub fn delete(&self, id: &str) -> Result<bool> {
227        let conn = self.conn.lock().unwrap();
228        let tx = conn.unchecked_transaction()?;
229
230        let deleted = tx.execute("DELETE FROM memories WHERE id = ?1", params![id])?;
231        tx.execute("DELETE FROM memories_vec WHERE id = ?1", params![id])?;
232        tx.execute("DELETE FROM memories_fts WHERE id = ?1", params![id])?;
233
234        tx.commit()?;
235        Ok(deleted > 0)
236    }
237
238    pub fn delete_by_ids(&self, ids: &[String]) -> Result<usize> {
239        if ids.is_empty() {
240            return Ok(0);
241        }
242        let conn = self.conn.lock().unwrap();
243        let tx = conn.unchecked_transaction()?;
244        let mut count = 0;
245
246        for id in ids {
247            count += tx.execute("DELETE FROM memories WHERE id = ?1", params![id])?;
248            tx.execute("DELETE FROM memories_vec WHERE id = ?1", params![id])?;
249            tx.execute("DELETE FROM memories_fts WHERE id = ?1", params![id])?;
250        }
251
252        tx.commit()?;
253        Ok(count)
254    }
255
256    pub fn list(&self, container_tag: Option<&str>, limit: usize) -> Result<Vec<MemoryEntry>> {
257        let conn = self.conn.lock().unwrap();
258        let (sql, param_values): (String, Vec<Box<dyn rusqlite::types::ToSql>>) = match container_tag {
259            Some(tag) => (
260                "SELECT id, content, source, category, container_tag, metadata, session_id, created_at, updated_at
261                 FROM memories WHERE container_tag = ?1 ORDER BY created_at DESC LIMIT ?2"
262                    .to_string(),
263                vec![Box::new(tag.to_string()), Box::new(limit as i64)],
264            ),
265            None => (
266                "SELECT id, content, source, category, container_tag, metadata, session_id, created_at, updated_at
267                 FROM memories ORDER BY created_at DESC LIMIT ?1"
268                    .to_string(),
269                vec![Box::new(limit as i64)],
270            ),
271        };
272
273        let mut stmt = conn.prepare(&sql)?;
274        let params_ref: Vec<&dyn rusqlite::types::ToSql> =
275            param_values.iter().map(|b| b.as_ref()).collect();
276        let rows = stmt.query_map(params_ref.as_slice(), |row| Ok(row_to_entry(row)))?;
277
278        let mut entries = Vec::new();
279        for row in rows {
280            entries.push(row??);
281        }
282        Ok(entries)
283    }
284
285    pub fn vector_search(&self, embedding: &[f32], limit: usize) -> Result<Vec<(String, f32)>> {
286        let conn = self.conn.lock().unwrap();
287        let embedding_bytes = embedding_to_bytes(embedding);
288
289        let mut stmt = conn.prepare(
290            "SELECT id, distance FROM memories_vec WHERE embedding MATCH ?1 ORDER BY distance LIMIT ?2",
291        )?;
292
293        let rows = stmt.query_map(params![embedding_bytes, limit as i64], |row| {
294            Ok((row.get::<_, String>(0)?, row.get::<_, f32>(1)?))
295        })?;
296
297        let mut results = Vec::new();
298        for row in rows {
299            results.push(row?);
300        }
301        Ok(results)
302    }
303
304    pub fn fts_search(&self, query: &str, limit: usize) -> Result<Vec<(String, f64)>> {
305        let conn = self.conn.lock().unwrap();
306        let mut stmt = conn.prepare(
307            "SELECT id, bm25(memories_fts) as rank FROM memories_fts WHERE memories_fts MATCH ?1 ORDER BY rank LIMIT ?2",
308        )?;
309
310        let rows = stmt.query_map(params![query, limit as i64], |row| {
311            Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?))
312        })?;
313
314        let mut results = Vec::new();
315        for row in rows {
316            results.push(row?);
317        }
318        Ok(results)
319    }
320
321    pub fn get_cached_embedding(&self, content_hash: &str) -> Result<Option<Vec<f32>>> {
322        let conn = self.conn.lock().unwrap();
323        let mut stmt =
324            conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
325
326        let result = stmt
327            .query_row(params![content_hash], |row| {
328                let bytes: Vec<u8> = row.get(0)?;
329                Ok(bytes_to_embedding(&bytes))
330            })
331            .optional()?;
332
333        Ok(result)
334    }
335
336    pub fn cache_embedding(
337        &self,
338        content_hash: &str,
339        embedding: &[f32],
340        model: &str,
341    ) -> Result<()> {
342        let conn = self.conn.lock().unwrap();
343        let bytes = embedding_to_bytes(embedding);
344        conn.execute(
345            "INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, model) VALUES (?1, ?2, ?3)",
346            params![content_hash, bytes, model],
347        )?;
348        Ok(())
349    }
350
351    pub fn add_profile_fact(&self, fact: &UserProfileFact) -> Result<()> {
352        let conn = self.conn.lock().unwrap();
353        conn.execute(
354            "INSERT OR REPLACE INTO user_profile (id, fact_type, category, content, created_at, updated_at)
355             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
356            params![
357                fact.id,
358                fact.fact_type,
359                fact.category,
360                fact.content,
361                fact.created_at.to_rfc3339(),
362                fact.updated_at.to_rfc3339(),
363            ],
364        )?;
365        Ok(())
366    }
367
368    pub fn get_profile_facts(&self, fact_type: Option<&str>) -> Result<Vec<UserProfileFact>> {
369        let conn = self.conn.lock().unwrap();
370        let (sql, param_values): (String, Vec<Box<dyn rusqlite::types::ToSql>>) = match fact_type {
371            Some(ft) => (
372                "SELECT id, fact_type, category, content, created_at, updated_at FROM user_profile WHERE fact_type = ?1 ORDER BY created_at DESC".to_string(),
373                vec![Box::new(ft.to_string())],
374            ),
375            None => (
376                "SELECT id, fact_type, category, content, created_at, updated_at FROM user_profile ORDER BY created_at DESC".to_string(),
377                vec![],
378            ),
379        };
380
381        let mut stmt = conn.prepare(&sql)?;
382        let params_ref: Vec<&dyn rusqlite::types::ToSql> =
383            param_values.iter().map(|b| b.as_ref()).collect();
384        let rows = stmt.query_map(params_ref.as_slice(), |row| Ok(row_to_profile_fact(row)))?;
385
386        let mut facts = Vec::new();
387        for row in rows {
388            facts.push(row??);
389        }
390        Ok(facts)
391    }
392
393    pub fn remove_profile_fact(&self, id: &str) -> Result<bool> {
394        let conn = self.conn.lock().unwrap();
395        let deleted = conn.execute("DELETE FROM user_profile WHERE id = ?1", params![id])?;
396        Ok(deleted > 0)
397    }
398
399    pub fn stats(&self) -> Result<MemoryStats> {
400        let conn = self.conn.lock().unwrap();
401
402        let total: usize = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
403
404        let mut by_category = HashMap::new();
405        let mut stmt = conn.prepare("SELECT category, COUNT(*) FROM memories GROUP BY category")?;
406        let rows = stmt.query_map([], |row| {
407            Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?))
408        })?;
409        for row in rows {
410            let (cat, count) = row?;
411            by_category.insert(cat, count);
412        }
413
414        let mut by_container = HashMap::new();
415        let mut stmt =
416            conn.prepare("SELECT container_tag, COUNT(*) FROM memories GROUP BY container_tag")?;
417        let rows = stmt.query_map([], |row| {
418            Ok((row.get::<_, String>(0)?, row.get::<_, usize>(1)?))
419        })?;
420        for row in rows {
421            let (tag, count) = row?;
422            by_container.insert(tag, count);
423        }
424
425        let db_size = conn
426            .query_row(
427                "SELECT page_count * page_size FROM pragma_page_count, pragma_page_size",
428                [],
429                |row| row.get::<_, u64>(0),
430            )
431            .unwrap_or(0);
432
433        Ok(MemoryStats {
434            total_memories: total,
435            total_by_category: by_category,
436            total_by_container: by_container,
437            db_size_bytes: db_size,
438        })
439    }
440
441    pub fn count(&self) -> Result<usize> {
442        let conn = self.conn.lock().unwrap();
443        let count: usize = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
444        Ok(count)
445    }
446
447    pub fn cleanup(&self, container_tag: &str, retention_days: Option<u32>, max_memories: Option<usize>) -> Result<usize> {
448        let conn = self.conn.lock().unwrap();
449        let tx = conn.unchecked_transaction()?;
450        let mut total_deleted = 0;
451
452        // If both are None, return early (no-op)
453        if retention_days.is_none() && max_memories.is_none() {
454            return Ok(0);
455        }
456
457        // Delete old entries by retention_days
458        if let Some(days) = retention_days {
459            let deleted = tx.execute(
460                "DELETE FROM memories WHERE created_at < datetime('now', ?1) AND container_tag = ?2",
461                params![format!("-{} days", days), container_tag],
462            )?;
463            total_deleted += deleted;
464        }
465
466        // Keep only newest N memories per container tag
467        if let Some(max_count) = max_memories {
468            let deleted = tx.execute(
469                "DELETE FROM memories WHERE container_tag = ?1 AND id NOT IN (SELECT id FROM memories WHERE container_tag = ?2 ORDER BY created_at DESC LIMIT ?3)",
470                params![container_tag, container_tag, max_count as i64],
471            )?;
472            total_deleted += deleted;
473        }
474
475        // Clean up orphaned FTS entries
476        tx.execute(
477            "DELETE FROM memories_fts WHERE rowid NOT IN (SELECT rowid FROM memories)",
478            [],
479        )?;
480
481        tx.commit()?;
482
483        tracing::info!(deleted = total_deleted, "memory cleanup completed");
484        Ok(total_deleted)
485    }
486}
487
488fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
489    embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
490}
491
492fn bytes_to_embedding(bytes: &[u8]) -> Vec<f32> {
493    bytes
494        .chunks_exact(4)
495        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
496        .collect()
497}
498
499fn row_to_entry(row: &rusqlite::Row<'_>) -> Result<MemoryEntry> {
500    let metadata_str: String = row.get(5)?;
501    let metadata: HashMap<String, serde_json::Value> =
502        serde_json::from_str(&metadata_str).unwrap_or_default();
503
504    let created_str: String = row.get(7)?;
505    let updated_str: String = row.get(8)?;
506
507    let created_at = DateTime::parse_from_rfc3339(&created_str)
508        .map(|dt| dt.with_timezone(&Utc))
509        .unwrap_or_else(|_| Utc::now());
510    let updated_at = DateTime::parse_from_rfc3339(&updated_str)
511        .map(|dt| dt.with_timezone(&Utc))
512        .unwrap_or_else(|_| Utc::now());
513
514    Ok(MemoryEntry {
515        id: row.get(0)?,
516        content: row.get(1)?,
517        source: MemorySource::from_str_lossy(&row.get::<_, String>(2)?),
518        category: MemoryCategory::from_str_lossy(&row.get::<_, String>(3)?),
519        container_tag: row.get(4)?,
520        metadata,
521        session_id: row.get(6)?,
522        created_at,
523        updated_at,
524    })
525}
526
527fn row_to_profile_fact(row: &rusqlite::Row<'_>) -> Result<UserProfileFact> {
528    let created_str: String = row.get(4)?;
529    let updated_str: String = row.get(5)?;
530
531    let created_at = DateTime::parse_from_rfc3339(&created_str)
532        .map(|dt| dt.with_timezone(&Utc))
533        .unwrap_or_else(|_| Utc::now());
534    let updated_at = DateTime::parse_from_rfc3339(&updated_str)
535        .map(|dt| dt.with_timezone(&Utc))
536        .unwrap_or_else(|_| Utc::now());
537
538    Ok(UserProfileFact {
539        id: row.get(0)?,
540        fact_type: row.get(1)?,
541        category: row.get(2)?,
542        content: row.get(3)?,
543        created_at,
544        updated_at,
545    })
546}
547
548use rusqlite::OptionalExtension;
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553
554    fn make_embedding(dim: usize, seed: f32) -> Vec<f32> {
555        (0..dim)
556            .map(|i| ((i as f32 + seed) / dim as f32).sin())
557            .collect()
558    }
559
560    #[test]
561    fn create_store_and_schema() {
562        let store = MemoryStore::new_in_memory(128).unwrap();
563        let stats = store.stats().unwrap();
564        assert_eq!(stats.total_memories, 0);
565    }
566
567    #[test]
568    fn insert_and_get() {
569        let store = MemoryStore::new_in_memory(128).unwrap();
570        let entry = MemoryEntry::new("I prefer dark mode", MemorySource::Manual, "default");
571        let embedding = make_embedding(128, 1.0);
572
573        store.insert(&entry, &embedding).unwrap();
574        let retrieved = store.get(&entry.id).unwrap().unwrap();
575
576        assert_eq!(retrieved.content, "I prefer dark mode");
577        assert_eq!(retrieved.category, MemoryCategory::Preference);
578        assert_eq!(retrieved.container_tag, "default");
579    }
580
581    #[test]
582    fn store_text_only_inserts_and_returns_row_id() {
583        let store = MemoryStore::new_in_memory(128).unwrap();
584        let entry = MemoryEntry::new("text-only memory", MemorySource::Conversation, "default");
585
586        let row_id = store.store_text_only(&entry).unwrap();
587
588        assert!(row_id > 0);
589        assert_eq!(store.count().unwrap(), 1);
590        let retrieved = store.get(&entry.id).unwrap().unwrap();
591        assert_eq!(retrieved.content, "text-only memory");
592    }
593
594    #[test]
595    fn update_embedding_updates_vector_row() {
596        let store = MemoryStore::new_in_memory(128).unwrap();
597        let entry = MemoryEntry::new("needs embedding", MemorySource::Conversation, "default");
598        let embedding = make_embedding(128, 4.0);
599
600        let row_id = store.store_text_only(&entry).unwrap();
601        store.update_embedding(row_id, &embedding).unwrap();
602
603        let results = store.vector_search(&embedding, 1).unwrap();
604        assert_eq!(results.len(), 1);
605        assert_eq!(results[0].0, entry.id);
606    }
607
608    #[test]
609    fn delete_removes_from_all_tables() {
610        let store = MemoryStore::new_in_memory(128).unwrap();
611        let entry = MemoryEntry::new("test content", MemorySource::Manual, "default");
612        let embedding = make_embedding(128, 1.0);
613
614        store.insert(&entry, &embedding).unwrap();
615        assert_eq!(store.count().unwrap(), 1);
616
617        let deleted = store.delete(&entry.id).unwrap();
618        assert!(deleted);
619        assert_eq!(store.count().unwrap(), 0);
620        assert!(store.get(&entry.id).unwrap().is_none());
621    }
622
623    #[test]
624    fn list_with_container_filter() {
625        let store = MemoryStore::new_in_memory(128).unwrap();
626
627        let e1 = MemoryEntry::new("entry one", MemorySource::Manual, "work");
628        let e2 = MemoryEntry::new("entry two", MemorySource::Manual, "personal");
629        let e3 = MemoryEntry::new("entry three", MemorySource::Manual, "work");
630
631        store.insert(&e1, &make_embedding(128, 1.0)).unwrap();
632        store.insert(&e2, &make_embedding(128, 2.0)).unwrap();
633        store.insert(&e3, &make_embedding(128, 3.0)).unwrap();
634
635        let work = store.list(Some("work"), 10).unwrap();
636        assert_eq!(work.len(), 2);
637
638        let personal = store.list(Some("personal"), 10).unwrap();
639        assert_eq!(personal.len(), 1);
640
641        let all = store.list(None, 10).unwrap();
642        assert_eq!(all.len(), 3);
643    }
644
645    #[test]
646    fn vector_search_returns_results() {
647        let store = MemoryStore::new_in_memory(128).unwrap();
648
649        for i in 0..5 {
650            let entry = MemoryEntry::new(format!("memory {i}"), MemorySource::Manual, "default");
651            store
652                .insert(&entry, &make_embedding(128, i as f32))
653                .unwrap();
654        }
655
656        let query_embedding = make_embedding(128, 2.5);
657        let results = store.vector_search(&query_embedding, 3).unwrap();
658        assert!(!results.is_empty());
659        assert!(results.len() <= 3);
660    }
661
662    #[test]
663    fn fts_search_returns_results() {
664        let store = MemoryStore::new_in_memory(128).unwrap();
665
666        let e1 = MemoryEntry::new("rust programming language", MemorySource::Manual, "default");
667        let e2 = MemoryEntry::new("python scripting language", MemorySource::Manual, "default");
668        let e3 = MemoryEntry::new("rust async runtime tokio", MemorySource::Manual, "default");
669
670        store.insert(&e1, &make_embedding(128, 1.0)).unwrap();
671        store.insert(&e2, &make_embedding(128, 2.0)).unwrap();
672        store.insert(&e3, &make_embedding(128, 3.0)).unwrap();
673
674        let results = store.fts_search("rust", 10).unwrap();
675        assert_eq!(results.len(), 2);
676    }
677
678    #[test]
679    fn embedding_cache() {
680        let store = MemoryStore::new_in_memory(128).unwrap();
681        let hash = "abc123";
682        let embedding = make_embedding(128, 1.0);
683
684        assert!(store.get_cached_embedding(hash).unwrap().is_none());
685
686        store
687            .cache_embedding(hash, &embedding, "test-model")
688            .unwrap();
689
690        let cached = store.get_cached_embedding(hash).unwrap().unwrap();
691        assert_eq!(cached.len(), 128);
692        assert!((cached[0] - embedding[0]).abs() < 1e-6);
693    }
694
695    #[test]
696    fn stats_counting() {
697        let store = MemoryStore::new_in_memory(128).unwrap();
698
699        let e1 = MemoryEntry::new("I prefer vim", MemorySource::Manual, "work");
700        let e2 = MemoryEntry::new("The sky is blue", MemorySource::Manual, "personal");
701
702        store.insert(&e1, &make_embedding(128, 1.0)).unwrap();
703        store.insert(&e2, &make_embedding(128, 2.0)).unwrap();
704
705        let stats = store.stats().unwrap();
706        assert_eq!(stats.total_memories, 2);
707        assert!(stats.total_by_container.contains_key("work"));
708        assert!(stats.total_by_container.contains_key("personal"));
709    }
710
711    #[test]
712    fn profile_facts_crud() {
713        let store = MemoryStore::new_in_memory(128).unwrap();
714
715        let fact = UserProfileFact {
716            id: "f1".to_string(),
717            fact_type: "static".to_string(),
718            category: "preference".to_string(),
719            content: "Prefers dark mode".to_string(),
720            created_at: Utc::now(),
721            updated_at: Utc::now(),
722        };
723
724        store.add_profile_fact(&fact).unwrap();
725
726        let facts = store.get_profile_facts(Some("static")).unwrap();
727        assert_eq!(facts.len(), 1);
728        assert_eq!(facts[0].content, "Prefers dark mode");
729
730        let removed = store.remove_profile_fact("f1").unwrap();
731        assert!(removed);
732
733        let facts = store.get_profile_facts(None).unwrap();
734        assert!(facts.is_empty());
735    }
736    #[test]
737    fn cleanup_with_retention_days_keeps_recent_entries() {
738        let store = MemoryStore::new_in_memory(128).unwrap();
739        let container = "test-container";
740
741        // Insert an entry
742        let entry = MemoryEntry::new("recent memory", MemorySource::Manual, container);
743        store.insert(&entry, &make_embedding(128, 1.0)).unwrap();
744        assert_eq!(store.count().unwrap(), 1);
745
746        // Cleanup with retention_days=1000 keeps recent entries
747        let deleted = store.cleanup(container, Some(1000), None).unwrap();
748        assert_eq!(deleted, 0);
749        assert_eq!(store.count().unwrap(), 1);
750    }
751
752    #[test]
753    fn cleanup_with_max_memories_keeps_only_newest() {
754        let store = MemoryStore::new_in_memory(128).unwrap();
755        let container = "test-container";
756
757        // Insert 5 entries
758        for i in 0..5 {
759            let entry = MemoryEntry::new(format!("memory {i}"), MemorySource::Manual, container);
760            store.insert(&entry, &make_embedding(128, i as f32)).unwrap();
761        }
762        assert_eq!(store.count().unwrap(), 5);
763
764        // Cleanup with max_memories=2 should keep only 2 newest
765        let deleted = store.cleanup(container, None, Some(2)).unwrap();
766        assert_eq!(deleted, 3);
767        assert_eq!(store.count().unwrap(), 2);
768    }
769
770    #[test]
771    fn cleanup_with_both_none_is_noop() {
772        let store = MemoryStore::new_in_memory(128).unwrap();
773        let container = "test-container";
774
775        // Insert an entry
776        let entry = MemoryEntry::new("memory", MemorySource::Manual, container);
777        store.insert(&entry, &make_embedding(128, 1.0)).unwrap();
778        assert_eq!(store.count().unwrap(), 1);
779
780        // Cleanup with both None should be a no-op
781        let deleted = store.cleanup(container, None, None).unwrap();
782        assert_eq!(deleted, 0);
783        assert_eq!(store.count().unwrap(), 1);
784    }
785
786}