Skip to main content

rustyclaw_core/mnemo/
sqlite_store.rs

1//! SQLite-backed memory store implementation.
2
3use super::config::MnemoConfig;
4use super::schema::{CURRENT_VERSION, SCHEMA, VERSION_CHECK};
5use super::traits::{CompactionStats, MemoryEntry, MemoryHit, MemoryStore, SummaryKind, Summarizer};
6use super::estimate_tokens;
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use std::path::Path;
10use std::sync::Mutex;
11use std::time::Instant;
12
13/// SQLite-backed memory store with FTS5 support.
14pub struct SqliteMemoryStore {
15    conn: Mutex<rusqlite::Connection>,
16    config: MnemoConfig,
17    /// Current conversation ID (set after first ingest).
18    conversation_id: Mutex<Option<i64>>,
19}
20
21impl SqliteMemoryStore {
22    /// Open or create a memory database at the given path.
23    pub async fn open(path: &Path, config: MnemoConfig) -> Result<Self> {
24        // Ensure parent directory exists
25        if let Some(parent) = path.parent() {
26            std::fs::create_dir_all(parent)
27                .with_context(|| format!("Failed to create mnemo directory: {:?}", parent))?;
28        }
29
30        // Open SQLite with WAL mode for better concurrency
31        let conn = rusqlite::Connection::open(path)
32            .with_context(|| format!("Failed to open mnemo database: {:?}", path))?;
33
34        // Enable WAL mode
35        conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA synchronous=NORMAL;")?;
36
37        // Create schema
38        conn.execute_batch(VERSION_CHECK)?;
39        conn.execute_batch(SCHEMA)?;
40
41        // Check/update version
42        let version: i32 = conn
43            .query_row(
44                "SELECT COALESCE(MAX(version), 0) FROM schema_version",
45                [],
46                |row| row.get(0),
47            )
48            .unwrap_or(0);
49
50        if version < CURRENT_VERSION {
51            conn.execute(
52                "INSERT OR REPLACE INTO schema_version (version) VALUES (?)",
53                [CURRENT_VERSION],
54            )?;
55        }
56
57        Ok(Self {
58            conn: Mutex::new(conn),
59            config,
60            conversation_id: Mutex::new(None),
61        })
62    }
63
64    /// Get or create the default conversation.
65    fn ensure_conversation(&self) -> Result<i64> {
66        let mut conv_id = self.conversation_id.lock().unwrap();
67        if let Some(id) = *conv_id {
68            return Ok(id);
69        }
70
71        let conn = self.conn.lock().unwrap();
72
73        // Try to find existing
74        let existing: Option<i64> = conn
75            .query_row(
76                "SELECT id FROM conversations WHERE agent_id = 'default' AND session_id = 'main'",
77                [],
78                |row| row.get(0),
79            )
80            .ok();
81
82        let id = if let Some(id) = existing {
83            id
84        } else {
85            conn.execute(
86                "INSERT INTO conversations (agent_id, session_id) VALUES ('default', 'main')",
87                [],
88            )?;
89            conn.last_insert_rowid()
90        };
91
92        *conv_id = Some(id);
93        Ok(id)
94    }
95
96    /// Get compaction candidates (oldest non-fresh context items).
97    fn get_compaction_candidates(&self, count: usize) -> Result<Vec<MemoryEntry>> {
98        let conv_id = self.ensure_conversation()?;
99        let conn = self.conn.lock().unwrap();
100
101        let total: i64 = conn.query_row(
102            "SELECT COUNT(*) FROM context_items WHERE conversation_id = ?",
103            [conv_id],
104            |row| row.get(0),
105        )?;
106
107        let fresh_tail = self.config.fresh_tail_messages;
108        if (total as usize) <= fresh_tail {
109            return Ok(Vec::new());
110        }
111
112        let available = (total as usize) - fresh_tail;
113        let to_get = count.min(available);
114
115        let mut stmt = conn.prepare(
116            "SELECT item_type, ref_id FROM context_items 
117             WHERE conversation_id = ? 
118             ORDER BY position ASC
119             LIMIT ?",
120        )?;
121
122        let items: Vec<(String, i64)> = stmt
123            .query_map(rusqlite::params![conv_id, to_get as i64], |row| {
124                Ok((row.get(0)?, row.get(1)?))
125            })?
126            .filter_map(|r| r.ok())
127            .collect();
128
129        let mut entries = Vec::new();
130        for (item_type, ref_id) in items {
131            match item_type.as_str() {
132                "message" => {
133                    if let Ok(entry) = self.get_message_entry(&conn, ref_id) {
134                        entries.push(entry);
135                    }
136                }
137                "summary" => {
138                    if let Ok(entry) = self.get_summary_entry(&conn, ref_id) {
139                        entries.push(entry);
140                    }
141                }
142                _ => {}
143            }
144        }
145
146        Ok(entries)
147    }
148
149    fn get_message_entry(&self, conn: &rusqlite::Connection, id: i64) -> Result<MemoryEntry> {
150        conn.query_row(
151            "SELECT id, role, content, token_count, created_at FROM messages WHERE id = ?",
152            [id],
153            |row| {
154                Ok(MemoryEntry {
155                    id: row.get(0)?,
156                    role: row.get(1)?,
157                    content: row.get(2)?,
158                    token_count: row.get::<_, i32>(3)? as usize,
159                    timestamp: row.get(4)?,
160                    depth: 0,
161                })
162            },
163        )
164        .map_err(|e| anyhow::anyhow!("Failed to get message {}: {}", id, e))
165    }
166
167    fn get_summary_entry(&self, conn: &rusqlite::Connection, id: i64) -> Result<MemoryEntry> {
168        conn.query_row(
169            "SELECT id, depth, content, token_count, created_at FROM summaries WHERE id = ?",
170            [id],
171            |row| {
172                Ok(MemoryEntry {
173                    id: row.get(0)?,
174                    role: "summary".to_string(),
175                    content: row.get(2)?,
176                    token_count: row.get::<_, i32>(3)? as usize,
177                    timestamp: row.get(4)?,
178                    depth: row.get::<_, i32>(1)? as u8,
179                })
180            },
181        )
182        .map_err(|e| anyhow::anyhow!("Failed to get summary {}: {}", id, e))
183    }
184
185    /// Create a summary from messages.
186    fn create_summary_from_messages(
187        &self,
188        message_ids: &[i64],
189        summary_content: &str,
190    ) -> Result<i64> {
191        let conv_id = self.ensure_conversation()?;
192        let conn = self.conn.lock().unwrap();
193        let token_count = estimate_tokens(summary_content) as i32;
194
195        // Insert summary at depth 0 (leaf)
196        conn.execute(
197            "INSERT INTO summaries (conversation_id, depth, content, token_count) VALUES (?, 0, ?, ?)",
198            rusqlite::params![conv_id, summary_content, token_count],
199        )?;
200
201        let summary_id = conn.last_insert_rowid();
202
203        // Link to source messages
204        for &msg_id in message_ids {
205            conn.execute(
206                "INSERT INTO summary_messages (summary_id, message_id) VALUES (?, ?)",
207                [summary_id, msg_id],
208            )?;
209        }
210
211        // Remove source messages from context, add summary
212        let mut positions_to_remove = Vec::new();
213        for &msg_id in message_ids {
214            let pos: Option<i64> = conn
215                .query_row(
216                    "SELECT position FROM context_items WHERE conversation_id = ? AND item_type = 'message' AND ref_id = ?",
217                    rusqlite::params![conv_id, msg_id],
218                    |row| row.get(0),
219                )
220                .ok();
221            if let Some(p) = pos {
222                positions_to_remove.push(p);
223            }
224        }
225
226        // Delete old context items
227        for &msg_id in message_ids {
228            conn.execute(
229                "DELETE FROM context_items WHERE conversation_id = ? AND item_type = 'message' AND ref_id = ?",
230                rusqlite::params![conv_id, msg_id],
231            )?;
232        }
233
234        // Insert summary at the lowest removed position
235        if let Some(&min_pos) = positions_to_remove.iter().min() {
236            conn.execute(
237                "INSERT INTO context_items (conversation_id, item_type, ref_id, position) VALUES (?, 'summary', ?, ?)",
238                rusqlite::params![conv_id, summary_id, min_pos],
239            )?;
240        }
241
242        Ok(summary_id)
243    }
244
245    /// Check if compaction is needed.
246    fn needs_compaction(&self) -> Result<bool> {
247        let conv_id = self.ensure_conversation()?;
248        let conn = self.conn.lock().unwrap();
249        let count: i64 = conn.query_row(
250            "SELECT COUNT(*) FROM context_items WHERE conversation_id = ?",
251            [conv_id],
252            |row| row.get(0),
253        )?;
254        Ok(count as usize > self.config.threshold_items)
255    }
256}
257
258#[async_trait]
259impl MemoryStore for SqliteMemoryStore {
260    fn name(&self) -> &str {
261        "sqlite"
262    }
263
264    async fn ingest(&self, role: &str, content: &str, token_count: usize) -> Result<i64> {
265        let conv_id = self.ensure_conversation()?;
266        let conn = self.conn.lock().unwrap();
267
268        // Get next sequence number
269        let seq: i64 = conn.query_row(
270            "SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE conversation_id = ?",
271            [conv_id],
272            |row| row.get(0),
273        )?;
274
275        // Insert message
276        conn.execute(
277            "INSERT INTO messages (conversation_id, role, content, seq, token_count) VALUES (?, ?, ?, ?, ?)",
278            rusqlite::params![conv_id, role, content, seq, token_count as i32],
279        )?;
280
281        let msg_id = conn.last_insert_rowid();
282
283        // Add to context items
284        let position: i64 = conn.query_row(
285            "SELECT COALESCE(MAX(position), 0) + 1 FROM context_items WHERE conversation_id = ?",
286            [conv_id],
287            |row| row.get(0),
288        )?;
289
290        conn.execute(
291            "INSERT INTO context_items (conversation_id, item_type, ref_id, position) VALUES (?, 'message', ?, ?)",
292            rusqlite::params![conv_id, msg_id, position],
293        )?;
294
295        // Update conversation timestamp
296        conn.execute(
297            "UPDATE conversations SET updated_at = strftime('%s', 'now') WHERE id = ?",
298            [conv_id],
299        )?;
300
301        Ok(msg_id)
302    }
303
304    async fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryHit>> {
305        let conn = self.conn.lock().unwrap();
306
307        let mut stmt = conn.prepare(
308            "SELECT m.id, m.role, m.content, m.token_count, m.created_at
309             FROM messages m
310             JOIN messages_fts fts ON m.id = fts.rowid
311             WHERE messages_fts MATCH ?
312             ORDER BY rank
313             LIMIT ?",
314        )?;
315
316        let hits: Vec<MemoryHit> = stmt
317            .query_map(rusqlite::params![query, limit as i64], |row| {
318                let content: String = row.get(2)?;
319                Ok(MemoryHit {
320                    entry: MemoryEntry {
321                        id: row.get(0)?,
322                        role: row.get(1)?,
323                        content: content.clone(),
324                        token_count: row.get::<_, i32>(3)? as usize,
325                        timestamp: row.get(4)?,
326                        depth: 0,
327                    },
328                    score: 1.0, // FTS5 doesn't expose raw scores easily
329                    snippet: content,
330                })
331            })?
332            .filter_map(|r| r.ok())
333            .collect();
334
335        Ok(hits)
336    }
337
338    async fn get_context(&self, max_tokens: usize) -> Result<String> {
339        let entries = self.get_context_entries(max_tokens).await?;
340        Ok(super::generate_context_md(&entries))
341    }
342
343    async fn get_context_entries(&self, max_tokens: usize) -> Result<Vec<MemoryEntry>> {
344        let conv_id = self.ensure_conversation()?;
345        let conn = self.conn.lock().unwrap();
346
347        let mut stmt = conn.prepare(
348            "SELECT item_type, ref_id FROM context_items 
349             WHERE conversation_id = ? 
350             ORDER BY position ASC",
351        )?;
352
353        let items: Vec<(String, i64)> = stmt
354            .query_map([conv_id], |row| Ok((row.get(0)?, row.get(1)?)))?
355            .filter_map(|r| r.ok())
356            .collect();
357
358        let mut entries = Vec::new();
359        let mut total_tokens = 0;
360
361        for (item_type, ref_id) in items {
362            let entry = match item_type.as_str() {
363                "message" => self.get_message_entry(&conn, ref_id).ok(),
364                "summary" => self.get_summary_entry(&conn, ref_id).ok(),
365                _ => None,
366            };
367
368            if let Some(e) = entry {
369                if total_tokens + e.token_count > max_tokens && !entries.is_empty() {
370                    break;
371                }
372                total_tokens += e.token_count;
373                entries.push(e);
374            }
375        }
376
377        Ok(entries)
378    }
379
380    async fn compact(&self, summarizer: &dyn Summarizer) -> Result<CompactionStats> {
381        let start = Instant::now();
382        let mut stats = CompactionStats::default();
383
384        if !self.needs_compaction()? {
385            return Ok(stats);
386        }
387
388        // Get candidates for leaf compaction
389        let candidates = self.get_compaction_candidates(self.config.leaf_chunk_size)?;
390
391        // Only compact messages at depth 0
392        let messages: Vec<&MemoryEntry> = candidates.iter().filter(|e| e.depth == 0).collect();
393
394        if messages.len() >= self.config.leaf_chunk_size {
395            let chunk: Vec<_> = messages
396                .iter()
397                .take(self.config.leaf_chunk_size)
398                .cloned()
399                .cloned()
400                .collect();
401
402            let message_ids: Vec<i64> = chunk.iter().map(|e| e.id).collect();
403            let tokens_before: usize = chunk.iter().map(|e| e.token_count).sum();
404
405            let summary_text = summarizer.summarize(&chunk, SummaryKind::Leaf).await?;
406            self.create_summary_from_messages(&message_ids, &summary_text)?;
407
408            stats.messages_compacted = chunk.len();
409            stats.summaries_created = 1;
410            stats.tokens_saved = tokens_before.saturating_sub(estimate_tokens(&summary_text));
411        }
412
413        stats.duration = start.elapsed();
414        Ok(stats)
415    }
416
417    async fn message_count(&self) -> Result<usize> {
418        let conn = self.conn.lock().unwrap();
419        let count: i64 =
420            conn.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
421        Ok(count as usize)
422    }
423
424    async fn summary_count(&self) -> Result<usize> {
425        let conn = self.conn.lock().unwrap();
426        let count: i64 =
427            conn.query_row("SELECT COUNT(*) FROM summaries", [], |row| row.get(0))?;
428        Ok(count as usize)
429    }
430
431    async fn flush(&self) -> Result<()> {
432        let conn = self.conn.lock().unwrap();
433        conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
434        Ok(())
435    }
436}