Skip to main content

serdes_ai_a2a/
storage.rs

1//! Storage abstraction for A2A tasks.
2//!
3//! This module provides the storage trait and implementations for persisting tasks.
4
5use crate::task::{Task, TaskId};
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use thiserror::Error;
10use tokio::sync::RwLock;
11
12/// Errors that can occur during storage operations.
13#[derive(Debug, Error)]
14pub enum StorageError {
15    /// Task was not found.
16    #[error("Task not found: {0}")]
17    NotFound(TaskId),
18
19    /// Task already exists.
20    #[error("Task already exists: {0}")]
21    AlreadyExists(TaskId),
22
23    /// Storage backend error.
24    #[error("Storage error: {0}")]
25    Backend(String),
26
27    /// Serialization/deserialization error.
28    #[error("Serialization error: {0}")]
29    Serialization(#[from] serde_json::Error),
30}
31
32/// Storage trait for task persistence.
33///
34/// Implementations of this trait handle storing and retrieving tasks.
35/// The default implementation uses an in-memory HashMap.
36#[async_trait]
37pub trait Storage: Send + Sync {
38    /// Get a task by ID.
39    async fn get_task(&self, task_id: &str) -> Result<Option<Task>, StorageError>;
40
41    /// Save a new task.
42    async fn save_task(&self, task: &Task) -> Result<(), StorageError>;
43
44    /// Update an existing task.
45    async fn update_task(&self, task: &Task) -> Result<(), StorageError>;
46
47    /// Delete a task.
48    async fn delete_task(&self, task_id: &str) -> Result<(), StorageError>;
49
50    /// List tasks by thread ID.
51    async fn list_tasks_by_thread(&self, thread_id: &str) -> Result<Vec<Task>, StorageError>;
52
53    /// List all tasks (with optional limit).
54    async fn list_tasks(&self, limit: Option<usize>) -> Result<Vec<Task>, StorageError>;
55}
56
57/// In-memory storage implementation.
58///
59/// Suitable for development and testing. Not recommended for production
60/// as data is lost when the process terminates.
61#[derive(Debug, Default)]
62pub struct InMemoryStorage {
63    tasks: Arc<RwLock<HashMap<TaskId, Task>>>,
64}
65
66impl InMemoryStorage {
67    /// Create a new in-memory storage.
68    pub fn new() -> Self {
69        Self {
70            tasks: Arc::new(RwLock::new(HashMap::new())),
71        }
72    }
73
74    /// Get the number of stored tasks.
75    pub async fn len(&self) -> usize {
76        self.tasks.read().await.len()
77    }
78
79    /// Check if storage is empty.
80    pub async fn is_empty(&self) -> bool {
81        self.tasks.read().await.is_empty()
82    }
83
84    /// Clear all tasks.
85    pub async fn clear(&self) {
86        self.tasks.write().await.clear();
87    }
88}
89
90#[async_trait]
91impl Storage for InMemoryStorage {
92    async fn get_task(&self, task_id: &str) -> Result<Option<Task>, StorageError> {
93        let tasks = self.tasks.read().await;
94        Ok(tasks.get(task_id).cloned())
95    }
96
97    async fn save_task(&self, task: &Task) -> Result<(), StorageError> {
98        let mut tasks = self.tasks.write().await;
99        if tasks.contains_key(&task.id) {
100            return Err(StorageError::AlreadyExists(task.id.clone()));
101        }
102        tasks.insert(task.id.clone(), task.clone());
103        Ok(())
104    }
105
106    async fn update_task(&self, task: &Task) -> Result<(), StorageError> {
107        let mut tasks = self.tasks.write().await;
108        if !tasks.contains_key(&task.id) {
109            return Err(StorageError::NotFound(task.id.clone()));
110        }
111        tasks.insert(task.id.clone(), task.clone());
112        Ok(())
113    }
114
115    async fn delete_task(&self, task_id: &str) -> Result<(), StorageError> {
116        let mut tasks = self.tasks.write().await;
117        if tasks.remove(task_id).is_none() {
118            return Err(StorageError::NotFound(task_id.to_string()));
119        }
120        Ok(())
121    }
122
123    async fn list_tasks_by_thread(&self, thread_id: &str) -> Result<Vec<Task>, StorageError> {
124        let tasks = self.tasks.read().await;
125        let thread_tasks: Vec<Task> = tasks
126            .values()
127            .filter(|t| t.thread_id == thread_id)
128            .cloned()
129            .collect();
130        Ok(thread_tasks)
131    }
132
133    async fn list_tasks(&self, limit: Option<usize>) -> Result<Vec<Task>, StorageError> {
134        let tasks = self.tasks.read().await;
135        let mut all_tasks: Vec<Task> = tasks.values().cloned().collect();
136
137        // Sort by created_at descending (newest first)
138        all_tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
139
140        if let Some(limit) = limit {
141            all_tasks.truncate(limit);
142        }
143
144        Ok(all_tasks)
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::schema::Message;
152
153    #[tokio::test]
154    async fn test_save_and_get_task() {
155        let storage = InMemoryStorage::new();
156        let task = Task::new("thread-1", Message::user("Hello"));
157        let task_id = task.id.clone();
158
159        storage.save_task(&task).await.unwrap();
160
161        let retrieved = storage.get_task(&task_id).await.unwrap();
162        assert!(retrieved.is_some());
163        assert_eq!(retrieved.unwrap().id, task_id);
164    }
165
166    #[tokio::test]
167    async fn test_save_duplicate_task() {
168        let storage = InMemoryStorage::new();
169        let task = Task::new("thread-1", Message::user("Hello"));
170
171        storage.save_task(&task).await.unwrap();
172        let result = storage.save_task(&task).await;
173
174        assert!(matches!(result, Err(StorageError::AlreadyExists(_))));
175    }
176
177    #[tokio::test]
178    async fn test_update_task() {
179        let storage = InMemoryStorage::new();
180        let mut task = Task::new("thread-1", Message::user("Hello"));
181
182        storage.save_task(&task).await.unwrap();
183
184        let _ = task.start();
185        storage.update_task(&task).await.unwrap();
186
187        let retrieved = storage.get_task(&task.id).await.unwrap().unwrap();
188        assert!(retrieved.is_running());
189    }
190
191    #[tokio::test]
192    async fn test_update_nonexistent_task() {
193        let storage = InMemoryStorage::new();
194        let task = Task::new("thread-1", Message::user("Hello"));
195
196        let result = storage.update_task(&task).await;
197        assert!(matches!(result, Err(StorageError::NotFound(_))));
198    }
199
200    #[tokio::test]
201    async fn test_delete_task() {
202        let storage = InMemoryStorage::new();
203        let task = Task::new("thread-1", Message::user("Hello"));
204        let task_id = task.id.clone();
205
206        storage.save_task(&task).await.unwrap();
207        storage.delete_task(&task_id).await.unwrap();
208
209        let retrieved = storage.get_task(&task_id).await.unwrap();
210        assert!(retrieved.is_none());
211    }
212
213    #[tokio::test]
214    async fn test_list_tasks_by_thread() {
215        let storage = InMemoryStorage::new();
216
217        let task1 = Task::new("thread-1", Message::user("Hello"));
218        let task2 = Task::new("thread-1", Message::user("World"));
219        let task3 = Task::new("thread-2", Message::user("Other"));
220
221        storage.save_task(&task1).await.unwrap();
222        storage.save_task(&task2).await.unwrap();
223        storage.save_task(&task3).await.unwrap();
224
225        let thread1_tasks = storage.list_tasks_by_thread("thread-1").await.unwrap();
226        assert_eq!(thread1_tasks.len(), 2);
227
228        let thread2_tasks = storage.list_tasks_by_thread("thread-2").await.unwrap();
229        assert_eq!(thread2_tasks.len(), 1);
230    }
231
232    #[tokio::test]
233    async fn test_list_tasks_with_limit() {
234        let storage = InMemoryStorage::new();
235
236        for i in 0..10 {
237            let task = Task::new("thread-1", Message::user(format!("Task {}", i)));
238            storage.save_task(&task).await.unwrap();
239        }
240
241        let limited = storage.list_tasks(Some(5)).await.unwrap();
242        assert_eq!(limited.len(), 5);
243
244        let unlimited = storage.list_tasks(None).await.unwrap();
245        assert_eq!(unlimited.len(), 10);
246    }
247
248    #[tokio::test]
249    async fn test_clear_storage() {
250        let storage = InMemoryStorage::new();
251
252        let task = Task::new("thread-1", Message::user("Hello"));
253        storage.save_task(&task).await.unwrap();
254
255        assert!(!storage.is_empty().await);
256
257        storage.clear().await;
258        assert!(storage.is_empty().await);
259    }
260}