1use crate::{
4 models::*, Storage, StorageError, Transaction as TransactionTrait,
5};
6use async_trait::async_trait;
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12pub struct MemoryStorage {
14 agents: Arc<RwLock<HashMap<String, AgentModel>>>,
15 tasks: Arc<RwLock<HashMap<String, TaskModel>>>,
16 events: Arc<RwLock<Vec<EventModel>>>,
17 messages: Arc<RwLock<Vec<MessageModel>>>,
18 metrics: Arc<RwLock<Vec<MetricModel>>>,
19 next_sequence: Arc<RwLock<u64>>,
20}
21
22impl MemoryStorage {
23 pub fn new() -> Self {
25 Self {
26 agents: Arc::new(RwLock::new(HashMap::new())),
27 tasks: Arc::new(RwLock::new(HashMap::new())),
28 events: Arc::new(RwLock::new(Vec::new())),
29 messages: Arc::new(RwLock::new(Vec::new())),
30 metrics: Arc::new(RwLock::new(Vec::new())),
31 next_sequence: Arc::new(RwLock::new(1)),
32 }
33 }
34}
35
36impl Default for MemoryStorage {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42#[async_trait]
43impl Storage for MemoryStorage {
44 type Error = StorageError;
45
46 async fn store_agent(&self, agent: &AgentModel) -> Result<(), Self::Error> {
48 let mut agents = self.agents.write();
49 agents.insert(agent.id.clone(), agent.clone());
50 debug!("Stored agent in memory: {}", agent.id);
51 Ok(())
52 }
53
54 async fn get_agent(&self, id: &str) -> Result<Option<AgentModel>, Self::Error> {
55 let agents = self.agents.read();
56 Ok(agents.get(id).cloned())
57 }
58
59 async fn update_agent(&self, agent: &AgentModel) -> Result<(), Self::Error> {
60 let mut agents = self.agents.write();
61 if !agents.contains_key(&agent.id) {
62 return Err(StorageError::NotFound(format!("Agent {} not found", agent.id)));
63 }
64 agents.insert(agent.id.clone(), agent.clone());
65 debug!("Updated agent in memory: {}", agent.id);
66 Ok(())
67 }
68
69 async fn delete_agent(&self, id: &str) -> Result<(), Self::Error> {
70 let mut agents = self.agents.write();
71 if agents.remove(id).is_none() {
72 return Err(StorageError::NotFound(format!("Agent {} not found", id)));
73 }
74 debug!("Deleted agent from memory: {}", id);
75 Ok(())
76 }
77
78 async fn list_agents(&self) -> Result<Vec<AgentModel>, Self::Error> {
79 let agents = self.agents.read();
80 let mut list: Vec<_> = agents.values().cloned().collect();
81 list.sort_by(|a, b| b.created_at.cmp(&a.created_at));
82 Ok(list)
83 }
84
85 async fn list_agents_by_status(&self, status: &str) -> Result<Vec<AgentModel>, Self::Error> {
86 let agents = self.agents.read();
87 let mut list: Vec<_> = agents
88 .values()
89 .filter(|a| a.status.to_string() == status)
90 .cloned()
91 .collect();
92 list.sort_by(|a, b| b.created_at.cmp(&a.created_at));
93 Ok(list)
94 }
95
96 async fn store_task(&self, task: &TaskModel) -> Result<(), Self::Error> {
98 let mut tasks = self.tasks.write();
99 tasks.insert(task.id.clone(), task.clone());
100 debug!("Stored task in memory: {}", task.id);
101 Ok(())
102 }
103
104 async fn get_task(&self, id: &str) -> Result<Option<TaskModel>, Self::Error> {
105 let tasks = self.tasks.read();
106 Ok(tasks.get(id).cloned())
107 }
108
109 async fn update_task(&self, task: &TaskModel) -> Result<(), Self::Error> {
110 let mut tasks = self.tasks.write();
111 if !tasks.contains_key(&task.id) {
112 return Err(StorageError::NotFound(format!("Task {} not found", task.id)));
113 }
114 tasks.insert(task.id.clone(), task.clone());
115 debug!("Updated task in memory: {}", task.id);
116 Ok(())
117 }
118
119 async fn get_pending_tasks(&self) -> Result<Vec<TaskModel>, Self::Error> {
120 let tasks = self.tasks.read();
121 let mut pending: Vec<_> = tasks
122 .values()
123 .filter(|t| t.status == TaskStatus::Pending)
124 .cloned()
125 .collect();
126 pending.sort_by(|a, b| {
127 b.priority.cmp(&a.priority)
128 .then_with(|| a.created_at.cmp(&b.created_at))
129 });
130 Ok(pending)
131 }
132
133 async fn get_tasks_by_agent(&self, agent_id: &str) -> Result<Vec<TaskModel>, Self::Error> {
134 let tasks = self.tasks.read();
135 let mut agent_tasks: Vec<_> = tasks
136 .values()
137 .filter(|t| t.assigned_to.as_deref() == Some(agent_id))
138 .cloned()
139 .collect();
140 agent_tasks.sort_by(|a, b| {
141 b.priority.cmp(&a.priority)
142 .then_with(|| a.created_at.cmp(&b.created_at))
143 });
144 Ok(agent_tasks)
145 }
146
147 async fn claim_task(&self, task_id: &str, agent_id: &str) -> Result<bool, Self::Error> {
148 let mut tasks = self.tasks.write();
149 if let Some(task) = tasks.get_mut(task_id) {
150 if task.status == TaskStatus::Pending {
151 task.assign_to(agent_id);
152 return Ok(true);
153 }
154 }
155 Ok(false)
156 }
157
158 async fn store_event(&self, event: &EventModel) -> Result<(), Self::Error> {
160 let mut events = self.events.write();
161 let mut sequence = self.next_sequence.write();
162
163 let mut event_with_seq = event.clone();
164 event_with_seq.sequence = *sequence;
165 *sequence += 1;
166
167 events.push(event_with_seq);
168 debug!("Stored event in memory: {}", event.id);
169 Ok(())
170 }
171
172 async fn get_events_by_agent(&self, agent_id: &str, limit: usize) -> Result<Vec<EventModel>, Self::Error> {
173 let events = self.events.read();
174 let mut agent_events: Vec<_> = events
175 .iter()
176 .filter(|e| e.agent_id.as_deref() == Some(agent_id))
177 .cloned()
178 .collect();
179 agent_events.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
180 agent_events.truncate(limit);
181 Ok(agent_events)
182 }
183
184 async fn get_events_by_type(&self, event_type: &str, limit: usize) -> Result<Vec<EventModel>, Self::Error> {
185 let events = self.events.read();
186 let mut type_events: Vec<_> = events
187 .iter()
188 .filter(|e| e.event_type == event_type)
189 .cloned()
190 .collect();
191 type_events.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
192 type_events.truncate(limit);
193 Ok(type_events)
194 }
195
196 async fn get_events_since(&self, timestamp: i64) -> Result<Vec<EventModel>, Self::Error> {
197 let events = self.events.read();
198 let since_events: Vec<_> = events
199 .iter()
200 .filter(|e| e.timestamp.timestamp() > timestamp)
201 .cloned()
202 .collect();
203 Ok(since_events)
204 }
205
206 async fn store_message(&self, message: &MessageModel) -> Result<(), Self::Error> {
208 let mut messages = self.messages.write();
209 messages.push(message.clone());
210 debug!("Stored message in memory: {}", message.id);
211 Ok(())
212 }
213
214 async fn get_messages_between(&self, agent1: &str, agent2: &str, limit: usize) -> Result<Vec<MessageModel>, Self::Error> {
215 let messages = self.messages.read();
216 let mut between: Vec<_> = messages
217 .iter()
218 .filter(|m| {
219 (m.from_agent == agent1 && m.to_agent == agent2) ||
220 (m.from_agent == agent2 && m.to_agent == agent1)
221 })
222 .cloned()
223 .collect();
224 between.sort_by(|a, b| b.created_at.cmp(&a.created_at));
225 between.truncate(limit);
226 Ok(between)
227 }
228
229 async fn get_unread_messages(&self, agent_id: &str) -> Result<Vec<MessageModel>, Self::Error> {
230 let messages = self.messages.read();
231 let unread: Vec<_> = messages
232 .iter()
233 .filter(|m| m.to_agent == agent_id && !m.read)
234 .cloned()
235 .collect();
236 Ok(unread)
237 }
238
239 async fn mark_message_read(&self, message_id: &str) -> Result<(), Self::Error> {
240 let mut messages = self.messages.write();
241 if let Some(msg) = messages.iter_mut().find(|m| m.id == message_id) {
242 msg.mark_read();
243 Ok(())
244 } else {
245 Err(StorageError::NotFound(format!("Message {} not found", message_id)))
246 }
247 }
248
249 async fn store_metric(&self, metric: &MetricModel) -> Result<(), Self::Error> {
251 let mut metrics = self.metrics.write();
252 metrics.push(metric.clone());
253 debug!("Stored metric in memory: {}", metric.id);
254 Ok(())
255 }
256
257 async fn get_metrics_by_agent(&self, agent_id: &str, metric_type: &str) -> Result<Vec<MetricModel>, Self::Error> {
258 let metrics = self.metrics.read();
259 let agent_metrics: Vec<_> = metrics
260 .iter()
261 .filter(|m| m.agent_id.as_deref() == Some(agent_id) && m.metric_type == metric_type)
262 .cloned()
263 .collect();
264 Ok(agent_metrics)
265 }
266
267 async fn get_aggregated_metrics(&self, metric_type: &str, start_time: i64, end_time: i64) -> Result<Vec<MetricModel>, Self::Error> {
268 let metrics = self.metrics.read();
269
270 let mut grouped: HashMap<(Option<String>, String), Vec<f64>> = HashMap::new();
272
273 for metric in metrics.iter() {
274 let timestamp = metric.timestamp.timestamp();
275 if metric.metric_type == metric_type && timestamp >= start_time && timestamp <= end_time {
276 let key = (metric.agent_id.clone(), metric.unit.clone());
277 grouped.entry(key).or_insert_with(Vec::new).push(metric.value);
278 }
279 }
280
281 let mut results = Vec::new();
283 for ((agent_id, unit), values) in grouped {
284 if !values.is_empty() {
285 let avg = values.iter().sum::<f64>() / values.len() as f64;
286 let mut metric = MetricModel::new(metric_type.to_string(), avg, unit);
287 metric.agent_id = agent_id;
288 metric.tags.insert("count".to_string(), values.len().to_string());
289 results.push(metric);
290 }
291 }
292
293 Ok(results)
294 }
295
296 async fn begin_transaction(&self) -> Result<Box<dyn TransactionTrait>, Self::Error> {
298 Ok(Box::new(MemoryTransaction::new()))
299 }
300
301 async fn vacuum(&self) -> Result<(), Self::Error> {
303 debug!("Vacuum called on memory storage (no-op)");
305 Ok(())
306 }
307
308 async fn checkpoint(&self) -> Result<(), Self::Error> {
309 debug!("Checkpoint called on memory storage (no-op)");
311 Ok(())
312 }
313
314 async fn get_storage_size(&self) -> Result<u64, Self::Error> {
315 let agents_count = self.agents.read().len();
317 let tasks_count = self.tasks.read().len();
318 let events_count = self.events.read().len();
319 let messages_count = self.messages.read().len();
320 let metrics_count = self.metrics.read().len();
321
322 let total_items = agents_count + tasks_count + events_count + messages_count + metrics_count;
324 Ok((total_items * 1024) as u64)
325 }
326}
327
328struct MemoryTransaction;
330
331impl MemoryTransaction {
332 fn new() -> Self {
333 Self
334 }
335}
336
337#[async_trait]
338impl TransactionTrait for MemoryTransaction {
339 async fn commit(self: Box<Self>) -> Result<(), StorageError> {
340 Ok(())
342 }
343
344 async fn rollback(self: Box<Self>) -> Result<(), StorageError> {
345 Ok(())
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[tokio::test]
355 async fn test_memory_storage() {
356 let storage = MemoryStorage::new();
357
358 let agent = AgentModel::new(
360 "test-agent".to_string(),
361 "worker".to_string(),
362 vec!["compute".to_string()],
363 );
364
365 storage.store_agent(&agent).await.unwrap();
366 let retrieved = storage.get_agent(&agent.id).await.unwrap();
367 assert!(retrieved.is_some());
368 assert_eq!(retrieved.unwrap().name, "test-agent");
369
370 let task = TaskModel::new(
372 "process".to_string(),
373 serde_json::json!({"data": "test"}),
374 TaskPriority::High,
375 );
376
377 storage.store_task(&task).await.unwrap();
378 let pending = storage.get_pending_tasks().await.unwrap();
379 assert_eq!(pending.len(), 1);
380
381 let claimed = storage.claim_task(&task.id, &agent.id).await.unwrap();
383 assert!(claimed);
384
385 let updated_task = storage.get_task(&task.id).await.unwrap().unwrap();
386 assert_eq!(updated_task.assigned_to, Some(agent.id.clone()));
387 }
388}