1use super::config::MnemoConfig;
4use super::schema::{CURRENT_VERSION, SCHEMA, VERSION_CHECK};
5use super::traits::{CompactionStats, MemoryEntry, MemoryHit, MemoryStore, SummaryKind, Summarizer};
6use super::estimate_tokens;
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use std::path::Path;
10use std::sync::Mutex;
11use std::time::Instant;
12
13pub struct SqliteMemoryStore {
15 conn: Mutex<rusqlite::Connection>,
16 config: MnemoConfig,
17 conversation_id: Mutex<Option<i64>>,
19}
20
21impl SqliteMemoryStore {
22 pub async fn open(path: &Path, config: MnemoConfig) -> Result<Self> {
24 if let Some(parent) = path.parent() {
26 std::fs::create_dir_all(parent)
27 .with_context(|| format!("Failed to create mnemo directory: {:?}", parent))?;
28 }
29
30 let conn = rusqlite::Connection::open(path)
32 .with_context(|| format!("Failed to open mnemo database: {:?}", path))?;
33
34 conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA synchronous=NORMAL;")?;
36
37 conn.execute_batch(VERSION_CHECK)?;
39 conn.execute_batch(SCHEMA)?;
40
41 let version: i32 = conn
43 .query_row(
44 "SELECT COALESCE(MAX(version), 0) FROM schema_version",
45 [],
46 |row| row.get(0),
47 )
48 .unwrap_or(0);
49
50 if version < CURRENT_VERSION {
51 conn.execute(
52 "INSERT OR REPLACE INTO schema_version (version) VALUES (?)",
53 [CURRENT_VERSION],
54 )?;
55 }
56
57 Ok(Self {
58 conn: Mutex::new(conn),
59 config,
60 conversation_id: Mutex::new(None),
61 })
62 }
63
64 fn ensure_conversation(&self) -> Result<i64> {
66 let mut conv_id = self.conversation_id.lock().unwrap();
67 if let Some(id) = *conv_id {
68 return Ok(id);
69 }
70
71 let conn = self.conn.lock().unwrap();
72
73 let existing: Option<i64> = conn
75 .query_row(
76 "SELECT id FROM conversations WHERE agent_id = 'default' AND session_id = 'main'",
77 [],
78 |row| row.get(0),
79 )
80 .ok();
81
82 let id = if let Some(id) = existing {
83 id
84 } else {
85 conn.execute(
86 "INSERT INTO conversations (agent_id, session_id) VALUES ('default', 'main')",
87 [],
88 )?;
89 conn.last_insert_rowid()
90 };
91
92 *conv_id = Some(id);
93 Ok(id)
94 }
95
96 fn get_compaction_candidates(&self, count: usize) -> Result<Vec<MemoryEntry>> {
98 let conv_id = self.ensure_conversation()?;
99 let conn = self.conn.lock().unwrap();
100
101 let total: i64 = conn.query_row(
102 "SELECT COUNT(*) FROM context_items WHERE conversation_id = ?",
103 [conv_id],
104 |row| row.get(0),
105 )?;
106
107 let fresh_tail = self.config.fresh_tail_messages;
108 if (total as usize) <= fresh_tail {
109 return Ok(Vec::new());
110 }
111
112 let available = (total as usize) - fresh_tail;
113 let to_get = count.min(available);
114
115 let mut stmt = conn.prepare(
116 "SELECT item_type, ref_id FROM context_items
117 WHERE conversation_id = ?
118 ORDER BY position ASC
119 LIMIT ?",
120 )?;
121
122 let items: Vec<(String, i64)> = stmt
123 .query_map(rusqlite::params![conv_id, to_get as i64], |row| {
124 Ok((row.get(0)?, row.get(1)?))
125 })?
126 .filter_map(|r| r.ok())
127 .collect();
128
129 let mut entries = Vec::new();
130 for (item_type, ref_id) in items {
131 match item_type.as_str() {
132 "message" => {
133 if let Ok(entry) = self.get_message_entry(&conn, ref_id) {
134 entries.push(entry);
135 }
136 }
137 "summary" => {
138 if let Ok(entry) = self.get_summary_entry(&conn, ref_id) {
139 entries.push(entry);
140 }
141 }
142 _ => {}
143 }
144 }
145
146 Ok(entries)
147 }
148
149 fn get_message_entry(&self, conn: &rusqlite::Connection, id: i64) -> Result<MemoryEntry> {
150 conn.query_row(
151 "SELECT id, role, content, token_count, created_at FROM messages WHERE id = ?",
152 [id],
153 |row| {
154 Ok(MemoryEntry {
155 id: row.get(0)?,
156 role: row.get(1)?,
157 content: row.get(2)?,
158 token_count: row.get::<_, i32>(3)? as usize,
159 timestamp: row.get(4)?,
160 depth: 0,
161 })
162 },
163 )
164 .map_err(|e| anyhow::anyhow!("Failed to get message {}: {}", id, e))
165 }
166
167 fn get_summary_entry(&self, conn: &rusqlite::Connection, id: i64) -> Result<MemoryEntry> {
168 conn.query_row(
169 "SELECT id, depth, content, token_count, created_at FROM summaries WHERE id = ?",
170 [id],
171 |row| {
172 Ok(MemoryEntry {
173 id: row.get(0)?,
174 role: "summary".to_string(),
175 content: row.get(2)?,
176 token_count: row.get::<_, i32>(3)? as usize,
177 timestamp: row.get(4)?,
178 depth: row.get::<_, i32>(1)? as u8,
179 })
180 },
181 )
182 .map_err(|e| anyhow::anyhow!("Failed to get summary {}: {}", id, e))
183 }
184
185 fn create_summary_from_messages(
187 &self,
188 message_ids: &[i64],
189 summary_content: &str,
190 ) -> Result<i64> {
191 let conv_id = self.ensure_conversation()?;
192 let conn = self.conn.lock().unwrap();
193 let token_count = estimate_tokens(summary_content) as i32;
194
195 conn.execute(
197 "INSERT INTO summaries (conversation_id, depth, content, token_count) VALUES (?, 0, ?, ?)",
198 rusqlite::params![conv_id, summary_content, token_count],
199 )?;
200
201 let summary_id = conn.last_insert_rowid();
202
203 for &msg_id in message_ids {
205 conn.execute(
206 "INSERT INTO summary_messages (summary_id, message_id) VALUES (?, ?)",
207 [summary_id, msg_id],
208 )?;
209 }
210
211 let mut positions_to_remove = Vec::new();
213 for &msg_id in message_ids {
214 let pos: Option<i64> = conn
215 .query_row(
216 "SELECT position FROM context_items WHERE conversation_id = ? AND item_type = 'message' AND ref_id = ?",
217 rusqlite::params![conv_id, msg_id],
218 |row| row.get(0),
219 )
220 .ok();
221 if let Some(p) = pos {
222 positions_to_remove.push(p);
223 }
224 }
225
226 for &msg_id in message_ids {
228 conn.execute(
229 "DELETE FROM context_items WHERE conversation_id = ? AND item_type = 'message' AND ref_id = ?",
230 rusqlite::params![conv_id, msg_id],
231 )?;
232 }
233
234 if let Some(&min_pos) = positions_to_remove.iter().min() {
236 conn.execute(
237 "INSERT INTO context_items (conversation_id, item_type, ref_id, position) VALUES (?, 'summary', ?, ?)",
238 rusqlite::params![conv_id, summary_id, min_pos],
239 )?;
240 }
241
242 Ok(summary_id)
243 }
244
245 fn needs_compaction(&self) -> Result<bool> {
247 let conv_id = self.ensure_conversation()?;
248 let conn = self.conn.lock().unwrap();
249 let count: i64 = conn.query_row(
250 "SELECT COUNT(*) FROM context_items WHERE conversation_id = ?",
251 [conv_id],
252 |row| row.get(0),
253 )?;
254 Ok(count as usize > self.config.threshold_items)
255 }
256}
257
258#[async_trait]
259impl MemoryStore for SqliteMemoryStore {
260 fn name(&self) -> &str {
261 "sqlite"
262 }
263
264 async fn ingest(&self, role: &str, content: &str, token_count: usize) -> Result<i64> {
265 let conv_id = self.ensure_conversation()?;
266 let conn = self.conn.lock().unwrap();
267
268 let seq: i64 = conn.query_row(
270 "SELECT COALESCE(MAX(seq), 0) + 1 FROM messages WHERE conversation_id = ?",
271 [conv_id],
272 |row| row.get(0),
273 )?;
274
275 conn.execute(
277 "INSERT INTO messages (conversation_id, role, content, seq, token_count) VALUES (?, ?, ?, ?, ?)",
278 rusqlite::params![conv_id, role, content, seq, token_count as i32],
279 )?;
280
281 let msg_id = conn.last_insert_rowid();
282
283 let position: i64 = conn.query_row(
285 "SELECT COALESCE(MAX(position), 0) + 1 FROM context_items WHERE conversation_id = ?",
286 [conv_id],
287 |row| row.get(0),
288 )?;
289
290 conn.execute(
291 "INSERT INTO context_items (conversation_id, item_type, ref_id, position) VALUES (?, 'message', ?, ?)",
292 rusqlite::params![conv_id, msg_id, position],
293 )?;
294
295 conn.execute(
297 "UPDATE conversations SET updated_at = strftime('%s', 'now') WHERE id = ?",
298 [conv_id],
299 )?;
300
301 Ok(msg_id)
302 }
303
304 async fn search(&self, query: &str, limit: usize) -> Result<Vec<MemoryHit>> {
305 let conn = self.conn.lock().unwrap();
306
307 let mut stmt = conn.prepare(
308 "SELECT m.id, m.role, m.content, m.token_count, m.created_at
309 FROM messages m
310 JOIN messages_fts fts ON m.id = fts.rowid
311 WHERE messages_fts MATCH ?
312 ORDER BY rank
313 LIMIT ?",
314 )?;
315
316 let hits: Vec<MemoryHit> = stmt
317 .query_map(rusqlite::params![query, limit as i64], |row| {
318 let content: String = row.get(2)?;
319 Ok(MemoryHit {
320 entry: MemoryEntry {
321 id: row.get(0)?,
322 role: row.get(1)?,
323 content: content.clone(),
324 token_count: row.get::<_, i32>(3)? as usize,
325 timestamp: row.get(4)?,
326 depth: 0,
327 },
328 score: 1.0, snippet: content,
330 })
331 })?
332 .filter_map(|r| r.ok())
333 .collect();
334
335 Ok(hits)
336 }
337
338 async fn get_context(&self, max_tokens: usize) -> Result<String> {
339 let entries = self.get_context_entries(max_tokens).await?;
340 Ok(super::generate_context_md(&entries))
341 }
342
343 async fn get_context_entries(&self, max_tokens: usize) -> Result<Vec<MemoryEntry>> {
344 let conv_id = self.ensure_conversation()?;
345 let conn = self.conn.lock().unwrap();
346
347 let mut stmt = conn.prepare(
348 "SELECT item_type, ref_id FROM context_items
349 WHERE conversation_id = ?
350 ORDER BY position ASC",
351 )?;
352
353 let items: Vec<(String, i64)> = stmt
354 .query_map([conv_id], |row| Ok((row.get(0)?, row.get(1)?)))?
355 .filter_map(|r| r.ok())
356 .collect();
357
358 let mut entries = Vec::new();
359 let mut total_tokens = 0;
360
361 for (item_type, ref_id) in items {
362 let entry = match item_type.as_str() {
363 "message" => self.get_message_entry(&conn, ref_id).ok(),
364 "summary" => self.get_summary_entry(&conn, ref_id).ok(),
365 _ => None,
366 };
367
368 if let Some(e) = entry {
369 if total_tokens + e.token_count > max_tokens && !entries.is_empty() {
370 break;
371 }
372 total_tokens += e.token_count;
373 entries.push(e);
374 }
375 }
376
377 Ok(entries)
378 }
379
380 async fn compact(&self, summarizer: &dyn Summarizer) -> Result<CompactionStats> {
381 let start = Instant::now();
382 let mut stats = CompactionStats::default();
383
384 if !self.needs_compaction()? {
385 return Ok(stats);
386 }
387
388 let candidates = self.get_compaction_candidates(self.config.leaf_chunk_size)?;
390
391 let messages: Vec<&MemoryEntry> = candidates.iter().filter(|e| e.depth == 0).collect();
393
394 if messages.len() >= self.config.leaf_chunk_size {
395 let chunk: Vec<_> = messages
396 .iter()
397 .take(self.config.leaf_chunk_size)
398 .cloned()
399 .cloned()
400 .collect();
401
402 let message_ids: Vec<i64> = chunk.iter().map(|e| e.id).collect();
403 let tokens_before: usize = chunk.iter().map(|e| e.token_count).sum();
404
405 let summary_text = summarizer.summarize(&chunk, SummaryKind::Leaf).await?;
406 self.create_summary_from_messages(&message_ids, &summary_text)?;
407
408 stats.messages_compacted = chunk.len();
409 stats.summaries_created = 1;
410 stats.tokens_saved = tokens_before.saturating_sub(estimate_tokens(&summary_text));
411 }
412
413 stats.duration = start.elapsed();
414 Ok(stats)
415 }
416
417 async fn message_count(&self) -> Result<usize> {
418 let conn = self.conn.lock().unwrap();
419 let count: i64 =
420 conn.query_row("SELECT COUNT(*) FROM messages", [], |row| row.get(0))?;
421 Ok(count as usize)
422 }
423
424 async fn summary_count(&self) -> Result<usize> {
425 let conn = self.conn.lock().unwrap();
426 let count: i64 =
427 conn.query_row("SELECT COUNT(*) FROM summaries", [], |row| row.get(0))?;
428 Ok(count as usize)
429 }
430
431 async fn flush(&self) -> Result<()> {
432 let conn = self.conn.lock().unwrap();
433 conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
434 Ok(())
435 }
436}