Skip to main content

sh_layer2/
tasks.rs

1//! # Task Manager
2//!
3//! 任务队列管理,支持优先级和依赖关系。
4
5use async_trait::async_trait;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::{BinaryHeap, HashMap};
9use std::time::Duration;
10
11use crate::types::{Layer2Result, TaskId};
12
13/// 任务状态
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
15pub enum TaskStatus {
16    #[default]
17    Pending,
18    Running,
19    Completed,
20    Failed,
21    Cancelled,
22}
23
24/// 任务优先级
25#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
26pub enum TaskPriority {
27    Low = 0,
28    #[default]
29    Normal = 1,
30    High = 2,
31    Urgent = 3,
32}
33
34/// 任务定义
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct Task {
37    pub id: TaskId,
38    pub name: String,
39    pub description: String,
40    pub status: TaskStatus,
41    pub priority: TaskPriority,
42    pub dependencies: Vec<TaskId>,
43    pub created_at: chrono::DateTime<chrono::Utc>,
44    pub started_at: Option<chrono::DateTime<chrono::Utc>>,
45    pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
46    pub timeout: Option<Duration>,
47    pub retry_count: u32,
48    pub max_retries: u32,
49    pub metadata: HashMap<String, serde_json::Value>,
50}
51
52impl Task {
53    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
54        Self {
55            id: TaskId::new(),
56            name: name.into(),
57            description: description.into(),
58            status: TaskStatus::Pending,
59            priority: TaskPriority::Normal,
60            dependencies: Vec::new(),
61            created_at: chrono::Utc::now(),
62            started_at: None,
63            completed_at: None,
64            timeout: None,
65            retry_count: 0,
66            max_retries: 3,
67            metadata: HashMap::new(),
68        }
69    }
70
71    pub fn with_priority(mut self, priority: TaskPriority) -> Self {
72        self.priority = priority;
73        self
74    }
75
76    pub fn with_timeout(mut self, timeout: Duration) -> Self {
77        self.timeout = Some(timeout);
78        self
79    }
80
81    pub fn with_dependency(mut self, task_id: TaskId) -> Self {
82        self.dependencies.push(task_id);
83        self
84    }
85
86    pub fn with_metadata(mut self, key: &str, value: serde_json::Value) -> Self {
87        self.metadata.insert(key.to_string(), value);
88        self
89    }
90
91    /// 检查是否可以执行(所有依赖已完成)
92    pub fn can_execute(&self, completed: &HashMap<TaskId, TaskStatus>) -> bool {
93        self.dependencies
94            .iter()
95            .all(|dep_id| completed.get(dep_id) == Some(&TaskStatus::Completed))
96    }
97
98    /// 获取执行时长
99    pub fn duration(&self) -> Option<Duration> {
100        self.started_at.and_then(|start| {
101            self.completed_at
102                .map(|end| Duration::from_secs((end - start).num_seconds() as u64))
103        })
104    }
105}
106
107impl Eq for Task {}
108
109impl PartialEq for Task {
110    fn eq(&self, other: &Self) -> bool {
111        self.id == other.id
112    }
113}
114
115impl Ord for Task {
116    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
117        // 优先级高的排在前面
118        other
119            .priority
120            .cmp(&self.priority)
121            .then_with(|| other.created_at.cmp(&self.created_at))
122    }
123}
124
125impl PartialOrd for Task {
126    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
127        Some(self.cmp(other))
128    }
129}
130
131/// 任务管理器接口
132#[async_trait]
133pub trait TaskManagerTrait: Send + Sync {
134    /// 添加任务
135    fn add(&self, task: Task) -> Layer2Result<TaskId>;
136
137    /// 获取任务
138    fn get(&self, id: &TaskId) -> Option<Task>;
139
140    /// 更新任务状态
141    async fn update_status(&self, id: &TaskId, status: TaskStatus) -> Layer2Result<bool>;
142
143    /// 取消任务
144    async fn cancel(&self, id: &TaskId) -> Layer2Result<bool>;
145
146    /// 获取下一个可执行任务
147    fn next(&self) -> Option<Task>;
148
149    /// 获取任务数量
150    fn count(&self) -> usize;
151
152    /// 获取特定状态的任务数量
153    fn count_by_status(&self, status: TaskStatus) -> usize;
154
155    /// 清理已完成任务
156    fn cleanup_completed(&self) -> usize;
157}
158
159/// 任务管理器实现
160pub struct TaskManager {
161    tasks: RwLock<HashMap<TaskId, Task>>,
162    queue: RwLock<BinaryHeap<Task>>,
163}
164
165impl TaskManager {
166    pub fn new() -> Self {
167        Self {
168            tasks: RwLock::new(HashMap::new()),
169            queue: RwLock::new(BinaryHeap::new()),
170        }
171    }
172}
173
174impl Default for TaskManager {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180#[async_trait]
181impl TaskManagerTrait for TaskManager {
182    fn add(&self, task: Task) -> Layer2Result<TaskId> {
183        let id = task.id.clone();
184
185        self.queue.write().push(task.clone());
186        self.tasks.write().insert(id.clone(), task);
187
188        Ok(id)
189    }
190
191    fn get(&self, id: &TaskId) -> Option<Task> {
192        self.tasks.read().get(id).cloned()
193    }
194
195    async fn update_status(&self, id: &TaskId, status: TaskStatus) -> Layer2Result<bool> {
196        let mut tasks = self.tasks.write();
197
198        if let Some(task) = tasks.get_mut(id) {
199            task.status = status;
200
201            if status == TaskStatus::Running {
202                task.started_at = Some(chrono::Utc::now());
203            } else if matches!(status, TaskStatus::Completed | TaskStatus::Failed) {
204                task.completed_at = Some(chrono::Utc::now());
205            }
206
207            Ok(true)
208        } else {
209            Ok(false)
210        }
211    }
212
213    async fn cancel(&self, id: &TaskId) -> Layer2Result<bool> {
214        self.update_status(id, TaskStatus::Cancelled).await
215    }
216
217    fn next(&self) -> Option<Task> {
218        let tasks = self.tasks.read();
219        let completed: HashMap<TaskId, TaskStatus> = tasks
220            .iter()
221            .filter(|(_, t)| t.status == TaskStatus::Completed)
222            .map(|(id, t)| (id.clone(), t.status))
223            .collect();
224
225        self.queue
226            .write()
227            .pop()
228            .filter(|t| t.can_execute(&completed))
229    }
230
231    fn count(&self) -> usize {
232        self.tasks.read().len()
233    }
234
235    fn count_by_status(&self, status: TaskStatus) -> usize {
236        self.tasks
237            .read()
238            .values()
239            .filter(|t| t.status == status)
240            .count()
241    }
242
243    fn cleanup_completed(&self) -> usize {
244        let mut tasks = self.tasks.write();
245        let completed: Vec<TaskId> = tasks
246            .iter()
247            .filter(|(_, t)| t.status == TaskStatus::Completed)
248            .map(|(id, _)| id.clone())
249            .collect();
250
251        let count = completed.len();
252        for id in completed {
253            tasks.remove(&id);
254        }
255
256        // 重建队列
257        let mut queue = self.queue.write();
258        *queue = tasks.values().cloned().collect();
259
260        count
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_task_creation() {
270        let task = Task::new("test", "Test task");
271        assert_eq!(task.status, TaskStatus::Pending);
272        assert_eq!(task.priority, TaskPriority::Normal);
273    }
274
275    #[test]
276    fn test_task_priority() {
277        let task = Task::new("test", "Test").with_priority(TaskPriority::High);
278        assert_eq!(task.priority, TaskPriority::High);
279    }
280
281    #[test]
282    fn test_task_manager() {
283        let manager = TaskManager::new();
284        assert_eq!(manager.count(), 0);
285
286        let task = Task::new("test", "Test task");
287        manager.add(task).unwrap();
288
289        assert_eq!(manager.count(), 1);
290    }
291}