1use crate::task::{Task, TaskId};
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use thiserror::Error;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Error)]
14pub enum StorageError {
15 #[error("Task not found: {0}")]
17 NotFound(TaskId),
18
19 #[error("Task already exists: {0}")]
21 AlreadyExists(TaskId),
22
23 #[error("Storage error: {0}")]
25 Backend(String),
26
27 #[error("Serialization error: {0}")]
29 Serialization(#[from] serde_json::Error),
30}
31
32#[async_trait]
37pub trait Storage: Send + Sync {
38 async fn get_task(&self, task_id: &str) -> Result<Option<Task>, StorageError>;
40
41 async fn save_task(&self, task: &Task) -> Result<(), StorageError>;
43
44 async fn update_task(&self, task: &Task) -> Result<(), StorageError>;
46
47 async fn delete_task(&self, task_id: &str) -> Result<(), StorageError>;
49
50 async fn list_tasks_by_thread(&self, thread_id: &str) -> Result<Vec<Task>, StorageError>;
52
53 async fn list_tasks(&self, limit: Option<usize>) -> Result<Vec<Task>, StorageError>;
55}
56
57#[derive(Debug, Default)]
62pub struct InMemoryStorage {
63 tasks: Arc<RwLock<HashMap<TaskId, Task>>>,
64}
65
66impl InMemoryStorage {
67 pub fn new() -> Self {
69 Self {
70 tasks: Arc::new(RwLock::new(HashMap::new())),
71 }
72 }
73
74 pub async fn len(&self) -> usize {
76 self.tasks.read().await.len()
77 }
78
79 pub async fn is_empty(&self) -> bool {
81 self.tasks.read().await.is_empty()
82 }
83
84 pub async fn clear(&self) {
86 self.tasks.write().await.clear();
87 }
88}
89
90#[async_trait]
91impl Storage for InMemoryStorage {
92 async fn get_task(&self, task_id: &str) -> Result<Option<Task>, StorageError> {
93 let tasks = self.tasks.read().await;
94 Ok(tasks.get(task_id).cloned())
95 }
96
97 async fn save_task(&self, task: &Task) -> Result<(), StorageError> {
98 let mut tasks = self.tasks.write().await;
99 if tasks.contains_key(&task.id) {
100 return Err(StorageError::AlreadyExists(task.id.clone()));
101 }
102 tasks.insert(task.id.clone(), task.clone());
103 Ok(())
104 }
105
106 async fn update_task(&self, task: &Task) -> Result<(), StorageError> {
107 let mut tasks = self.tasks.write().await;
108 if !tasks.contains_key(&task.id) {
109 return Err(StorageError::NotFound(task.id.clone()));
110 }
111 tasks.insert(task.id.clone(), task.clone());
112 Ok(())
113 }
114
115 async fn delete_task(&self, task_id: &str) -> Result<(), StorageError> {
116 let mut tasks = self.tasks.write().await;
117 if tasks.remove(task_id).is_none() {
118 return Err(StorageError::NotFound(task_id.to_string()));
119 }
120 Ok(())
121 }
122
123 async fn list_tasks_by_thread(&self, thread_id: &str) -> Result<Vec<Task>, StorageError> {
124 let tasks = self.tasks.read().await;
125 let thread_tasks: Vec<Task> = tasks
126 .values()
127 .filter(|t| t.thread_id == thread_id)
128 .cloned()
129 .collect();
130 Ok(thread_tasks)
131 }
132
133 async fn list_tasks(&self, limit: Option<usize>) -> Result<Vec<Task>, StorageError> {
134 let tasks = self.tasks.read().await;
135 let mut all_tasks: Vec<Task> = tasks.values().cloned().collect();
136
137 all_tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
139
140 if let Some(limit) = limit {
141 all_tasks.truncate(limit);
142 }
143
144 Ok(all_tasks)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::schema::Message;
152
153 #[tokio::test]
154 async fn test_save_and_get_task() {
155 let storage = InMemoryStorage::new();
156 let task = Task::new("thread-1", Message::user("Hello"));
157 let task_id = task.id.clone();
158
159 storage.save_task(&task).await.unwrap();
160
161 let retrieved = storage.get_task(&task_id).await.unwrap();
162 assert!(retrieved.is_some());
163 assert_eq!(retrieved.unwrap().id, task_id);
164 }
165
166 #[tokio::test]
167 async fn test_save_duplicate_task() {
168 let storage = InMemoryStorage::new();
169 let task = Task::new("thread-1", Message::user("Hello"));
170
171 storage.save_task(&task).await.unwrap();
172 let result = storage.save_task(&task).await;
173
174 assert!(matches!(result, Err(StorageError::AlreadyExists(_))));
175 }
176
177 #[tokio::test]
178 async fn test_update_task() {
179 let storage = InMemoryStorage::new();
180 let mut task = Task::new("thread-1", Message::user("Hello"));
181
182 storage.save_task(&task).await.unwrap();
183
184 let _ = task.start();
185 storage.update_task(&task).await.unwrap();
186
187 let retrieved = storage.get_task(&task.id).await.unwrap().unwrap();
188 assert!(retrieved.is_running());
189 }
190
191 #[tokio::test]
192 async fn test_update_nonexistent_task() {
193 let storage = InMemoryStorage::new();
194 let task = Task::new("thread-1", Message::user("Hello"));
195
196 let result = storage.update_task(&task).await;
197 assert!(matches!(result, Err(StorageError::NotFound(_))));
198 }
199
200 #[tokio::test]
201 async fn test_delete_task() {
202 let storage = InMemoryStorage::new();
203 let task = Task::new("thread-1", Message::user("Hello"));
204 let task_id = task.id.clone();
205
206 storage.save_task(&task).await.unwrap();
207 storage.delete_task(&task_id).await.unwrap();
208
209 let retrieved = storage.get_task(&task_id).await.unwrap();
210 assert!(retrieved.is_none());
211 }
212
213 #[tokio::test]
214 async fn test_list_tasks_by_thread() {
215 let storage = InMemoryStorage::new();
216
217 let task1 = Task::new("thread-1", Message::user("Hello"));
218 let task2 = Task::new("thread-1", Message::user("World"));
219 let task3 = Task::new("thread-2", Message::user("Other"));
220
221 storage.save_task(&task1).await.unwrap();
222 storage.save_task(&task2).await.unwrap();
223 storage.save_task(&task3).await.unwrap();
224
225 let thread1_tasks = storage.list_tasks_by_thread("thread-1").await.unwrap();
226 assert_eq!(thread1_tasks.len(), 2);
227
228 let thread2_tasks = storage.list_tasks_by_thread("thread-2").await.unwrap();
229 assert_eq!(thread2_tasks.len(), 1);
230 }
231
232 #[tokio::test]
233 async fn test_list_tasks_with_limit() {
234 let storage = InMemoryStorage::new();
235
236 for i in 0..10 {
237 let task = Task::new("thread-1", Message::user(format!("Task {}", i)));
238 storage.save_task(&task).await.unwrap();
239 }
240
241 let limited = storage.list_tasks(Some(5)).await.unwrap();
242 assert_eq!(limited.len(), 5);
243
244 let unlimited = storage.list_tasks(None).await.unwrap();
245 assert_eq!(unlimited.len(), 10);
246 }
247
248 #[tokio::test]
249 async fn test_clear_storage() {
250 let storage = InMemoryStorage::new();
251
252 let task = Task::new("thread-1", Message::user("Hello"));
253 storage.save_task(&task).await.unwrap();
254
255 assert!(!storage.is_empty().await);
256
257 storage.clear().await;
258 assert!(storage.is_empty().await);
259 }
260}