1use crate::{models::*, Storage, StorageError, Transaction as TransactionTrait};
4use async_trait::async_trait;
5use parking_lot::RwLock;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tracing::debug;
9
10pub struct MemoryStorage {
12 agents: Arc<RwLock<HashMap<String, AgentModel>>>,
13 tasks: Arc<RwLock<HashMap<String, TaskModel>>>,
14 events: Arc<RwLock<Vec<EventModel>>>,
15 messages: Arc<RwLock<Vec<MessageModel>>>,
16 metrics: Arc<RwLock<Vec<MetricModel>>>,
17 next_sequence: Arc<RwLock<u64>>,
18}
19
20impl MemoryStorage {
21 pub fn new() -> Self {
23 Self {
24 agents: Arc::new(RwLock::new(HashMap::new())),
25 tasks: Arc::new(RwLock::new(HashMap::new())),
26 events: Arc::new(RwLock::new(Vec::new())),
27 messages: Arc::new(RwLock::new(Vec::new())),
28 metrics: Arc::new(RwLock::new(Vec::new())),
29 next_sequence: Arc::new(RwLock::new(1)),
30 }
31 }
32}
33
34impl Default for MemoryStorage {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40#[async_trait]
41impl Storage for MemoryStorage {
42 type Error = StorageError;
43
44 async fn store_agent(&self, agent: &AgentModel) -> Result<(), Self::Error> {
46 let mut agents = self.agents.write();
47 agents.insert(agent.id.clone(), agent.clone());
48 debug!("Stored agent in memory: {}", agent.id);
49 Ok(())
50 }
51
52 async fn get_agent(&self, id: &str) -> Result<Option<AgentModel>, Self::Error> {
53 let agents = self.agents.read();
54 Ok(agents.get(id).cloned())
55 }
56
57 async fn update_agent(&self, agent: &AgentModel) -> Result<(), Self::Error> {
58 let mut agents = self.agents.write();
59 if !agents.contains_key(&agent.id) {
60 return Err(StorageError::NotFound(format!(
61 "Agent {} not found",
62 agent.id
63 )));
64 }
65 agents.insert(agent.id.clone(), agent.clone());
66 debug!("Updated agent in memory: {}", agent.id);
67 Ok(())
68 }
69
70 async fn delete_agent(&self, id: &str) -> Result<(), Self::Error> {
71 let mut agents = self.agents.write();
72 if agents.remove(id).is_some() {
73 debug!("Deleted agent from memory: {}", id);
74 } else {
75 debug!("Agent {} not found, delete is idempotent", id);
76 }
77 Ok(())
78 }
79
80 async fn list_agents(&self) -> Result<Vec<AgentModel>, Self::Error> {
81 let agents = self.agents.read();
82 let mut list: Vec<_> = agents.values().cloned().collect();
83 list.sort_by(|a, b| b.created_at.cmp(&a.created_at));
84 Ok(list)
85 }
86
87 async fn list_agents_by_status(&self, status: &str) -> Result<Vec<AgentModel>, Self::Error> {
88 let agents = self.agents.read();
89 let mut list: Vec<_> = agents
90 .values()
91 .filter(|a| a.status.to_string() == status)
92 .cloned()
93 .collect();
94 list.sort_by(|a, b| b.created_at.cmp(&a.created_at));
95 Ok(list)
96 }
97
98 async fn store_task(&self, task: &TaskModel) -> Result<(), Self::Error> {
100 let mut tasks = self.tasks.write();
101 tasks.insert(task.id.clone(), task.clone());
102 debug!("Stored task in memory: {}", task.id);
103 Ok(())
104 }
105
106 async fn get_task(&self, id: &str) -> Result<Option<TaskModel>, Self::Error> {
107 let tasks = self.tasks.read();
108 Ok(tasks.get(id).cloned())
109 }
110
111 async fn update_task(&self, task: &TaskModel) -> Result<(), Self::Error> {
112 let mut tasks = self.tasks.write();
113 if !tasks.contains_key(&task.id) {
114 return Err(StorageError::NotFound(format!(
115 "Task {} not found",
116 task.id
117 )));
118 }
119 tasks.insert(task.id.clone(), task.clone());
120 debug!("Updated task in memory: {}", task.id);
121 Ok(())
122 }
123
124 async fn get_pending_tasks(&self) -> Result<Vec<TaskModel>, Self::Error> {
125 let tasks = self.tasks.read();
126 let mut pending: Vec<_> = tasks
127 .values()
128 .filter(|t| t.status == TaskStatus::Pending)
129 .cloned()
130 .collect();
131 pending.sort_by(|a, b| {
132 b.priority
133 .cmp(&a.priority)
134 .then_with(|| a.created_at.cmp(&b.created_at))
135 });
136 Ok(pending)
137 }
138
139 async fn get_tasks_by_agent(&self, agent_id: &str) -> Result<Vec<TaskModel>, Self::Error> {
140 let tasks = self.tasks.read();
141 let mut agent_tasks: Vec<_> = tasks
142 .values()
143 .filter(|t| t.assigned_to.as_deref() == Some(agent_id))
144 .cloned()
145 .collect();
146 agent_tasks.sort_by(|a, b| {
147 b.priority
148 .cmp(&a.priority)
149 .then_with(|| a.created_at.cmp(&b.created_at))
150 });
151 Ok(agent_tasks)
152 }
153
154 async fn claim_task(&self, task_id: &str, agent_id: &str) -> Result<bool, Self::Error> {
155 let mut tasks = self.tasks.write();
156 if let Some(task) = tasks.get_mut(task_id) {
157 if task.status == TaskStatus::Pending {
158 task.assign_to(agent_id);
159 return Ok(true);
160 }
161 }
162 Ok(false)
163 }
164
165 async fn store_event(&self, event: &EventModel) -> Result<(), Self::Error> {
167 let mut events = self.events.write();
168 let mut sequence = self.next_sequence.write();
169
170 let mut event_with_seq = event.clone();
171 event_with_seq.sequence = *sequence;
172 *sequence += 1;
173
174 events.push(event_with_seq);
175 debug!("Stored event in memory: {}", event.id);
176 Ok(())
177 }
178
179 async fn get_events_by_agent(
180 &self,
181 agent_id: &str,
182 limit: usize,
183 ) -> Result<Vec<EventModel>, Self::Error> {
184 let events = self.events.read();
185 let mut agent_events: Vec<_> = events
186 .iter()
187 .filter(|e| e.agent_id.as_deref() == Some(agent_id))
188 .cloned()
189 .collect();
190 agent_events.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
191 agent_events.truncate(limit);
192 Ok(agent_events)
193 }
194
195 async fn get_events_by_type(
196 &self,
197 event_type: &str,
198 limit: usize,
199 ) -> Result<Vec<EventModel>, Self::Error> {
200 let events = self.events.read();
201 let mut type_events: Vec<_> = events
202 .iter()
203 .filter(|e| e.event_type == event_type)
204 .cloned()
205 .collect();
206 type_events.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
207 type_events.truncate(limit);
208 Ok(type_events)
209 }
210
211 async fn get_events_since(&self, timestamp: i64) -> Result<Vec<EventModel>, Self::Error> {
212 let events = self.events.read();
213 let mut since_events: Vec<_> = events
214 .iter()
215 .filter(|e| e.timestamp.timestamp() >= timestamp)
216 .cloned()
217 .collect();
218
219 since_events.sort_by(|a, b| {
221 a.timestamp.cmp(&b.timestamp)
222 .then_with(|| a.id.cmp(&b.id))
223 });
224
225 Ok(since_events)
226 }
227
228 async fn store_message(&self, message: &MessageModel) -> Result<(), Self::Error> {
230 let mut messages = self.messages.write();
231 messages.push(message.clone());
232 debug!("Stored message in memory: {}", message.id);
233 Ok(())
234 }
235
236 async fn get_messages_between(
237 &self,
238 agent1: &str,
239 agent2: &str,
240 limit: usize,
241 ) -> Result<Vec<MessageModel>, Self::Error> {
242 let messages = self.messages.read();
243 let mut between: Vec<_> = messages
244 .iter()
245 .filter(|m| {
246 (m.from_agent == agent1 && m.to_agent == agent2)
247 || (m.from_agent == agent2 && m.to_agent == agent1)
248 })
249 .cloned()
250 .collect();
251 between.sort_by(|a, b| b.created_at.cmp(&a.created_at));
252 between.truncate(limit);
253 Ok(between)
254 }
255
256 async fn get_unread_messages(&self, agent_id: &str) -> Result<Vec<MessageModel>, Self::Error> {
257 let messages = self.messages.read();
258 let unread: Vec<_> = messages
259 .iter()
260 .filter(|m| m.to_agent == agent_id && !m.read)
261 .cloned()
262 .collect();
263 Ok(unread)
264 }
265
266 async fn mark_message_read(&self, message_id: &str) -> Result<(), Self::Error> {
267 let mut messages = self.messages.write();
268 if let Some(msg) = messages.iter_mut().find(|m| m.id == message_id) {
269 msg.mark_read();
270 Ok(())
271 } else {
272 Err(StorageError::NotFound(format!(
273 "Message {} not found",
274 message_id
275 )))
276 }
277 }
278
279 async fn store_metric(&self, metric: &MetricModel) -> Result<(), Self::Error> {
281 let mut metrics = self.metrics.write();
282 metrics.push(metric.clone());
283 debug!("Stored metric in memory: {}", metric.id);
284 Ok(())
285 }
286
287 async fn get_metrics_by_agent(
288 &self,
289 agent_id: &str,
290 metric_type: &str,
291 ) -> Result<Vec<MetricModel>, Self::Error> {
292 let metrics = self.metrics.read();
293 let agent_metrics: Vec<_> = metrics
294 .iter()
295 .filter(|m| m.agent_id.as_deref() == Some(agent_id) && m.metric_type == metric_type)
296 .cloned()
297 .collect();
298 Ok(agent_metrics)
299 }
300
301 async fn get_aggregated_metrics(
302 &self,
303 metric_type: &str,
304 start_time: i64,
305 end_time: i64,
306 ) -> Result<Vec<MetricModel>, Self::Error> {
307 let metrics = self.metrics.read();
308
309 let mut grouped: HashMap<(Option<String>, String), Vec<f64>> = HashMap::new();
311
312 for metric in metrics.iter() {
313 let timestamp = metric.timestamp.timestamp();
314 if metric.metric_type == metric_type && timestamp >= start_time && timestamp <= end_time
315 {
316 let key = (metric.agent_id.clone(), metric.unit.clone());
317 grouped
318 .entry(key)
319 .or_default()
320 .push(metric.value);
321 }
322 }
323
324 let mut results = Vec::new();
326 for ((agent_id, unit), values) in grouped {
327 if !values.is_empty() {
328 let avg = values.iter().sum::<f64>() / values.len() as f64;
329 let mut metric = MetricModel::new(metric_type.to_string(), avg, unit);
330 metric.agent_id = agent_id;
331 metric
332 .tags
333 .insert("count".to_string(), values.len().to_string());
334 results.push(metric);
335 }
336 }
337
338 Ok(results)
339 }
340
341 async fn begin_transaction(&self) -> Result<Box<dyn TransactionTrait>, Self::Error> {
343 Ok(Box::new(MemoryTransaction::new()))
344 }
345
346 async fn vacuum(&self) -> Result<(), Self::Error> {
348 debug!("Vacuum called on memory storage (no-op)");
350 Ok(())
351 }
352
353 async fn checkpoint(&self) -> Result<(), Self::Error> {
354 debug!("Checkpoint called on memory storage (no-op)");
356 Ok(())
357 }
358
359 async fn get_storage_size(&self) -> Result<u64, Self::Error> {
360 let agents_count = self.agents.read().len();
362 let tasks_count = self.tasks.read().len();
363 let events_count = self.events.read().len();
364 let messages_count = self.messages.read().len();
365 let metrics_count = self.metrics.read().len();
366
367 let total_items =
369 agents_count + tasks_count + events_count + messages_count + metrics_count;
370 Ok((total_items * 1024) as u64)
371 }
372}
373
374struct MemoryTransaction;
376
377impl MemoryTransaction {
378 fn new() -> Self {
379 Self
380 }
381}
382
383#[async_trait]
384impl TransactionTrait for MemoryTransaction {
385 async fn commit(self: Box<Self>) -> Result<(), StorageError> {
386 Ok(())
388 }
389
390 async fn rollback(self: Box<Self>) -> Result<(), StorageError> {
391 Ok(())
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[tokio::test]
401 async fn test_memory_storage() {
402 let storage = MemoryStorage::new();
403
404 let agent = AgentModel::new(
406 "test-agent".to_string(),
407 "worker".to_string(),
408 vec!["compute".to_string()],
409 );
410
411 storage.store_agent(&agent).await.unwrap();
412 let retrieved = storage.get_agent(&agent.id).await.unwrap();
413 assert!(retrieved.is_some());
414 assert_eq!(retrieved.unwrap().name, "test-agent");
415
416 let task = TaskModel::new(
418 "process".to_string(),
419 serde_json::json!({"data": "test"}),
420 TaskPriority::High,
421 );
422
423 storage.store_task(&task).await.unwrap();
424 let pending = storage.get_pending_tasks().await.unwrap();
425 assert_eq!(pending.len(), 1);
426
427 let claimed = storage.claim_task(&task.id, &agent.id).await.unwrap();
429 assert!(claimed);
430
431 let updated_task = storage.get_task(&task.id).await.unwrap().unwrap();
432 assert_eq!(updated_task.assigned_to, Some(agent.id.clone()));
433 }
434}