Skip to main content

wae_scheduler/
delayed.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::{
4    collections::BTreeMap,
5    sync::{
6        Arc,
7        atomic::{AtomicBool, Ordering},
8    },
9    time::Duration,
10};
11use tokio::sync::{RwLock, broadcast, mpsc};
12use tracing::{debug, info};
13use uuid::Uuid;
14
15use crate::task::{TaskHandle, TaskId, TaskState};
16use wae_types::{WaeError, WaeResult};
17
18/// 延迟任务
19#[derive(Debug, Clone)]
20pub struct DelayedTask<T: Send + Sync + Clone + 'static> {
21    /// 任务 ID
22    pub id: TaskId,
23    /// 任务名称
24    pub name: String,
25    /// 执行时间
26    pub execute_at: DateTime<Utc>,
27    /// 优先级 (数值越小优先级越高)
28    pub priority: u32,
29    /// 任务数据
30    pub data: T,
31    /// 创建时间
32    pub created_at: DateTime<Utc>,
33}
34
35impl<T: Send + Sync + Clone + 'static> DelayedTask<T> {
36    /// 创建新的延迟任务
37    pub fn new(name: String, execute_at: DateTime<Utc>, data: T) -> Self {
38        Self { id: Uuid::new_v4().to_string(), name, execute_at, priority: 0, data, created_at: Utc::now() }
39    }
40
41    /// 设置优先级
42    pub fn with_priority(mut self, priority: u32) -> Self {
43        self.priority = priority;
44        self
45    }
46
47    /// 检查是否到期
48    pub fn is_due(&self) -> bool {
49        Utc::now() >= self.execute_at
50    }
51
52    /// 获取剩余时间
53    pub fn remaining(&self) -> Duration {
54        let now = Utc::now();
55        if now >= self.execute_at { Duration::ZERO } else { (self.execute_at - now).to_std().unwrap_or(Duration::ZERO) }
56    }
57}
58
59/// 延迟任务执行器 trait
60#[async_trait]
61pub trait DelayedTaskExecutor<T: Send + Sync + Clone + 'static>: Send + Sync {
62    /// 执行任务
63    async fn execute(&self, task: DelayedTask<T>) -> WaeResult<()>;
64}
65
66/// 延迟任务队列配置
67#[derive(Debug, Clone)]
68pub struct DelayedQueueConfig {
69    /// 最大队列大小
70    pub max_queue_size: usize,
71    /// 轮询间隔
72    pub poll_interval: Duration,
73    /// 最大并发执行数
74    pub max_concurrent_executions: usize,
75}
76
77impl Default for DelayedQueueConfig {
78    fn default() -> Self {
79        Self { max_queue_size: 10000, poll_interval: Duration::from_millis(100), max_concurrent_executions: 10 }
80    }
81}
82
83/// 延迟任务队列
84///
85/// 管理延迟执行的任务,支持优先级排序。
86pub struct DelayedQueue<T: Send + Sync + Clone + 'static> {
87    /// 配置
88    #[allow(dead_code)]
89    config: DelayedQueueConfig,
90    /// 任务队列 (按执行时间和优先级排序)
91    queue: Arc<RwLock<Vec<DelayedTask<T>>>>,
92    /// 任务句柄映射
93    handles: Arc<RwLock<BTreeMap<TaskId, TaskHandle>>>,
94    /// 任务发送器
95    task_tx: mpsc::Sender<DelayedTask<T>>,
96    /// 关闭信号
97    shutdown_tx: broadcast::Sender<()>,
98    /// 是否已关闭
99    is_shutdown: Arc<AtomicBool>,
100}
101
102impl<T: Send + Sync + Clone + 'static> DelayedQueue<T> {
103    /// 创建新的延迟任务队列
104    pub fn new<E: DelayedTaskExecutor<T> + 'static>(config: DelayedQueueConfig, executor: Arc<E>) -> Self {
105        let (task_tx, mut task_rx) = mpsc::channel(config.max_queue_size);
106        let (shutdown_tx, _) = broadcast::channel(1);
107        let queue: Arc<RwLock<Vec<DelayedTask<T>>>> = Arc::new(RwLock::new(Vec::new()));
108        let handles: Arc<RwLock<BTreeMap<TaskId, TaskHandle>>> = Arc::new(RwLock::new(BTreeMap::new()));
109        let is_shutdown = Arc::new(AtomicBool::new(false));
110
111        let queue_clone = queue.clone();
112        let handles_clone = handles.clone();
113        let is_shutdown_clone = is_shutdown.clone();
114        let mut shutdown_rx = shutdown_tx.subscribe();
115
116        tokio::spawn(async move {
117            loop {
118                tokio::select! {
119                    _ = shutdown_rx.recv() => {
120                        debug!("Delayed queue received shutdown signal");
121                        break;
122                    }
123                    _ = tokio::time::sleep(config.poll_interval) => {
124                        if is_shutdown_clone.load(Ordering::SeqCst) {
125                            break;
126                        }
127
128                        let now = Utc::now();
129                        let mut queue_guard = queue_clone.write().await;
130
131                        let due_tasks: Vec<_> = queue_guard
132                            .iter()
133                            .filter(|t| t.execute_at <= now)
134                            .cloned()
135                            .collect();
136
137                        queue_guard.retain(|t| t.execute_at > now);
138
139                        drop(queue_guard);
140
141                        for task in due_tasks {
142                            let executor_clone = executor.clone();
143                            let handles_ref = handles_clone.clone();
144
145                            tokio::spawn(async move {
146                                if let Some(handle) = handles_ref.read().await.get(&task.id) {
147                                    handle.set_state(TaskState::Running).await;
148                                }
149
150                                let result = executor_clone.execute(task.clone()).await;
151
152                                if let Some(handle) = handles_ref.read().await.get(&task.id) {
153                                    match result {
154                                        Ok(()) => {
155                                            handle.record_execution().await;
156                                            handle.set_state(TaskState::Completed).await;
157                                        }
158                                        Err(e) => {
159                                            handle.record_error(e.to_string()).await;
160                                            handle.set_state(TaskState::Failed).await;
161                                        }
162                                    }
163                                }
164                            });
165                        }
166                    }
167                }
168            }
169        });
170
171        let queue_clone = queue.clone();
172        tokio::spawn(async move {
173            while let Some(task) = task_rx.recv().await {
174                let mut queue_guard = queue_clone.write().await;
175                queue_guard.push(task);
176                queue_guard.sort_by(|a, b| a.execute_at.cmp(&b.execute_at).then_with(|| a.priority.cmp(&b.priority)));
177            }
178        });
179
180        Self { config, queue, handles, task_tx, shutdown_tx, is_shutdown }
181    }
182
183    /// 注册延迟任务
184    ///
185    /// # 参数
186    ///
187    /// - `task`: 延迟任务
188    ///
189    /// # 返回值
190    ///
191    /// 返回任务句柄。
192    pub async fn schedule_delayed(&self, task: DelayedTask<T>) -> WaeResult<TaskHandle> {
193        if self.is_shutdown.load(Ordering::SeqCst) {
194            return Err(WaeError::scheduler_shutdown());
195        }
196
197        let handle = TaskHandle::new(task.id.clone(), task.name.clone());
198
199        {
200            let mut handles = self.handles.write().await;
201            handles.insert(task.id.clone(), handle.clone());
202        }
203
204        self.task_tx.send(task).await.map_err(|e| WaeError::internal(format!("Failed to send task: {}", e)))?;
205
206        info!("Scheduled delayed task: {}", handle.name);
207        Ok(handle)
208    }
209
210    /// 获取队列大小
211    pub async fn queue_size(&self) -> usize {
212        self.queue.read().await.len()
213    }
214
215    /// 取消任务
216    pub async fn cancel_task(&self, task_id: &str) -> WaeResult<bool> {
217        let mut queue = self.queue.write().await;
218        let initial_len = queue.len();
219        queue.retain(|t| t.id != task_id);
220
221        if queue.len() < initial_len {
222            if let Some(handle) = self.handles.read().await.get(task_id) {
223                handle.cancel();
224                handle.set_state(TaskState::Cancelled).await;
225            }
226            info!("Cancelled delayed task: {}", task_id);
227            Ok(true)
228        }
229        else {
230            Err(WaeError::task_not_found(task_id))
231        }
232    }
233
234    /// 获取任务句柄
235    pub async fn get_handle(&self, task_id: &str) -> Option<TaskHandle> {
236        self.handles.read().await.get(task_id).cloned()
237    }
238
239    /// 关闭队列
240    pub fn shutdown(&self) {
241        self.is_shutdown.store(true, Ordering::SeqCst);
242        let _ = self.shutdown_tx.send(());
243        info!("Delayed queue shutdown initiated");
244    }
245
246    /// 检查是否已关闭
247    pub fn is_shutdown(&self) -> bool {
248        self.is_shutdown.load(Ordering::SeqCst)
249    }
250}