Skip to main content

traitclaw_memory_sqlite/
lib.rs

1//! SQLite memory backend for the `TraitClaw` AI agent framework.
2//!
3//! Provides persistent conversation history, working memory, and FTS5-powered
4//! long-term recall — all backed by a single `SQLite` database file.
5//!
6//! # Quick Start
7//!
8//! ```rust,no_run
9//! use traitclaw_memory_sqlite::SqliteMemory;
10//!
11//! let memory = SqliteMemory::new("./agent.db").expect("Failed to open database");
12//! ```
13
14#![deny(warnings)]
15#![deny(missing_docs)]
16#![warn(clippy::pedantic)]
17#![allow(clippy::module_name_repetitions)]
18#![allow(clippy::doc_markdown)] // proper noun: SQLite
19
20use std::sync::Mutex;
21
22use async_trait::async_trait;
23use rusqlite::Connection;
24use serde_json::Value;
25use traitclaw_core::traits::memory::{Memory, MemoryEntry};
26use traitclaw_core::types::message::{Message, MessageRole};
27use traitclaw_core::Result;
28
29/// SQLite-backed memory backend.
30///
31/// Uses a single `SQLite` database file to persist:
32/// - **Conversation memory** — `sessions` + `messages` tables
33/// - **Working memory** — `working_memory` (key/value per session)
34/// - **Long-term memory** — `long_term_memory` + FTS5 virtual table
35pub struct SqliteMemory {
36    conn: Mutex<Connection>,
37}
38
39impl SqliteMemory {
40    /// Open (or create) a SQLite database at the given path.
41    ///
42    /// The schema is auto-created/migrated on first access.
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if the database cannot be opened or the schema
47    /// cannot be created.
48    pub fn new(path: &str) -> Result<Self> {
49        let conn = Connection::open(path)
50            .map_err(|e| traitclaw_core::Error::Runtime(format!("SQLite open error: {e}")))?;
51
52        init_schema(&conn)?;
53
54        Ok(Self {
55            conn: Mutex::new(conn),
56        })
57    }
58
59    /// Create an in-memory SQLite database (useful for testing).
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if the schema cannot be created.
64    pub fn in_memory() -> Result<Self> {
65        let conn = Connection::open_in_memory()
66            .map_err(|e| traitclaw_core::Error::Runtime(format!("SQLite open error: {e}")))?;
67
68        init_schema(&conn)?;
69
70        Ok(Self {
71            conn: Mutex::new(conn),
72        })
73    }
74}
75
76fn init_schema(conn: &Connection) -> Result<()> {
77    conn.execute_batch(
78        "
79        CREATE TABLE IF NOT EXISTS sessions (
80            id         TEXT PRIMARY KEY,
81            created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
82        );
83
84        CREATE TABLE IF NOT EXISTS messages (
85            id         INTEGER PRIMARY KEY AUTOINCREMENT,
86            session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
87            role       TEXT NOT NULL,
88            content    TEXT NOT NULL DEFAULT '',
89            tool_call_id TEXT,
90            created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
91        );
92        CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
93
94        CREATE TABLE IF NOT EXISTS working_memory (
95            session_id TEXT NOT NULL,
96            key        TEXT NOT NULL,
97            value      TEXT NOT NULL,
98            PRIMARY KEY (session_id, key)
99        );
100
101        CREATE TABLE IF NOT EXISTS long_term_memory (
102            id         TEXT PRIMARY KEY,
103            content    TEXT NOT NULL,
104            metadata   TEXT,
105            created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
106        );
107
108        CREATE VIRTUAL TABLE IF NOT EXISTS long_term_fts
109            USING fts5(content, content_rowid='rowid');
110        ",
111    )
112    .map_err(|e| traitclaw_core::Error::Runtime(format!("Schema init error: {e}")))?;
113
114    Ok(())
115}
116
117fn role_to_str(role: &MessageRole) -> &'static str {
118    match role {
119        MessageRole::System => "system",
120        MessageRole::User => "user",
121        MessageRole::Assistant => "assistant",
122        MessageRole::Tool => "tool",
123        _ => "unknown",
124    }
125}
126
127fn str_to_role(s: &str) -> MessageRole {
128    match s {
129        "system" => MessageRole::System,
130        "assistant" => MessageRole::Assistant,
131        "tool" => MessageRole::Tool,
132        // "user" and anything unknown default to User
133        _ => MessageRole::User,
134    }
135}
136
137#[async_trait]
138impl Memory for SqliteMemory {
139    async fn messages(&self, session_id: &str) -> Result<Vec<Message>> {
140        let conn = self
141            .conn
142            .lock()
143            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
144        let mut stmt = conn
145            .prepare("SELECT role, content, tool_call_id FROM messages WHERE session_id = ?1 ORDER BY id")
146            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
147
148        let rows = stmt
149            .query_map([session_id], |row| {
150                let role: String = row.get(0)?;
151                let content: String = row.get(1)?;
152                let tool_call_id: Option<String> = row.get(2)?;
153                Ok(Message {
154                    role: str_to_role(&role),
155                    content,
156                    tool_call_id,
157                })
158            })
159            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
160
161        let mut messages = Vec::new();
162        for row in rows {
163            messages
164                .push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
165        }
166        Ok(messages)
167    }
168
169    async fn append(&self, session_id: &str, message: Message) -> Result<()> {
170        let conn = self
171            .conn
172            .lock()
173            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
174        conn.execute(
175            "INSERT INTO messages (session_id, role, content, tool_call_id) VALUES (?1, ?2, ?3, ?4)",
176            rusqlite::params![
177                session_id,
178                role_to_str(&message.role),
179                message.content,
180                message.tool_call_id,
181            ],
182        )
183        .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
184        Ok(())
185    }
186
187    async fn get_context(&self, session_id: &str, key: &str) -> Result<Option<Value>> {
188        let conn = self
189            .conn
190            .lock()
191            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
192        let result: rusqlite::Result<String> = conn.query_row(
193            "SELECT value FROM working_memory WHERE session_id = ?1 AND key = ?2",
194            rusqlite::params![session_id, key],
195            |row| row.get(0),
196        );
197        match result {
198            Ok(json_str) => {
199                let val: Value = serde_json::from_str(&json_str).unwrap_or(Value::String(json_str));
200                Ok(Some(val))
201            }
202            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
203            Err(e) => Err(traitclaw_core::Error::Runtime(format!("Query error: {e}"))),
204        }
205    }
206
207    async fn set_context(&self, session_id: &str, key: &str, value: Value) -> Result<()> {
208        let conn = self
209            .conn
210            .lock()
211            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
212        let json_str = serde_json::to_string(&value)
213            .map_err(|e| traitclaw_core::Error::Runtime(format!("JSON error: {e}")))?;
214        conn.execute(
215            "INSERT OR REPLACE INTO working_memory (session_id, key, value) VALUES (?1, ?2, ?3)",
216            rusqlite::params![session_id, key, json_str],
217        )
218        .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
219        Ok(())
220    }
221
222    async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
223        let conn = self
224            .conn
225            .lock()
226            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
227        let mut stmt = conn
228            .prepare(
229                "SELECT m.id, m.content, m.metadata, m.created_at
230                 FROM long_term_fts f
231                 JOIN long_term_memory m ON m.rowid = f.rowid
232                 WHERE long_term_fts MATCH ?1
233                 ORDER BY rank
234                 LIMIT ?2",
235            )
236            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
237
238        let rows = stmt
239            .query_map(rusqlite::params![query, limit], |row| {
240                let id: String = row.get(0)?;
241                let content: String = row.get(1)?;
242                let metadata_str: Option<String> = row.get(2)?;
243                let created_at: u64 = row.get(3)?;
244                let metadata = metadata_str.and_then(|s| serde_json::from_str::<Value>(&s).ok());
245                let mut entry = MemoryEntry::now(id, content);
246                entry.metadata = metadata;
247                entry.created_at = created_at;
248                Ok(entry)
249            })
250            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
251
252        let mut entries = Vec::new();
253        for row in rows {
254            entries
255                .push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
256        }
257        Ok(entries)
258    }
259
260    async fn store(&self, entry: MemoryEntry) -> Result<()> {
261        let conn = self
262            .conn
263            .lock()
264            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
265        let metadata_str = entry
266            .metadata
267            .as_ref()
268            .map(|v| serde_json::to_string(v).unwrap_or_default());
269
270        conn.execute(
271            "INSERT OR REPLACE INTO long_term_memory (id, content, metadata, created_at) VALUES (?1, ?2, ?3, ?4)",
272            rusqlite::params![entry.id, entry.content, metadata_str, entry.created_at],
273        )
274        .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
275
276        // Insert into FTS index
277        let rowid: i64 = conn
278            .query_row(
279                "SELECT rowid FROM long_term_memory WHERE id = ?1",
280                rusqlite::params![entry.id],
281                |row| row.get(0),
282            )
283            .map_err(|e| traitclaw_core::Error::Runtime(format!("Rowid error: {e}")))?;
284
285        conn.execute(
286            "INSERT OR REPLACE INTO long_term_fts (rowid, content) VALUES (?1, ?2)",
287            rusqlite::params![rowid, entry.content],
288        )
289        .map_err(|e| traitclaw_core::Error::Runtime(format!("FTS insert error: {e}")))?;
290
291        Ok(())
292    }
293
294    async fn create_session(&self) -> Result<String> {
295        let id = uuid::Uuid::new_v4().to_string();
296        let conn = self
297            .conn
298            .lock()
299            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
300        conn.execute("INSERT INTO sessions (id) VALUES (?1)", [&id])
301            .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
302        Ok(id)
303    }
304
305    async fn list_sessions(&self) -> Result<Vec<String>> {
306        let conn = self
307            .conn
308            .lock()
309            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
310        let mut stmt = conn
311            .prepare("SELECT id FROM sessions ORDER BY created_at")
312            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
313        let rows = stmt
314            .query_map([], |row| row.get::<_, String>(0))
315            .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
316
317        let mut ids = Vec::new();
318        for row in rows {
319            ids.push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
320        }
321        Ok(ids)
322    }
323
324    async fn delete_session(&self, session_id: &str) -> Result<()> {
325        let conn = self
326            .conn
327            .lock()
328            .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
329        // CASCADE will delete messages; manually delete working_memory
330        conn.execute(
331            "DELETE FROM working_memory WHERE session_id = ?1",
332            [session_id],
333        )
334        .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
335        conn.execute("DELETE FROM messages WHERE session_id = ?1", [session_id])
336            .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
337        conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])
338            .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
339        Ok(())
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[tokio::test]
348    async fn test_session_lifecycle() {
349        let mem = SqliteMemory::in_memory().unwrap();
350        let sid = mem.create_session().await.unwrap();
351
352        let sessions = mem.list_sessions().await.unwrap();
353        assert!(sessions.contains(&sid));
354
355        mem.delete_session(&sid).await.unwrap();
356        let sessions = mem.list_sessions().await.unwrap();
357        assert!(!sessions.contains(&sid));
358    }
359
360    #[tokio::test]
361    async fn test_conversation_persistence() {
362        let mem = SqliteMemory::in_memory().unwrap();
363        let sid = mem.create_session().await.unwrap();
364
365        mem.append(&sid, Message::user("Hello")).await.unwrap();
366        mem.append(&sid, Message::assistant("Hi there!"))
367            .await
368            .unwrap();
369
370        let msgs = mem.messages(&sid).await.unwrap();
371        assert_eq!(msgs.len(), 2);
372        assert!(matches!(msgs[0].role, MessageRole::User));
373        assert_eq!(msgs[0].content, "Hello");
374        assert!(matches!(msgs[1].role, MessageRole::Assistant));
375        assert_eq!(msgs[1].content, "Hi there!");
376    }
377
378    #[tokio::test]
379    async fn test_working_memory() {
380        let mem = SqliteMemory::in_memory().unwrap();
381        let sid = mem.create_session().await.unwrap();
382
383        // Initially empty
384        assert!(mem.get_context(&sid, "task").await.unwrap().is_none());
385
386        // Set and get
387        mem.set_context(&sid, "task", serde_json::json!("coding"))
388            .await
389            .unwrap();
390        let val = mem.get_context(&sid, "task").await.unwrap().unwrap();
391        assert_eq!(val, serde_json::json!("coding"));
392
393        // Overwrite
394        mem.set_context(&sid, "task", serde_json::json!("testing"))
395            .await
396            .unwrap();
397        let val = mem.get_context(&sid, "task").await.unwrap().unwrap();
398        assert_eq!(val, serde_json::json!("testing"));
399    }
400
401    #[tokio::test]
402    async fn test_long_term_store_and_recall() {
403        let mem = SqliteMemory::in_memory().unwrap();
404
405        mem.store(MemoryEntry::now(
406            "1",
407            "Rust is a systems programming language",
408        ))
409        .await
410        .unwrap();
411        mem.store(MemoryEntry::now("2", "Python is great for data science"))
412            .await
413            .unwrap();
414        mem.store(MemoryEntry::now("3", "Rust has zero-cost abstractions"))
415            .await
416            .unwrap();
417
418        let results = mem.recall("Rust programming", 10).await.unwrap();
419        assert!(!results.is_empty());
420        // FTS5 should rank Rust-related entries higher
421        assert!(results.iter().any(|r| r.content.contains("Rust")));
422    }
423
424    #[tokio::test]
425    async fn test_recall_empty() {
426        let mem = SqliteMemory::in_memory().unwrap();
427        let results = mem.recall("anything", 10).await.unwrap();
428        assert!(results.is_empty());
429    }
430
431    #[tokio::test]
432    async fn test_delete_session_clears_messages_and_context() {
433        let mem = SqliteMemory::in_memory().unwrap();
434        let sid = mem.create_session().await.unwrap();
435
436        mem.append(&sid, Message::user("test")).await.unwrap();
437        mem.set_context(&sid, "key", serde_json::json!("val"))
438            .await
439            .unwrap();
440
441        mem.delete_session(&sid).await.unwrap();
442
443        let msgs = mem.messages(&sid).await.unwrap();
444        assert!(msgs.is_empty());
445        assert!(mem.get_context(&sid, "key").await.unwrap().is_none());
446    }
447
448    #[tokio::test]
449    async fn test_tool_message_with_call_id() {
450        let mem = SqliteMemory::in_memory().unwrap();
451        let sid = mem.create_session().await.unwrap();
452
453        mem.append(
454            &sid,
455            Message {
456                role: MessageRole::Tool,
457                content: "result".into(),
458                tool_call_id: Some("call_1".into()),
459            },
460        )
461        .await
462        .unwrap();
463
464        let msgs = mem.messages(&sid).await.unwrap();
465        assert_eq!(msgs.len(), 1);
466        assert!(matches!(msgs[0].role, MessageRole::Tool));
467        assert_eq!(msgs[0].tool_call_id, Some("call_1".into()));
468    }
469}