taskflow_rs/storage/backends/
memory.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use tokio::sync::RwLock;
4
5use super::super::traits::TaskStorage;
6use crate::error::Result;
7use crate::task::{Task, TaskStatus};
8
9pub struct InMemoryStorage {
10 tasks: RwLock<HashMap<String, Task>>,
11}
12
13impl InMemoryStorage {
14 pub fn new() -> Self {
15 Self {
16 tasks: RwLock::new(HashMap::new()),
17 }
18 }
19}
20
21#[async_trait]
22impl TaskStorage for InMemoryStorage {
23 async fn save_task(&self, task: &Task) -> Result<()> {
24 let mut tasks = self.tasks.write().await;
25 tasks.insert(task.definition.id.clone(), task.clone());
26 Ok(())
27 }
28
29 async fn get_task(&self, task_id: &str) -> Result<Option<Task>> {
30 let tasks = self.tasks.read().await;
31 Ok(tasks.get(task_id).cloned())
32 }
33
34 async fn update_task_status(&self, task_id: &str, status: TaskStatus) -> Result<()> {
35 let mut tasks = self.tasks.write().await;
36 if let Some(task) = tasks.get_mut(task_id) {
37 task.status = status;
38 }
39 Ok(())
40 }
41
42 async fn list_tasks_by_status(&self, status: TaskStatus) -> Result<Vec<Task>> {
43 let tasks = self.tasks.read().await;
44 let filtered_tasks: Vec<Task> = tasks
45 .values()
46 .filter(|task| task.status == status)
47 .cloned()
48 .collect();
49 Ok(filtered_tasks)
50 }
51
52 async fn list_tasks_by_worker(&self, worker_id: &str) -> Result<Vec<Task>> {
53 let tasks = self.tasks.read().await;
54 let filtered_tasks: Vec<Task> = tasks
55 .values()
56 .filter(|task| task.assigned_worker.as_ref() == Some(&worker_id.to_string()))
57 .cloned()
58 .collect();
59 Ok(filtered_tasks)
60 }
61
62 async fn delete_task(&self, task_id: &str) -> Result<()> {
63 let mut tasks = self.tasks.write().await;
64 tasks.remove(task_id);
65 Ok(())
66 }
67
68 async fn get_pending_tasks(&self, limit: usize) -> Result<Vec<Task>> {
69 let mut tasks = self.list_tasks_by_status(TaskStatus::Pending).await?;
70 tasks.sort_by(|a, b| b.definition.priority.cmp(&a.definition.priority));
71 tasks.truncate(limit);
72 Ok(tasks)
73 }
74
75 async fn get_tasks_by_tags(&self, tags: &[String]) -> Result<Vec<Task>> {
76 let tasks = self.tasks.read().await;
77 let filtered_tasks: Vec<Task> = tasks
78 .values()
79 .filter(|task| task.definition.tags.iter().any(|tag| tags.contains(tag)))
80 .cloned()
81 .collect();
82 Ok(filtered_tasks)
83 }
84}