Skip to main content

traitclaw_memory_sqlite/
lib.rs

1//! SQLite memory backend for the `TraitClaw` AI agent framework.
2//!
3//! Provides persistent conversation history, working memory, and FTS5-powered
4//! long-term recall — all backed by a single `SQLite` database file.
5//!
6//! # Quick Start
7//!
8//! ```rust,no_run
9//! use traitclaw_memory_sqlite::SqliteMemory;
10//!
11//! let memory = SqliteMemory::new("./agent.db").expect("Failed to open database");
12//! ```
13
14#![deny(missing_docs)]
15#![allow(clippy::redundant_closure)]
16
17use std::sync::Mutex;
18
19use async_trait::async_trait;
20use rusqlite::Connection;
21use serde_json::Value;
22use traitclaw_core::traits::memory::{Memory, MemoryEntry};
23use traitclaw_core::types::message::{Message, MessageRole};
24use traitclaw_core::Result;
25
26/// SQLite-backed memory backend.
27///
28/// Uses a single `SQLite` database file to persist:
29/// - **Conversation memory** — `sessions` + `messages` tables
30/// - **Working memory** — `working_memory` (key/value per session)
31/// - **Long-term memory** — `long_term_memory` + FTS5 virtual table
32pub struct SqliteMemory {
33    conn: Mutex<Connection>,
34}
35
36impl SqliteMemory {
37    /// Open (or create) a SQLite database at the given path.
38    ///
39    /// The schema is auto-created/migrated on first access.
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if the database cannot be opened or the schema
44    /// cannot be created.
45    pub fn new(path: &str) -> Result<Self> {
46        let conn = Connection::open(path)
47            .map_err(|e| traitclaw_core::Error::Runtime(format!("SQLite open error: {e}")))?;
48
49        init_schema(&conn)?;
50
51        Ok(Self {
52            conn: Mutex::new(conn),
53        })
54    }
55
56    /// Create an in-memory SQLite database (useful for testing).
57    ///
58    /// # Errors
59    ///
60    /// Returns an error if the schema cannot be created.
61    pub fn in_memory() -> Result<Self> {
62        let conn = Connection::open_in_memory()
63            .map_err(|e| traitclaw_core::Error::Runtime(format!("SQLite open error: {e}")))?;
64
65        init_schema(&conn)?;
66
67        Ok(Self {
68            conn: Mutex::new(conn),
69        })
70    }
71}
72
73fn init_schema(conn: &Connection) -> Result<()> {
74    conn.execute_batch(
75        "
76        CREATE TABLE IF NOT EXISTS sessions (
77            id         TEXT PRIMARY KEY,
78            created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
79        );
80
81        CREATE TABLE IF NOT EXISTS messages (
82            id         INTEGER PRIMARY KEY AUTOINCREMENT,
83            session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
84            role       TEXT NOT NULL,
85            content    TEXT NOT NULL DEFAULT '',
86            tool_call_id TEXT,
87            created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
88        );
89        CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
90
91        CREATE TABLE IF NOT EXISTS working_memory (
92            session_id TEXT NOT NULL,
93            key        TEXT NOT NULL,
94            value      TEXT NOT NULL,
95            PRIMARY KEY (session_id, key)
96        );
97
98        CREATE TABLE IF NOT EXISTS long_term_memory (
99            id         TEXT PRIMARY KEY,
100            content    TEXT NOT NULL,
101            metadata   TEXT,
102            created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
103        );
104
105        CREATE VIRTUAL TABLE IF NOT EXISTS long_term_fts
106            USING fts5(content, content_rowid='rowid');
107        ",
108    )
109    .map_err(|e| traitclaw_core::Error::Runtime(format!("Schema init error: {e}")))?;
110
111    Ok(())
112}
113
114fn role_to_str(role: &MessageRole) -> &'static str {
115    match role {
116        MessageRole::System => "system",
117        MessageRole::User => "user",
118        MessageRole::Assistant => "assistant",
119        MessageRole::Tool => "tool",
120        _ => "unknown",
121    }
122}
123
124fn str_to_role(s: &str) -> MessageRole {
125    match s {
126        "system" => MessageRole::System,
127        "assistant" => MessageRole::Assistant,
128        "tool" => MessageRole::Tool,
129        // "user" and anything unknown default to User
130        _ => MessageRole::User,
131    }
132}
133
134#[async_trait]
135impl Memory for SqliteMemory {
136    async fn messages(&self, session_id: &str) -> Result<Vec<Message>> {
137        let conn = self
138            .conn
139            .lock()
140            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
141        let mut stmt = conn
142            .prepare("SELECT role, content, tool_call_id FROM messages WHERE session_id = ?1 ORDER BY id")
143            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
144
145        let rows = stmt
146            .query_map([session_id], |row| {
147                let role: String = row.get(0)?;
148                let content: String = row.get(1)?;
149                let tool_call_id: Option<String> = row.get(2)?;
150                Ok(Message {
151                    role: str_to_role(&role),
152                    content,
153                    tool_call_id,
154                })
155            })
156            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
157
158        let mut messages = Vec::new();
159        for row in rows {
160            messages
161                .push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
162        }
163        Ok(messages)
164    }
165
166    async fn append(&self, session_id: &str, message: Message) -> Result<()> {
167        let conn = self
168            .conn
169            .lock()
170            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
171        conn.execute(
172            "INSERT INTO messages (session_id, role, content, tool_call_id) VALUES (?1, ?2, ?3, ?4)",
173            rusqlite::params![
174                session_id,
175                role_to_str(&message.role),
176                message.content,
177                message.tool_call_id,
178            ],
179        )
180        .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
181        Ok(())
182    }
183
184    async fn get_context(&self, session_id: &str, key: &str) -> Result<Option<Value>> {
185        let conn = self
186            .conn
187            .lock()
188            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
189        let result: rusqlite::Result<String> = conn.query_row(
190            "SELECT value FROM working_memory WHERE session_id = ?1 AND key = ?2",
191            rusqlite::params![session_id, key],
192            |row| row.get(0),
193        );
194        match result {
195            Ok(json_str) => {
196                let val: Value = serde_json::from_str(&json_str).unwrap_or(Value::String(json_str));
197                Ok(Some(val))
198            }
199            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
200            Err(e) => Err(traitclaw_core::Error::Runtime(format!("Query error: {e}"))),
201        }
202    }
203
204    async fn set_context(&self, session_id: &str, key: &str, value: Value) -> Result<()> {
205        let conn = self
206            .conn
207            .lock()
208            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
209        let json_str = serde_json::to_string(&value)
210            .map_err(|e| traitclaw_core::Error::Runtime(format!("JSON error: {e}")))?;
211        conn.execute(
212            "INSERT OR REPLACE INTO working_memory (session_id, key, value) VALUES (?1, ?2, ?3)",
213            rusqlite::params![session_id, key, json_str],
214        )
215        .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
216        Ok(())
217    }
218
219    async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
220        let conn = self
221            .conn
222            .lock()
223            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
224        let mut stmt = conn
225            .prepare(
226                "SELECT m.id, m.content, m.metadata, m.created_at
227                 FROM long_term_fts f
228                 JOIN long_term_memory m ON m.rowid = f.rowid
229                 WHERE long_term_fts MATCH ?1
230                 ORDER BY rank
231                 LIMIT ?2",
232            )
233            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
234
235        let rows = stmt
236            .query_map(rusqlite::params![query, limit], |row| {
237                let id: String = row.get(0)?;
238                let content: String = row.get(1)?;
239                let metadata_str: Option<String> = row.get(2)?;
240                let created_at: u64 = row.get(3)?;
241                let metadata = metadata_str.and_then(|s| serde_json::from_str::<Value>(&s).ok());
242                let mut entry = MemoryEntry::now(id, content);
243                entry.metadata = metadata;
244                entry.created_at = created_at;
245                Ok(entry)
246            })
247            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
248
249        let mut entries = Vec::new();
250        for row in rows {
251            entries
252                .push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
253        }
254        Ok(entries)
255    }
256
257    async fn store(&self, entry: MemoryEntry) -> Result<()> {
258        let conn = self
259            .conn
260            .lock()
261            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
262        let metadata_str = entry
263            .metadata
264            .as_ref()
265            .map(|v| serde_json::to_string(v).unwrap_or_default());
266
267        conn.execute(
268            "INSERT OR REPLACE INTO long_term_memory (id, content, metadata, created_at) VALUES (?1, ?2, ?3, ?4)",
269            rusqlite::params![entry.id, entry.content, metadata_str, entry.created_at],
270        )
271        .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
272
273        // Insert into FTS index
274        let rowid: i64 = conn
275            .query_row(
276                "SELECT rowid FROM long_term_memory WHERE id = ?1",
277                rusqlite::params![entry.id],
278                |row| row.get(0),
279            )
280            .map_err(|e| traitclaw_core::Error::Runtime(format!("Rowid error: {e}")))?;
281
282        conn.execute(
283            "INSERT OR REPLACE INTO long_term_fts (rowid, content) VALUES (?1, ?2)",
284            rusqlite::params![rowid, entry.content],
285        )
286        .map_err(|e| traitclaw_core::Error::Runtime(format!("FTS insert error: {e}")))?;
287
288        Ok(())
289    }
290
291    async fn create_session(&self) -> Result<String> {
292        let id = uuid::Uuid::new_v4().to_string();
293        let conn = self
294            .conn
295            .lock()
296            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
297        conn.execute("INSERT INTO sessions (id) VALUES (?1)", [&id])
298            .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
299        Ok(id)
300    }
301
302    async fn list_sessions(&self) -> Result<Vec<String>> {
303        let conn = self
304            .conn
305            .lock()
306            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
307        let mut stmt = conn
308            .prepare("SELECT id FROM sessions ORDER BY created_at")
309            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
310        let rows = stmt
311            .query_map([], |row| row.get::<_, String>(0))
312            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
313
314        let mut ids = Vec::new();
315        for row in rows {
316            ids.push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
317        }
318        Ok(ids)
319    }
320
321    async fn delete_session(&self, session_id: &str) -> Result<()> {
322        let conn = self
323            .conn
324            .lock()
325            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
326        // CASCADE will delete messages; manually delete working_memory
327        conn.execute(
328            "DELETE FROM working_memory WHERE session_id = ?1",
329            [session_id],
330        )
331        .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
332        conn.execute("DELETE FROM messages WHERE session_id = ?1", [session_id])
333            .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
334        conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])
335            .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
336        Ok(())
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[tokio::test]
345    async fn test_session_lifecycle() {
346        let mem = SqliteMemory::in_memory().unwrap();
347        let sid = mem.create_session().await.unwrap();
348
349        let sessions = mem.list_sessions().await.unwrap();
350        assert!(sessions.contains(&sid));
351
352        mem.delete_session(&sid).await.unwrap();
353        let sessions = mem.list_sessions().await.unwrap();
354        assert!(!sessions.contains(&sid));
355    }
356
357    #[tokio::test]
358    async fn test_conversation_persistence() {
359        let mem = SqliteMemory::in_memory().unwrap();
360        let sid = mem.create_session().await.unwrap();
361
362        mem.append(&sid, Message::user("Hello")).await.unwrap();
363        mem.append(&sid, Message::assistant("Hi there!"))
364            .await
365            .unwrap();
366
367        let msgs = mem.messages(&sid).await.unwrap();
368        assert_eq!(msgs.len(), 2);
369        assert!(matches!(msgs[0].role, MessageRole::User));
370        assert_eq!(msgs[0].content, "Hello");
371        assert!(matches!(msgs[1].role, MessageRole::Assistant));
372        assert_eq!(msgs[1].content, "Hi there!");
373    }
374
375    #[tokio::test]
376    async fn test_working_memory() {
377        let mem = SqliteMemory::in_memory().unwrap();
378        let sid = mem.create_session().await.unwrap();
379
380        // Initially empty
381        assert!(mem.get_context(&sid, "task").await.unwrap().is_none());
382
383        // Set and get
384        mem.set_context(&sid, "task", serde_json::json!("coding"))
385            .await
386            .unwrap();
387        let val = mem.get_context(&sid, "task").await.unwrap().unwrap();
388        assert_eq!(val, serde_json::json!("coding"));
389
390        // Overwrite
391        mem.set_context(&sid, "task", serde_json::json!("testing"))
392            .await
393            .unwrap();
394        let val = mem.get_context(&sid, "task").await.unwrap().unwrap();
395        assert_eq!(val, serde_json::json!("testing"));
396    }
397
398    #[tokio::test]
399    async fn test_long_term_store_and_recall() {
400        let mem = SqliteMemory::in_memory().unwrap();
401
402        mem.store(MemoryEntry::now(
403            "1",
404            "Rust is a systems programming language",
405        ))
406        .await
407        .unwrap();
408        mem.store(MemoryEntry::now("2", "Python is great for data science"))
409            .await
410            .unwrap();
411        mem.store(MemoryEntry::now("3", "Rust has zero-cost abstractions"))
412            .await
413            .unwrap();
414
415        let results = mem.recall("Rust programming", 10).await.unwrap();
416        assert!(!results.is_empty());
417        // FTS5 should rank Rust-related entries higher
418        assert!(results.iter().any(|r| r.content.contains("Rust")));
419    }
420
421    #[tokio::test]
422    async fn test_recall_empty() {
423        let mem = SqliteMemory::in_memory().unwrap();
424        let results = mem.recall("anything", 10).await.unwrap();
425        assert!(results.is_empty());
426    }
427
428    #[tokio::test]
429    async fn test_delete_session_clears_messages_and_context() {
430        let mem = SqliteMemory::in_memory().unwrap();
431        let sid = mem.create_session().await.unwrap();
432
433        mem.append(&sid, Message::user("test")).await.unwrap();
434        mem.set_context(&sid, "key", serde_json::json!("val"))
435            .await
436            .unwrap();
437
438        mem.delete_session(&sid).await.unwrap();
439
440        let msgs = mem.messages(&sid).await.unwrap();
441        assert!(msgs.is_empty());
442        assert!(mem.get_context(&sid, "key").await.unwrap().is_none());
443    }
444
445    #[tokio::test]
446    async fn test_tool_message_with_call_id() {
447        let mem = SqliteMemory::in_memory().unwrap();
448        let sid = mem.create_session().await.unwrap();
449
450        mem.append(
451            &sid,
452            Message {
453                role: MessageRole::Tool,
454                content: "result".into(),
455                tool_call_id: Some("call_1".into()),
456            },
457        )
458        .await
459        .unwrap();
460
461        let msgs = mem.messages(&sid).await.unwrap();
462        assert_eq!(msgs.len(), 1);
463        assert!(matches!(msgs[0].role, MessageRole::Tool));
464        assert_eq!(msgs[0].tool_call_id, Some("call_1".into()));
465    }
466}