taskflow_rs/storage/backends/
memory.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use tokio::sync::RwLock;
4
5use super::super::traits::TaskStorage;
6use crate::error::Result;
7use crate::task::{Task, TaskStatus};
8
9pub struct InMemoryStorage {
10    tasks: RwLock<HashMap<String, Task>>,
11}
12
13impl InMemoryStorage {
14    pub fn new() -> Self {
15        Self {
16            tasks: RwLock::new(HashMap::new()),
17        }
18    }
19}
20
21#[async_trait]
22impl TaskStorage for InMemoryStorage {
23    async fn save_task(&self, task: &Task) -> Result<()> {
24        let mut tasks = self.tasks.write().await;
25        tasks.insert(task.definition.id.clone(), task.clone());
26        Ok(())
27    }
28
29    async fn get_task(&self, task_id: &str) -> Result<Option<Task>> {
30        let tasks = self.tasks.read().await;
31        Ok(tasks.get(task_id).cloned())
32    }
33
34    async fn update_task_status(&self, task_id: &str, status: TaskStatus) -> Result<()> {
35        let mut tasks = self.tasks.write().await;
36        if let Some(task) = tasks.get_mut(task_id) {
37            task.status = status;
38        }
39        Ok(())
40    }
41
42    async fn list_tasks_by_status(&self, status: TaskStatus) -> Result<Vec<Task>> {
43        let tasks = self.tasks.read().await;
44        let filtered_tasks: Vec<Task> = tasks
45            .values()
46            .filter(|task| task.status == status)
47            .cloned()
48            .collect();
49        Ok(filtered_tasks)
50    }
51
52    async fn list_tasks_by_worker(&self, worker_id: &str) -> Result<Vec<Task>> {
53        let tasks = self.tasks.read().await;
54        let filtered_tasks: Vec<Task> = tasks
55            .values()
56            .filter(|task| task.assigned_worker.as_ref() == Some(&worker_id.to_string()))
57            .cloned()
58            .collect();
59        Ok(filtered_tasks)
60    }
61
62    async fn delete_task(&self, task_id: &str) -> Result<()> {
63        let mut tasks = self.tasks.write().await;
64        tasks.remove(task_id);
65        Ok(())
66    }
67
68    async fn get_pending_tasks(&self, limit: usize) -> Result<Vec<Task>> {
69        let mut tasks = self.list_tasks_by_status(TaskStatus::Pending).await?;
70        tasks.sort_by(|a, b| b.definition.priority.cmp(&a.definition.priority));
71        tasks.truncate(limit);
72        Ok(tasks)
73    }
74
75    async fn get_tasks_by_tags(&self, tags: &[String]) -> Result<Vec<Task>> {
76        let tasks = self.tasks.read().await;
77        let filtered_tasks: Vec<Task> = tasks
78            .values()
79            .filter(|task| task.definition.tags.iter().any(|tag| tags.contains(tag)))
80            .cloned()
81            .collect();
82        Ok(filtered_tasks)
83    }
84}