1#![deny(warnings)]
15#![deny(missing_docs)]
16#![warn(clippy::pedantic)]
17#![allow(clippy::module_name_repetitions)]
18#![allow(clippy::doc_markdown)] use std::sync::Mutex;
21
22use async_trait::async_trait;
23use rusqlite::Connection;
24use serde_json::Value;
25use traitclaw_core::traits::memory::{Memory, MemoryEntry};
26use traitclaw_core::types::message::{Message, MessageRole};
27use traitclaw_core::Result;
28
29pub struct SqliteMemory {
36 conn: Mutex<Connection>,
37}
38
39impl SqliteMemory {
40 pub fn new(path: &str) -> Result<Self> {
49 let conn = Connection::open(path)
50 .map_err(|e| traitclaw_core::Error::Runtime(format!("SQLite open error: {e}")))?;
51
52 init_schema(&conn)?;
53
54 Ok(Self {
55 conn: Mutex::new(conn),
56 })
57 }
58
59 pub fn in_memory() -> Result<Self> {
65 let conn = Connection::open_in_memory()
66 .map_err(|e| traitclaw_core::Error::Runtime(format!("SQLite open error: {e}")))?;
67
68 init_schema(&conn)?;
69
70 Ok(Self {
71 conn: Mutex::new(conn),
72 })
73 }
74}
75
76fn init_schema(conn: &Connection) -> Result<()> {
77 conn.execute_batch(
78 "
79 CREATE TABLE IF NOT EXISTS sessions (
80 id TEXT PRIMARY KEY,
81 created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
82 );
83
84 CREATE TABLE IF NOT EXISTS messages (
85 id INTEGER PRIMARY KEY AUTOINCREMENT,
86 session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
87 role TEXT NOT NULL,
88 content TEXT NOT NULL DEFAULT '',
89 tool_call_id TEXT,
90 created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
91 );
92 CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
93
94 CREATE TABLE IF NOT EXISTS working_memory (
95 session_id TEXT NOT NULL,
96 key TEXT NOT NULL,
97 value TEXT NOT NULL,
98 PRIMARY KEY (session_id, key)
99 );
100
101 CREATE TABLE IF NOT EXISTS long_term_memory (
102 id TEXT PRIMARY KEY,
103 content TEXT NOT NULL,
104 metadata TEXT,
105 created_at INTEGER NOT NULL DEFAULT (strftime('%s','now'))
106 );
107
108 CREATE VIRTUAL TABLE IF NOT EXISTS long_term_fts
109 USING fts5(content, content_rowid='rowid');
110 ",
111 )
112 .map_err(|e| traitclaw_core::Error::Runtime(format!("Schema init error: {e}")))?;
113
114 Ok(())
115}
116
117fn role_to_str(role: &MessageRole) -> &'static str {
118 match role {
119 MessageRole::System => "system",
120 MessageRole::User => "user",
121 MessageRole::Assistant => "assistant",
122 MessageRole::Tool => "tool",
123 _ => "unknown",
124 }
125}
126
127fn str_to_role(s: &str) -> MessageRole {
128 match s {
129 "system" => MessageRole::System,
130 "assistant" => MessageRole::Assistant,
131 "tool" => MessageRole::Tool,
132 _ => MessageRole::User,
134 }
135}
136
137#[async_trait]
138impl Memory for SqliteMemory {
139 async fn messages(&self, session_id: &str) -> Result<Vec<Message>> {
140 let conn = self
141 .conn
142 .lock()
143 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
144 let mut stmt = conn
145 .prepare("SELECT role, content, tool_call_id FROM messages WHERE session_id = ?1 ORDER BY id")
146 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
147
148 let rows = stmt
149 .query_map([session_id], |row| {
150 let role: String = row.get(0)?;
151 let content: String = row.get(1)?;
152 let tool_call_id: Option<String> = row.get(2)?;
153 Ok(Message {
154 role: str_to_role(&role),
155 content,
156 tool_call_id,
157 })
158 })
159 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
160
161 let mut messages = Vec::new();
162 for row in rows {
163 messages
164 .push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
165 }
166 Ok(messages)
167 }
168
169 async fn append(&self, session_id: &str, message: Message) -> Result<()> {
170 let conn = self
171 .conn
172 .lock()
173 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
174 conn.execute(
175 "INSERT INTO messages (session_id, role, content, tool_call_id) VALUES (?1, ?2, ?3, ?4)",
176 rusqlite::params![
177 session_id,
178 role_to_str(&message.role),
179 message.content,
180 message.tool_call_id,
181 ],
182 )
183 .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
184 Ok(())
185 }
186
187 async fn get_context(&self, session_id: &str, key: &str) -> Result<Option<Value>> {
188 let conn = self
189 .conn
190 .lock()
191 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
192 let result: rusqlite::Result<String> = conn.query_row(
193 "SELECT value FROM working_memory WHERE session_id = ?1 AND key = ?2",
194 rusqlite::params![session_id, key],
195 |row| row.get(0),
196 );
197 match result {
198 Ok(json_str) => {
199 let val: Value = serde_json::from_str(&json_str).unwrap_or(Value::String(json_str));
200 Ok(Some(val))
201 }
202 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
203 Err(e) => Err(traitclaw_core::Error::Runtime(format!("Query error: {e}"))),
204 }
205 }
206
207 async fn set_context(&self, session_id: &str, key: &str, value: Value) -> Result<()> {
208 let conn = self
209 .conn
210 .lock()
211 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
212 let json_str = serde_json::to_string(&value)
213 .map_err(|e| traitclaw_core::Error::Runtime(format!("JSON error: {e}")))?;
214 conn.execute(
215 "INSERT OR REPLACE INTO working_memory (session_id, key, value) VALUES (?1, ?2, ?3)",
216 rusqlite::params![session_id, key, json_str],
217 )
218 .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
219 Ok(())
220 }
221
222 async fn recall(&self, query: &str, limit: usize) -> Result<Vec<MemoryEntry>> {
223 let conn = self
224 .conn
225 .lock()
226 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
227 let mut stmt = conn
228 .prepare(
229 "SELECT m.id, m.content, m.metadata, m.created_at
230 FROM long_term_fts f
231 JOIN long_term_memory m ON m.rowid = f.rowid
232 WHERE long_term_fts MATCH ?1
233 ORDER BY rank
234 LIMIT ?2",
235 )
236 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
237
238 let rows = stmt
239 .query_map(rusqlite::params![query, limit], |row| {
240 let id: String = row.get(0)?;
241 let content: String = row.get(1)?;
242 let metadata_str: Option<String> = row.get(2)?;
243 let created_at: u64 = row.get(3)?;
244 let metadata = metadata_str.and_then(|s| serde_json::from_str::<Value>(&s).ok());
245 let mut entry = MemoryEntry::now(id, content);
246 entry.metadata = metadata;
247 entry.created_at = created_at;
248 Ok(entry)
249 })
250 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
251
252 let mut entries = Vec::new();
253 for row in rows {
254 entries
255 .push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
256 }
257 Ok(entries)
258 }
259
260 async fn store(&self, entry: MemoryEntry) -> Result<()> {
261 let conn = self
262 .conn
263 .lock()
264 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
265 let metadata_str = entry
266 .metadata
267 .as_ref()
268 .map(|v| serde_json::to_string(v).unwrap_or_default());
269
270 conn.execute(
271 "INSERT OR REPLACE INTO long_term_memory (id, content, metadata, created_at) VALUES (?1, ?2, ?3, ?4)",
272 rusqlite::params![entry.id, entry.content, metadata_str, entry.created_at],
273 )
274 .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
275
276 let rowid: i64 = conn
278 .query_row(
279 "SELECT rowid FROM long_term_memory WHERE id = ?1",
280 rusqlite::params![entry.id],
281 |row| row.get(0),
282 )
283 .map_err(|e| traitclaw_core::Error::Runtime(format!("Rowid error: {e}")))?;
284
285 conn.execute(
286 "INSERT OR REPLACE INTO long_term_fts (rowid, content) VALUES (?1, ?2)",
287 rusqlite::params![rowid, entry.content],
288 )
289 .map_err(|e| traitclaw_core::Error::Runtime(format!("FTS insert error: {e}")))?;
290
291 Ok(())
292 }
293
294 async fn create_session(&self) -> Result<String> {
295 let id = uuid::Uuid::new_v4().to_string();
296 let conn = self
297 .conn
298 .lock()
299 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
300 conn.execute("INSERT INTO sessions (id) VALUES (?1)", [&id])
301 .map_err(|e| traitclaw_core::Error::Runtime(format!("Insert error: {e}")))?;
302 Ok(id)
303 }
304
305 async fn list_sessions(&self) -> Result<Vec<String>> {
306 let conn = self
307 .conn
308 .lock()
309 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
310 let mut stmt = conn
311 .prepare("SELECT id FROM sessions ORDER BY created_at")
312 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
313 let rows = stmt
314 .query_map([], |row| row.get::<_, String>(0))
315 .map_err(|e| traitclaw_core::Error::Runtime(format!("Query error: {e}")))?;
316
317 let mut ids = Vec::new();
318 for row in rows {
319 ids.push(row.map_err(|e| traitclaw_core::Error::Runtime(format!("Row error: {e}")))?);
320 }
321 Ok(ids)
322 }
323
324 async fn delete_session(&self, session_id: &str) -> Result<()> {
325 let conn = self
326 .conn
327 .lock()
328 .map_err(|e| traitclaw_core::Error::Runtime(format!("Lock error: {e}")))?;
329 conn.execute(
331 "DELETE FROM working_memory WHERE session_id = ?1",
332 [session_id],
333 )
334 .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
335 conn.execute("DELETE FROM messages WHERE session_id = ?1", [session_id])
336 .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
337 conn.execute("DELETE FROM sessions WHERE id = ?1", [session_id])
338 .map_err(|e| traitclaw_core::Error::Runtime(format!("Delete error: {e}")))?;
339 Ok(())
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[tokio::test]
348 async fn test_session_lifecycle() {
349 let mem = SqliteMemory::in_memory().unwrap();
350 let sid = mem.create_session().await.unwrap();
351
352 let sessions = mem.list_sessions().await.unwrap();
353 assert!(sessions.contains(&sid));
354
355 mem.delete_session(&sid).await.unwrap();
356 let sessions = mem.list_sessions().await.unwrap();
357 assert!(!sessions.contains(&sid));
358 }
359
360 #[tokio::test]
361 async fn test_conversation_persistence() {
362 let mem = SqliteMemory::in_memory().unwrap();
363 let sid = mem.create_session().await.unwrap();
364
365 mem.append(&sid, Message::user("Hello")).await.unwrap();
366 mem.append(&sid, Message::assistant("Hi there!"))
367 .await
368 .unwrap();
369
370 let msgs = mem.messages(&sid).await.unwrap();
371 assert_eq!(msgs.len(), 2);
372 assert!(matches!(msgs[0].role, MessageRole::User));
373 assert_eq!(msgs[0].content, "Hello");
374 assert!(matches!(msgs[1].role, MessageRole::Assistant));
375 assert_eq!(msgs[1].content, "Hi there!");
376 }
377
378 #[tokio::test]
379 async fn test_working_memory() {
380 let mem = SqliteMemory::in_memory().unwrap();
381 let sid = mem.create_session().await.unwrap();
382
383 assert!(mem.get_context(&sid, "task").await.unwrap().is_none());
385
386 mem.set_context(&sid, "task", serde_json::json!("coding"))
388 .await
389 .unwrap();
390 let val = mem.get_context(&sid, "task").await.unwrap().unwrap();
391 assert_eq!(val, serde_json::json!("coding"));
392
393 mem.set_context(&sid, "task", serde_json::json!("testing"))
395 .await
396 .unwrap();
397 let val = mem.get_context(&sid, "task").await.unwrap().unwrap();
398 assert_eq!(val, serde_json::json!("testing"));
399 }
400
401 #[tokio::test]
402 async fn test_long_term_store_and_recall() {
403 let mem = SqliteMemory::in_memory().unwrap();
404
405 mem.store(MemoryEntry::now(
406 "1",
407 "Rust is a systems programming language",
408 ))
409 .await
410 .unwrap();
411 mem.store(MemoryEntry::now("2", "Python is great for data science"))
412 .await
413 .unwrap();
414 mem.store(MemoryEntry::now("3", "Rust has zero-cost abstractions"))
415 .await
416 .unwrap();
417
418 let results = mem.recall("Rust programming", 10).await.unwrap();
419 assert!(!results.is_empty());
420 assert!(results.iter().any(|r| r.content.contains("Rust")));
422 }
423
424 #[tokio::test]
425 async fn test_recall_empty() {
426 let mem = SqliteMemory::in_memory().unwrap();
427 let results = mem.recall("anything", 10).await.unwrap();
428 assert!(results.is_empty());
429 }
430
431 #[tokio::test]
432 async fn test_delete_session_clears_messages_and_context() {
433 let mem = SqliteMemory::in_memory().unwrap();
434 let sid = mem.create_session().await.unwrap();
435
436 mem.append(&sid, Message::user("test")).await.unwrap();
437 mem.set_context(&sid, "key", serde_json::json!("val"))
438 .await
439 .unwrap();
440
441 mem.delete_session(&sid).await.unwrap();
442
443 let msgs = mem.messages(&sid).await.unwrap();
444 assert!(msgs.is_empty());
445 assert!(mem.get_context(&sid, "key").await.unwrap().is_none());
446 }
447
448 #[tokio::test]
449 async fn test_tool_message_with_call_id() {
450 let mem = SqliteMemory::in_memory().unwrap();
451 let sid = mem.create_session().await.unwrap();
452
453 mem.append(
454 &sid,
455 Message {
456 role: MessageRole::Tool,
457 content: "result".into(),
458 tool_call_id: Some("call_1".into()),
459 },
460 )
461 .await
462 .unwrap();
463
464 let msgs = mem.messages(&sid).await.unwrap();
465 assert_eq!(msgs.len(), 1);
466 assert!(matches!(msgs[0].role, MessageRole::Tool));
467 assert_eq!(msgs[0].tool_call_id, Some("call_1".into()));
468 }
469}