1use crate::{
4 models::*, Storage, StorageError, Transaction as TransactionTrait,
5};
6use async_trait::async_trait;
7use parking_lot::Mutex;
8use r2d2::{Pool, PooledConnection};
9use r2d2_sqlite::SqliteConnectionManager;
10use rusqlite::{params, OptionalExtension};
11use std::sync::Arc;
12use std::time::Duration;
13use tracing::{debug, info};
14
15type SqlitePool = Pool<SqliteConnectionManager>;
16type SqliteConn = PooledConnection<SqliteConnectionManager>;
17
18pub struct SqliteStorage {
20 pool: Arc<SqlitePool>,
21 path: String,
22}
23
24impl SqliteStorage {
25 pub async fn new(path: &str) -> Result<Self, StorageError> {
27 let manager = SqliteConnectionManager::file(path);
28
29 let pool = Pool::builder()
30 .max_size(16)
31 .min_idle(Some(2))
32 .connection_timeout(Duration::from_secs(30))
33 .idle_timeout(Some(Duration::from_secs(300)))
34 .build(manager)
35 .map_err(|e| StorageError::Pool(e.to_string()))?;
36
37 let storage = Self {
38 pool: Arc::new(pool),
39 path: path.to_string(),
40 };
41
42 storage.init_schema().await?;
44
45 info!("SQLite storage initialized at: {}", path);
46 Ok(storage)
47 }
48
49 fn get_conn(&self) -> Result<SqliteConn, StorageError> {
51 self.pool
52 .get()
53 .map_err(|e| StorageError::Pool(e.to_string()))
54 }
55
56 async fn init_schema(&self) -> Result<(), StorageError> {
58 let conn = self.get_conn()?;
59
60 conn.execute_batch(include_str!("../sql/schema.sql"))
61 .map_err(|e| StorageError::Migration(format!("Schema initialization failed: {}", e)))?;
62
63 Ok(())
64 }
65}
66
67#[async_trait]
68impl Storage for SqliteStorage {
69 type Error = StorageError;
70
71 async fn store_agent(&self, agent: &AgentModel) -> Result<(), Self::Error> {
73 let conn = self.get_conn()?;
74 let json = serde_json::to_string(agent)?;
75
76 conn.execute(
77 "INSERT INTO agents (id, name, agent_type, status, capabilities, metadata, heartbeat, created_at, updated_at, data)
78 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
79 params![
80 agent.id,
81 agent.name,
82 agent.agent_type,
83 agent.status.to_string(),
84 serde_json::to_string(&agent.capabilities)?,
85 serde_json::to_string(&agent.metadata)?,
86 agent.heartbeat.timestamp(),
87 agent.created_at.timestamp(),
88 agent.updated_at.timestamp(),
89 json
90 ],
91 )
92 .map_err(|e| StorageError::Database(e.to_string()))?;
93
94 debug!("Stored agent: {}", agent.id);
95 Ok(())
96 }
97
98 async fn get_agent(&self, id: &str) -> Result<Option<AgentModel>, Self::Error> {
99 let conn = self.get_conn()?;
100
101 let result = conn
102 .query_row(
103 "SELECT data FROM agents WHERE id = ?1",
104 params![id],
105 |row| row.get::<_, String>(0),
106 )
107 .optional()
108 .map_err(|e| StorageError::Database(e.to_string()))?;
109
110 match result {
111 Some(json) => Ok(Some(serde_json::from_str(&json)?)),
112 None => Ok(None),
113 }
114 }
115
116 async fn update_agent(&self, agent: &AgentModel) -> Result<(), Self::Error> {
117 let conn = self.get_conn()?;
118 let json = serde_json::to_string(agent)?;
119
120 let rows = conn.execute(
121 "UPDATE agents
122 SET name = ?2, agent_type = ?3, status = ?4, capabilities = ?5,
123 metadata = ?6, heartbeat = ?7, updated_at = ?8, data = ?9
124 WHERE id = ?1",
125 params![
126 agent.id,
127 agent.name,
128 agent.agent_type,
129 agent.status.to_string(),
130 serde_json::to_string(&agent.capabilities)?,
131 serde_json::to_string(&agent.metadata)?,
132 agent.heartbeat.timestamp(),
133 agent.updated_at.timestamp(),
134 json
135 ],
136 )
137 .map_err(|e| StorageError::Database(e.to_string()))?;
138
139 if rows == 0 {
140 return Err(StorageError::NotFound(format!("Agent {} not found", agent.id)));
141 }
142
143 debug!("Updated agent: {}", agent.id);
144 Ok(())
145 }
146
147 async fn delete_agent(&self, id: &str) -> Result<(), Self::Error> {
148 let conn = self.get_conn()?;
149
150 let rows = conn
151 .execute("DELETE FROM agents WHERE id = ?1", params![id])
152 .map_err(|e| StorageError::Database(e.to_string()))?;
153
154 if rows == 0 {
155 return Err(StorageError::NotFound(format!("Agent {} not found", id)));
156 }
157
158 debug!("Deleted agent: {}", id);
159 Ok(())
160 }
161
162 async fn list_agents(&self) -> Result<Vec<AgentModel>, Self::Error> {
163 let conn = self.get_conn()?;
164
165 let mut stmt = conn
166 .prepare("SELECT data FROM agents ORDER BY created_at DESC")
167 .map_err(|e| StorageError::Database(e.to_string()))?;
168
169 let agents = stmt
170 .query_map([], |row| Ok(row.get::<_, String>(0)?))
171 .map_err(|e| StorageError::Database(e.to_string()))?
172 .filter_map(|r| r.ok())
173 .filter_map(|json| serde_json::from_str(&json).ok())
174 .collect();
175
176 Ok(agents)
177 }
178
179 async fn list_agents_by_status(&self, status: &str) -> Result<Vec<AgentModel>, Self::Error> {
180 let conn = self.get_conn()?;
181
182 let mut stmt = conn
183 .prepare("SELECT data FROM agents WHERE status = ?1 ORDER BY created_at DESC")
184 .map_err(|e| StorageError::Database(e.to_string()))?;
185
186 let agents = stmt
187 .query_map(params![status], |row| Ok(row.get::<_, String>(0)?))
188 .map_err(|e| StorageError::Database(e.to_string()))?
189 .filter_map(|r| r.ok())
190 .filter_map(|json| serde_json::from_str(&json).ok())
191 .collect();
192
193 Ok(agents)
194 }
195
196 async fn store_task(&self, task: &TaskModel) -> Result<(), Self::Error> {
198 let conn = self.get_conn()?;
199 let json = serde_json::to_string(task)?;
200
201 conn.execute(
202 "INSERT INTO tasks (id, task_type, priority, status, assigned_to, payload,
203 created_at, updated_at, data)
204 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
205 params![
206 task.id,
207 task.task_type,
208 task.priority as i32,
209 serde_json::to_string(&task.status)?,
210 task.assigned_to,
211 serde_json::to_string(&task.payload)?,
212 task.created_at.timestamp(),
213 task.updated_at.timestamp(),
214 json
215 ],
216 )
217 .map_err(|e| StorageError::Database(e.to_string()))?;
218
219 debug!("Stored task: {}", task.id);
220 Ok(())
221 }
222
223 async fn get_task(&self, id: &str) -> Result<Option<TaskModel>, Self::Error> {
224 let conn = self.get_conn()?;
225
226 let result = conn
227 .query_row(
228 "SELECT data FROM tasks WHERE id = ?1",
229 params![id],
230 |row| row.get::<_, String>(0),
231 )
232 .optional()
233 .map_err(|e| StorageError::Database(e.to_string()))?;
234
235 match result {
236 Some(json) => Ok(Some(serde_json::from_str(&json)?)),
237 None => Ok(None),
238 }
239 }
240
241 async fn update_task(&self, task: &TaskModel) -> Result<(), Self::Error> {
242 let conn = self.get_conn()?;
243 let json = serde_json::to_string(task)?;
244
245 let rows = conn.execute(
246 "UPDATE tasks
247 SET task_type = ?2, priority = ?3, status = ?4, assigned_to = ?5,
248 payload = ?6, updated_at = ?7, data = ?8
249 WHERE id = ?1",
250 params![
251 task.id,
252 task.task_type,
253 task.priority as i32,
254 serde_json::to_string(&task.status)?,
255 task.assigned_to,
256 serde_json::to_string(&task.payload)?,
257 task.updated_at.timestamp(),
258 json
259 ],
260 )
261 .map_err(|e| StorageError::Database(e.to_string()))?;
262
263 if rows == 0 {
264 return Err(StorageError::NotFound(format!("Task {} not found", task.id)));
265 }
266
267 debug!("Updated task: {}", task.id);
268 Ok(())
269 }
270
271 async fn get_pending_tasks(&self) -> Result<Vec<TaskModel>, Self::Error> {
272 let conn = self.get_conn()?;
273
274 let mut stmt = conn
275 .prepare(
276 "SELECT data FROM tasks
277 WHERE status = 'pending'
278 ORDER BY priority DESC, created_at ASC"
279 )
280 .map_err(|e| StorageError::Database(e.to_string()))?;
281
282 let tasks = stmt
283 .query_map([], |row| Ok(row.get::<_, String>(0)?))
284 .map_err(|e| StorageError::Database(e.to_string()))?
285 .filter_map(|r| r.ok())
286 .filter_map(|json| serde_json::from_str(&json).ok())
287 .collect();
288
289 Ok(tasks)
290 }
291
292 async fn get_tasks_by_agent(&self, agent_id: &str) -> Result<Vec<TaskModel>, Self::Error> {
293 let conn = self.get_conn()?;
294
295 let mut stmt = conn
296 .prepare(
297 "SELECT data FROM tasks
298 WHERE assigned_to = ?1
299 ORDER BY priority DESC, created_at ASC"
300 )
301 .map_err(|e| StorageError::Database(e.to_string()))?;
302
303 let tasks = stmt
304 .query_map(params![agent_id], |row| Ok(row.get::<_, String>(0)?))
305 .map_err(|e| StorageError::Database(e.to_string()))?
306 .filter_map(|r| r.ok())
307 .filter_map(|json| serde_json::from_str(&json).ok())
308 .collect();
309
310 Ok(tasks)
311 }
312
313 async fn claim_task(&self, task_id: &str, agent_id: &str) -> Result<bool, Self::Error> {
314 let conn = self.get_conn()?;
315
316 let rows = conn.execute(
317 "UPDATE tasks
318 SET assigned_to = ?2, status = 'assigned', updated_at = ?3
319 WHERE id = ?1 AND status = 'pending'",
320 params![task_id, agent_id, chrono::Utc::now().timestamp()],
321 )
322 .map_err(|e| StorageError::Database(e.to_string()))?;
323
324 Ok(rows > 0)
325 }
326
327 async fn store_event(&self, event: &EventModel) -> Result<(), Self::Error> {
329 let conn = self.get_conn()?;
330 let json = serde_json::to_string(event)?;
331
332 conn.execute(
333 "INSERT INTO events (id, event_type, agent_id, task_id, payload, metadata,
334 timestamp, sequence, data)
335 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
336 params![
337 event.id,
338 event.event_type,
339 event.agent_id,
340 event.task_id,
341 serde_json::to_string(&event.payload)?,
342 serde_json::to_string(&event.metadata)?,
343 event.timestamp.timestamp(),
344 event.sequence as i64,
345 json
346 ],
347 )
348 .map_err(|e| StorageError::Database(e.to_string()))?;
349
350 debug!("Stored event: {}", event.id);
351 Ok(())
352 }
353
354 async fn get_events_by_agent(&self, agent_id: &str, limit: usize) -> Result<Vec<EventModel>, Self::Error> {
355 let conn = self.get_conn()?;
356
357 let mut stmt = conn
358 .prepare(
359 "SELECT data FROM events
360 WHERE agent_id = ?1
361 ORDER BY timestamp DESC
362 LIMIT ?2"
363 )
364 .map_err(|e| StorageError::Database(e.to_string()))?;
365
366 let events = stmt
367 .query_map(params![agent_id, limit], |row| Ok(row.get::<_, String>(0)?))
368 .map_err(|e| StorageError::Database(e.to_string()))?
369 .filter_map(|r| r.ok())
370 .filter_map(|json| serde_json::from_str(&json).ok())
371 .collect();
372
373 Ok(events)
374 }
375
376 async fn get_events_by_type(&self, event_type: &str, limit: usize) -> Result<Vec<EventModel>, Self::Error> {
377 let conn = self.get_conn()?;
378
379 let mut stmt = conn
380 .prepare(
381 "SELECT data FROM events
382 WHERE event_type = ?1
383 ORDER BY timestamp DESC
384 LIMIT ?2"
385 )
386 .map_err(|e| StorageError::Database(e.to_string()))?;
387
388 let events = stmt
389 .query_map(params![event_type, limit], |row| Ok(row.get::<_, String>(0)?))
390 .map_err(|e| StorageError::Database(e.to_string()))?
391 .filter_map(|r| r.ok())
392 .filter_map(|json| serde_json::from_str(&json).ok())
393 .collect();
394
395 Ok(events)
396 }
397
398 async fn get_events_since(&self, timestamp: i64) -> Result<Vec<EventModel>, Self::Error> {
399 let conn = self.get_conn()?;
400
401 let mut stmt = conn
402 .prepare(
403 "SELECT data FROM events
404 WHERE timestamp > ?1
405 ORDER BY timestamp ASC"
406 )
407 .map_err(|e| StorageError::Database(e.to_string()))?;
408
409 let events = stmt
410 .query_map(params![timestamp], |row| Ok(row.get::<_, String>(0)?))
411 .map_err(|e| StorageError::Database(e.to_string()))?
412 .filter_map(|r| r.ok())
413 .filter_map(|json| serde_json::from_str(&json).ok())
414 .collect();
415
416 Ok(events)
417 }
418
419 async fn store_message(&self, message: &MessageModel) -> Result<(), Self::Error> {
421 let conn = self.get_conn()?;
422 let json = serde_json::to_string(message)?;
423
424 conn.execute(
425 "INSERT INTO messages (id, from_agent, to_agent, message_type, content,
426 priority, read, created_at, data)
427 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
428 params![
429 message.id,
430 message.from_agent,
431 message.to_agent,
432 message.message_type,
433 serde_json::to_string(&message.content)?,
434 serde_json::to_string(&message.priority)?,
435 message.read as i32,
436 message.created_at.timestamp(),
437 json
438 ],
439 )
440 .map_err(|e| StorageError::Database(e.to_string()))?;
441
442 debug!("Stored message: {}", message.id);
443 Ok(())
444 }
445
446 async fn get_messages_between(&self, agent1: &str, agent2: &str, limit: usize) -> Result<Vec<MessageModel>, Self::Error> {
447 let conn = self.get_conn()?;
448
449 let mut stmt = conn
450 .prepare(
451 "SELECT data FROM messages
452 WHERE (from_agent = ?1 AND to_agent = ?2) OR (from_agent = ?2 AND to_agent = ?1)
453 ORDER BY created_at DESC
454 LIMIT ?3"
455 )
456 .map_err(|e| StorageError::Database(e.to_string()))?;
457
458 let messages = stmt
459 .query_map(params![agent1, agent2, limit], |row| Ok(row.get::<_, String>(0)?))
460 .map_err(|e| StorageError::Database(e.to_string()))?
461 .filter_map(|r| r.ok())
462 .filter_map(|json| serde_json::from_str(&json).ok())
463 .collect();
464
465 Ok(messages)
466 }
467
468 async fn get_unread_messages(&self, agent_id: &str) -> Result<Vec<MessageModel>, Self::Error> {
469 let conn = self.get_conn()?;
470
471 let mut stmt = conn
472 .prepare(
473 "SELECT data FROM messages
474 WHERE to_agent = ?1 AND read = 0
475 ORDER BY created_at ASC"
476 )
477 .map_err(|e| StorageError::Database(e.to_string()))?;
478
479 let messages = stmt
480 .query_map(params![agent_id], |row| Ok(row.get::<_, String>(0)?))
481 .map_err(|e| StorageError::Database(e.to_string()))?
482 .filter_map(|r| r.ok())
483 .filter_map(|json| serde_json::from_str(&json).ok())
484 .collect();
485
486 Ok(messages)
487 }
488
489 async fn mark_message_read(&self, message_id: &str) -> Result<(), Self::Error> {
490 let conn = self.get_conn()?;
491
492 let rows = conn.execute(
493 "UPDATE messages SET read = 1, read_at = ?2 WHERE id = ?1",
494 params![message_id, chrono::Utc::now().timestamp()],
495 )
496 .map_err(|e| StorageError::Database(e.to_string()))?;
497
498 if rows == 0 {
499 return Err(StorageError::NotFound(format!("Message {} not found", message_id)));
500 }
501
502 Ok(())
503 }
504
505 async fn store_metric(&self, metric: &MetricModel) -> Result<(), Self::Error> {
507 let conn = self.get_conn()?;
508 let json = serde_json::to_string(metric)?;
509
510 conn.execute(
511 "INSERT INTO metrics (id, metric_type, agent_id, value, unit, tags, timestamp, data)
512 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
513 params![
514 metric.id,
515 metric.metric_type,
516 metric.agent_id,
517 metric.value,
518 metric.unit,
519 serde_json::to_string(&metric.tags)?,
520 metric.timestamp.timestamp(),
521 json
522 ],
523 )
524 .map_err(|e| StorageError::Database(e.to_string()))?;
525
526 debug!("Stored metric: {}", metric.id);
527 Ok(())
528 }
529
530 async fn get_metrics_by_agent(&self, agent_id: &str, metric_type: &str) -> Result<Vec<MetricModel>, Self::Error> {
531 let conn = self.get_conn()?;
532
533 let mut stmt = conn
534 .prepare(
535 "SELECT data FROM metrics
536 WHERE agent_id = ?1 AND metric_type = ?2
537 ORDER BY timestamp DESC"
538 )
539 .map_err(|e| StorageError::Database(e.to_string()))?;
540
541 let metrics = stmt
542 .query_map(params![agent_id, metric_type], |row| Ok(row.get::<_, String>(0)?))
543 .map_err(|e| StorageError::Database(e.to_string()))?
544 .filter_map(|r| r.ok())
545 .filter_map(|json| serde_json::from_str(&json).ok())
546 .collect();
547
548 Ok(metrics)
549 }
550
551 async fn get_aggregated_metrics(&self, metric_type: &str, start_time: i64, end_time: i64) -> Result<Vec<MetricModel>, Self::Error> {
552 let conn = self.get_conn()?;
553
554 let mut stmt = conn
555 .prepare(
556 "SELECT metric_type, agent_id, AVG(value) as value, unit,
557 MIN(timestamp) as timestamp, COUNT(*) as count
558 FROM metrics
559 WHERE metric_type = ?1 AND timestamp >= ?2 AND timestamp <= ?3
560 GROUP BY metric_type, agent_id, unit"
561 )
562 .map_err(|e| StorageError::Database(e.to_string()))?;
563
564 let metrics = stmt
565 .query_map(params![metric_type, start_time, end_time], |row| {
566 let mut metric = MetricModel::new(
567 row.get::<_, String>(0)?,
568 row.get::<_, f64>(2)?,
569 row.get::<_, String>(3)?,
570 );
571 metric.agent_id = row.get::<_, Option<String>>(1)?;
572 metric.tags.insert("count".to_string(), row.get::<_, i64>(5)?.to_string());
573 Ok(metric)
574 })
575 .map_err(|e| StorageError::Database(e.to_string()))?
576 .filter_map(|r| r.ok())
577 .collect();
578
579 Ok(metrics)
580 }
581
582 async fn begin_transaction(&self) -> Result<Box<dyn TransactionTrait>, Self::Error> {
584 let conn = self.get_conn()?;
585 Ok(Box::new(SqliteTransaction::new(conn)))
586 }
587
588 async fn vacuum(&self) -> Result<(), Self::Error> {
590 let conn = self.get_conn()?;
591 conn.execute("VACUUM", [])
592 .map_err(|e| StorageError::Database(e.to_string()))?;
593 info!("Database vacuumed");
594 Ok(())
595 }
596
597 async fn checkpoint(&self) -> Result<(), Self::Error> {
598 let conn = self.get_conn()?;
599 conn.execute("PRAGMA wal_checkpoint(TRUNCATE)", [])
600 .map_err(|e| StorageError::Database(e.to_string()))?;
601 info!("Database checkpoint completed");
602 Ok(())
603 }
604
605 async fn get_storage_size(&self) -> Result<u64, Self::Error> {
606 let metadata = std::fs::metadata(&self.path)
607 .map_err(|e| StorageError::Other(e.to_string()))?;
608 Ok(metadata.len())
609 }
610}
611
612struct SqliteTransaction {
614 conn: Mutex<Option<SqliteConn>>,
615}
616
617impl SqliteTransaction {
618 fn new(conn: SqliteConn) -> Self {
619 Self {
620 conn: Mutex::new(Some(conn)),
621 }
622 }
623}
624
625#[async_trait]
626impl TransactionTrait for SqliteTransaction {
627 async fn commit(self: Box<Self>) -> Result<(), StorageError> {
628 if let Some(conn) = self.conn.lock().take() {
629 drop(conn);
632 }
633 Ok(())
634 }
635
636 async fn rollback(self: Box<Self>) -> Result<(), StorageError> {
637 if let Some(conn) = self.conn.lock().take() {
638 drop(conn);
641 }
642 Ok(())
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use tempfile::NamedTempFile;
650
651 #[tokio::test]
652 async fn test_sqlite_storage() {
653 let temp_file = NamedTempFile::new().unwrap();
654 let storage = SqliteStorage::new(temp_file.path().to_str().unwrap()).await.unwrap();
655
656 let agent = AgentModel::new(
658 "test-agent".to_string(),
659 "worker".to_string(),
660 vec!["compute".to_string()],
661 );
662
663 storage.store_agent(&agent).await.unwrap();
664 let retrieved = storage.get_agent(&agent.id).await.unwrap();
665 assert!(retrieved.is_some());
666 assert_eq!(retrieved.unwrap().name, "test-agent");
667
668 let task = TaskModel::new(
670 "process".to_string(),
671 serde_json::json!({"data": "test"}),
672 TaskPriority::High,
673 );
674
675 storage.store_task(&task).await.unwrap();
676 let pending = storage.get_pending_tasks().await.unwrap();
677 assert_eq!(pending.len(), 1);
678 }
679}