1#![deny(missing_docs)]
15#![allow(clippy::redundant_closure)]
16
17use std::sync::Mutex;
18
19use async_trait::async_trait;
20use rusqlite::Connection;
21use serde_json::Value;
22use traitclaw_core::traits::memory::{Memory, MemoryEntry};
23use traitclaw_core::types::message::{Message, MessageRole};
24use traitclaw_core::Result;
25
26pub struct SqliteMemory {
33 conn: Mutex<Connection>,
34}
35
36impl SqliteMemory {
37 pub fn new(path: &str) -> Result<Self> {
46 let conn = Connection::open(path)
47 .map_err(|e| traitclaw_core::Error::Runtime(format!("SQLite open error: {e}")))?;
48
49 init_schema(&conn)?;
50
51 Ok(Self {
52 conn: Mutex::new(conn),
53 })
54 }
55
56 pub fn in_memory() -> Result<Self> {
62 let conn = Connection::open_in_memory()
63 .map_err(|e| traitclaw_core::Error::Runtime(format!("SQLite open error: {e}")))?;
64
65 init_schema(&conn)?;
66
67 Ok(Self {
68 conn: Mutex::new(conn),
69 })
70 }
71}
72
73fn init_schema(conn: &Connection) -> Result<()> {
74 conn.execute_batch(
75 "
76 CREATE TABLE IF NOT EXISTS sessions (
77 id TEXT PRIMARY KEY,
78 created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
79 );
80
81 CREATE TABLE IF NOT EXISTS messages (
82 id INTEGER PRIMARY KEY AUTOINCREMENT,
83 session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
84 role TEXT NOT NULL,
85 content TEXT NOT NULL DEFAULT '',
86 tool_call_id TEXT,
87 created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
88 );
89 CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
90
91 CREATE TABLE IF NOT EXISTS working_memory (
92 session_id TEXT NOT NULL,
93 key TEXT NOT NULL,
94 value TEXT NOT NULL,
95 PRIMARY KEY (session_id, key)
96 );
97
98 CREATE TABLE IF NOT EXISTS long_term_memory (
99 id TEXT PRIMARY KEY,
100 content TEXT NOT NULL,
101 metadata TEXT,
102 created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
103 );
104
105 CREATE VIRTUAL TABLE IF NOT EXISTS long_term_fts
106 USING fts5(content, content_rowid='rowid');
107 ",
108 )
109 .map_err(|e| traitclaw_core::Error::Runtime(format!("Schema init error: {e}")))?;
110
111 Ok(())
112}
113
114fn role_to_str(role: &MessageRole) -> &'static str {
115 match role {
116 MessageRole::System => "system",
117 MessageRole::User => "user",
118 MessageRole::Assistant => "assistant",
119 MessageRole::Tool => "tool",
120 _ => "unknown",
121 }
122}
123
124fn str_to_role(s: &str) -> MessageRole {
125 match s {
126 "system" => MessageRole::System,
127 "assistant" => MessageRole::Assistant,
128 "tool" => MessageRole::Tool,
129 _ => MessageRole::User,
131 }
132}
133
134#[async_trait]
135impl Memory for SqliteMemory {
136 async fn messages(&self, session_id: &str) -> Result<Vec<Message>> {
137 let conn = self
138 .conn
139 .lock()
140 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
141 let mut stmt = conn
142 .prepare("SELECT role, content, tool_call_id FROM messages WHERE session_id = ?1 ORDER BY id")
143 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
144
145 let rows = stmt
146 .query_map([session_id], |row| {
147 let role: String = row.get(0)?;
148 let content: String = row.get(1)?;
149 let tool_call_id: Option<String> = row.get(2)?;
150 Ok(Message {
151 role: str_to_role(&role),
152 content,
153 tool_call_id,
154 })
155 })
156 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
157
158 let mut messages = Vec::new();
159 for row in rows {
160 messages
161 .push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
162 }
163 Ok(messages)
164 }
165
166 async fn append(&self, session_id: &str, message: Message) -> Result<()> {
167 let conn = self
168 .conn
169 .lock()
170 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
171 conn.execute(
172 "INSERT INTO messages (session_id, role, content, tool_call_id) VALUES (?1, ?2, ?3, ?4)",
173 rusqlite::params![
174 session_id,
175 role_to_str(&message.role),
176 message.content,
177 message.tool_call_id,
178 ],
179 )
180 .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
181 Ok(())
182 }
183
184 async fn get_context(&self, session_id: &str, key: &str) -> Result<Option<Value>> {
185 let conn = self
186 .conn
187 .lock()
188 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
189 let result: rusqlite::Result<String> = conn.query_row(
190 "SELECT value FROM working_memory WHERE session_id = ?1 AND key = ?2",
191 rusqlite::params![session_id, key],
192 |row| row.get(0),
193 );
194 match result {
195 Ok(json_str) => {
196 let val: Value = serde_json::from_str(&json_str).unwrap_or(Value::String(json_str));
197 Ok(Some(val))
198 }
199 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
200 Err(e) => Err(traitclaw_core::Error::Runtime(format!("Query error: {e}"))),
201 }
202 }
203
204 async fn set_context(&self, session_id: &str, key: &str, value: Value) -> Result<()> {
205 let conn = self
206 .conn
207 .lock()
208 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
209 let json_str = serde_json::to_string(&value)
210 .map_err(|e| traitclaw_core::Error::Runtime(format!("JSON error: {e}")))?;
211 conn.execute(
212 "INSERT OR REPLACE INTO working_memory (session_id, key, value) VALUES (?1, ?2, ?3)",
213 rusqlite::params![session_id, key, json_str],
214 )
215 .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
216 Ok(())
217 }
218
219 async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
220 let conn = self
221 .conn
222 .lock()
223 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
224 let mut stmt = conn
225 .prepare(
226 "SELECT m.id, m.content, m.metadata, m.created_at
227 FROM long_term_fts f
228 JOIN long_term_memory m ON m.rowid = f.rowid
229 WHERE long_term_fts MATCH ?1
230 ORDER BY rank
231 LIMIT ?2",
232 )
233 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
234
235 let rows = stmt
236 .query_map(rusqlite::params![query, limit], |row| {
237 let id: String = row.get(0)?;
238 let content: String = row.get(1)?;
239 let metadata_str: Option<String> = row.get(2)?;
240 let created_at: u64 = row.get(3)?;
241 let metadata = metadata_str.and_then(|s| serde_json::from_str::<Value>(&s).ok());
242 let mut entry = MemoryEntry::now(id, content);
243 entry.metadata = metadata;
244 entry.created_at = created_at;
245 Ok(entry)
246 })
247 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
248
249 let mut entries = Vec::new();
250 for row in rows {
251 entries
252 .push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
253 }
254 Ok(entries)
255 }
256
257 async fn store(&self, entry: MemoryEntry) -> Result<()> {
258 let conn = self
259 .conn
260 .lock()
261 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
262 let metadata_str = entry
263 .metadata
264 .as_ref()
265 .map(|v| serde_json::to_string(v).unwrap_or_default());
266
267 conn.execute(
268 "INSERT OR REPLACE INTO long_term_memory (id, content, metadata, created_at) VALUES (?1, ?2, ?3, ?4)",
269 rusqlite::params![entry.id, entry.content, metadata_str, entry.created_at],
270 )
271 .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
272
273 let rowid: i64 = conn
275 .query_row(
276 "SELECT rowid FROM long_term_memory WHERE id = ?1",
277 rusqlite::params![entry.id],
278 |row| row.get(0),
279 )
280 .map_err(|e| traitclaw_core::Error::Runtime(format!("Rowid error: {e}")))?;
281
282 conn.execute(
283 "INSERT OR REPLACE INTO long_term_fts (rowid, content) VALUES (?1, ?2)",
284 rusqlite::params![rowid, entry.content],
285 )
286 .map_err(|e| traitclaw_core::Error::Runtime(format!("FTS insert error: {e}")))?;
287
288 Ok(())
289 }
290
291 async fn create_session(&self) -> Result<String> {
292 let id = uuid::Uuid::new_v4().to_string();
293 let conn = self
294 .conn
295 .lock()
296 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
297 conn.execute("INSERT INTO sessions (id) VALUES (?1)", [&id])
298 .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
299 Ok(id)
300 }
301
302 async fn list_sessions(&self) -> Result<Vec<String>> {
303 let conn = self
304 .conn
305 .lock()
306 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
307 let mut stmt = conn
308 .prepare("SELECT id FROM sessions ORDER BY created_at")
309 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
310 let rows = stmt
311 .query_map([], |row| row.get::<_, String>(0))
312 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
313
314 let mut ids = Vec::new();
315 for row in rows {
316 ids.push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
317 }
318 Ok(ids)
319 }
320
321 async fn delete_session(&self, session_id: &str) -> Result<()> {
322 let conn = self
323 .conn
324 .lock()
325 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
326 conn.execute(
328 "DELETE FROM working_memory WHERE session_id = ?1",
329 [session_id],
330 )
331 .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
332 conn.execute("DELETE FROM messages WHERE session_id = ?1", [session_id])
333 .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
334 conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])
335 .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
336 Ok(())
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[tokio::test]
345 async fn test_session_lifecycle() {
346 let mem = SqliteMemory::in_memory().unwrap();
347 let sid = mem.create_session().await.unwrap();
348
349 let sessions = mem.list_sessions().await.unwrap();
350 assert!(sessions.contains(&sid));
351
352 mem.delete_session(&sid).await.unwrap();
353 let sessions = mem.list_sessions().await.unwrap();
354 assert!(!sessions.contains(&sid));
355 }
356
357 #[tokio::test]
358 async fn test_conversation_persistence() {
359 let mem = SqliteMemory::in_memory().unwrap();
360 let sid = mem.create_session().await.unwrap();
361
362 mem.append(&sid, Message::user("Hello")).await.unwrap();
363 mem.append(&sid, Message::assistant("Hi there!"))
364 .await
365 .unwrap();
366
367 let msgs = mem.messages(&sid).await.unwrap();
368 assert_eq!(msgs.len(), 2);
369 assert!(matches!(msgs[0].role, MessageRole::User));
370 assert_eq!(msgs[0].content, "Hello");
371 assert!(matches!(msgs[1].role, MessageRole::Assistant));
372 assert_eq!(msgs[1].content, "Hi there!");
373 }
374
375 #[tokio::test]
376 async fn test_working_memory() {
377 let mem = SqliteMemory::in_memory().unwrap();
378 let sid = mem.create_session().await.unwrap();
379
380 assert!(mem.get_context(&sid, "task").await.unwrap().is_none());
382
383 mem.set_context(&sid, "task", serde_json::json!("coding"))
385 .await
386 .unwrap();
387 let val = mem.get_context(&sid, "task").await.unwrap().unwrap();
388 assert_eq!(val, serde_json::json!("coding"));
389
390 mem.set_context(&sid, "task", serde_json::json!("testing"))
392 .await
393 .unwrap();
394 let val = mem.get_context(&sid, "task").await.unwrap().unwrap();
395 assert_eq!(val, serde_json::json!("testing"));
396 }
397
398 #[tokio::test]
399 async fn test_long_term_store_and_recall() {
400 let mem = SqliteMemory::in_memory().unwrap();
401
402 mem.store(MemoryEntry::now(
403 "1",
404 "Rust is a systems programming language",
405 ))
406 .await
407 .unwrap();
408 mem.store(MemoryEntry::now("2", "Python is great for data science"))
409 .await
410 .unwrap();
411 mem.store(MemoryEntry::now("3", "Rust has zero-cost abstractions"))
412 .await
413 .unwrap();
414
415 let results = mem.recall("Rust programming", 10).await.unwrap();
416 assert!(!results.is_empty());
417 assert!(results.iter().any(|r| r.content.contains("Rust")));
419 }
420
421 #[tokio::test]
422 async fn test_recall_empty() {
423 let mem = SqliteMemory::in_memory().unwrap();
424 let results = mem.recall("anything", 10).await.unwrap();
425 assert!(results.is_empty());
426 }
427
428 #[tokio::test]
429 async fn test_delete_session_clears_messages_and_context() {
430 let mem = SqliteMemory::in_memory().unwrap();
431 let sid = mem.create_session().await.unwrap();
432
433 mem.append(&sid, Message::user("test")).await.unwrap();
434 mem.set_context(&sid, "key", serde_json::json!("val"))
435 .await
436 .unwrap();
437
438 mem.delete_session(&sid).await.unwrap();
439
440 let msgs = mem.messages(&sid).await.unwrap();
441 assert!(msgs.is_empty());
442 assert!(mem.get_context(&sid, "key").await.unwrap().is_none());
443 }
444
445 #[tokio::test]
446 async fn test_tool_message_with_call_id() {
447 let mem = SqliteMemory::in_memory().unwrap();
448 let sid = mem.create_session().await.unwrap();
449
450 mem.append(
451 &sid,
452 Message {
453 role: MessageRole::Tool,
454 content: "result".into(),
455 tool_call_id: Some("call_1".into()),
456 },
457 )
458 .await
459 .unwrap();
460
461 let msgs = mem.messages(&sid).await.unwrap();
462 assert_eq!(msgs.len(), 1);
463 assert!(matches!(msgs[0].role, MessageRole::Tool));
464 assert_eq!(msgs[0].tool_call_id, Some("call_1".into()));
465 }
466}