Skip to main content

wae_scheduler/
cron.rs

1use chrono::{DateTime, Datelike, Timelike, Utc};
2use serde::{Deserialize, Serialize};
3use std::{
4    collections::BTreeMap,
5    sync::{
6        Arc,
7        atomic::{AtomicBool, Ordering},
8    },
9    time::Duration,
10};
11use tokio::sync::{RwLock, broadcast};
12use tracing::{debug, error, info};
13use uuid::Uuid;
14
15use crate::{
16    interval::IntervalScheduler,
17    task::{ScheduledTask, TaskHandle, TaskId, TaskState},
18};
19use wae_types::{WaeError, WaeResult};
20
21/// Cron 字段
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub enum CronField {
24    /// 任意值 (*)
25    Any,
26    /// 固定值
27    Value(u32),
28    /// 范围值
29    Range(u32, u32),
30    /// 步进值
31    Step(u32, u32),
32    /// 多个值
33    List(Vec<u32>),
34}
35
36impl CronField {
37    /// 检查值是否匹配
38    pub fn matches(&self, value: u32) -> bool {
39        match self {
40            CronField::Any => true,
41            CronField::Value(v) => *v == value,
42            CronField::Range(start, end) => value >= *start && value <= *end,
43            CronField::Step(start, step) => (value - start).is_multiple_of(*step),
44            CronField::List(values) => values.contains(&value),
45        }
46    }
47}
48
49/// Cron 表达式
50///
51/// 解析和计算标准 cron 表达式 (秒 分 时 日 月 周)。
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CronExpression {
54    /// 秒字段
55    pub second: CronField,
56    /// 分字段
57    pub minute: CronField,
58    /// 时字段
59    pub hour: CronField,
60    /// 日字段
61    pub day_of_month: CronField,
62    /// 月字段
63    pub month: CronField,
64    /// 周字段
65    pub day_of_week: CronField,
66}
67
68impl CronExpression {
69    /// 解析 cron 表达式
70    ///
71    /// 支持标准格式: `秒 分 时 日 月 周`
72    ///
73    /// # 示例
74    ///
75    /// - `0 * * * * *` - 每分钟执行
76    /// - `0 0 * * * *` - 每小时执行
77    /// - `0 0 0 * * *` - 每天执行
78    /// - `0 30 14 * * *` - 每天 14:30 执行
79    /// - `0 0 9-17 * * 1-5` - 周一到周五 9点到17点每小时执行
80    pub fn parse(expression: &str) -> WaeResult<Self> {
81        let parts: Vec<&str> = expression.split_whitespace().collect();
82        if parts.len() != 6 {
83            return Err(WaeError::invalid_cron_expression("Expected 6 fields: second minute hour day month weekday"));
84        }
85
86        Ok(Self {
87            second: Self::parse_field(parts[0], 0, 59)?,
88            minute: Self::parse_field(parts[1], 0, 59)?,
89            hour: Self::parse_field(parts[2], 0, 23)?,
90            day_of_month: Self::parse_field(parts[3], 1, 31)?,
91            month: Self::parse_field(parts[4], 1, 12)?,
92            day_of_week: Self::parse_field(parts[5], 0, 6)?,
93        })
94    }
95
96    fn parse_field(s: &str, min: u32, max: u32) -> WaeResult<CronField> {
97        if s == "*" {
98            return Ok(CronField::Any);
99        }
100
101        if s.contains('/') {
102            let parts: Vec<&str> = s.split('/').collect();
103            if parts.len() != 2 {
104                return Err(WaeError::invalid_cron_expression(format!("Invalid step expression: {}", s)));
105            }
106            let start = if parts[0] == "*" {
107                min
108            }
109            else {
110                parts[0]
111                    .parse::<u32>()
112                    .map_err(|_| WaeError::invalid_cron_expression(format!("Invalid value: {}", parts[0])))?
113            };
114            let step = parts[1]
115                .parse::<u32>()
116                .map_err(|_| WaeError::invalid_cron_expression(format!("Invalid step: {}", parts[1])))?;
117            return Ok(CronField::Step(start, step));
118        }
119
120        if s.contains('-') {
121            let parts: Vec<&str> = s.split('-').collect();
122            if parts.len() != 2 {
123                return Err(WaeError::invalid_cron_expression(format!("Invalid range expression: {}", s)));
124            }
125            let start = parts[0]
126                .parse::<u32>()
127                .map_err(|_| WaeError::invalid_cron_expression(format!("Invalid start: {}", parts[0])))?;
128            let end =
129                parts[1].parse::<u32>().map_err(|_| WaeError::invalid_cron_expression(format!("Invalid end: {}", parts[1])))?;
130            return Ok(CronField::Range(start, end));
131        }
132
133        if s.contains(',') {
134            let values: Result<Vec<u32>, _> = s.split(',').map(|v| v.parse::<u32>()).collect();
135            let values = values.map_err(|_| WaeError::invalid_cron_expression(format!("Invalid list: {}", s)))?;
136            return Ok(CronField::List(values));
137        }
138
139        let value = s.parse::<u32>().map_err(|_| WaeError::invalid_cron_expression(format!("Invalid value: {}", s)))?;
140
141        if value < min || value > max {
142            return Err(WaeError::invalid_cron_expression(format!("Value {} out of range [{}, {}]", value, min, max)));
143        }
144
145        Ok(CronField::Value(value))
146    }
147
148    /// 计算下一次执行时间
149    pub fn next_execution(&self, from: DateTime<Utc>) -> Option<DateTime<Utc>> {
150        let mut current = from + chrono::Duration::seconds(1);
151
152        for _ in 0..366 * 24 * 60 * 60 {
153            let second = current.second();
154            let minute = current.minute();
155            let hour = current.hour();
156            let day = current.day();
157            let month = current.month();
158            let weekday = current.weekday().num_days_from_sunday();
159
160            if !self.second.matches(second) {
161                current += chrono::Duration::seconds(1);
162                continue;
163            }
164            if !self.minute.matches(minute) {
165                current += chrono::Duration::minutes(1);
166                current = current.with_second(0).unwrap();
167                continue;
168            }
169            if !self.hour.matches(hour) {
170                current += chrono::Duration::hours(1);
171                current = current.with_second(0).unwrap().with_minute(0).unwrap();
172                continue;
173            }
174            if !self.day_of_month.matches(day) {
175                current += chrono::Duration::days(1);
176                current = current.with_second(0).unwrap().with_minute(0).unwrap().with_hour(0).unwrap();
177                continue;
178            }
179            if !self.month.matches(month) {
180                current += chrono::Duration::days(28);
181                continue;
182            }
183            if !self.day_of_week.matches(weekday) {
184                current += chrono::Duration::days(1);
185                current = current.with_second(0).unwrap().with_minute(0).unwrap().with_hour(0).unwrap();
186                continue;
187            }
188
189            return Some(current);
190        }
191
192        None
193    }
194}
195
196/// Cron 任务
197pub struct CronTask {
198    /// 任务 ID
199    pub id: TaskId,
200    /// 任务名称
201    pub name: String,
202    /// Cron 表达式
203    pub expression: CronExpression,
204    /// 任务
205    pub task: Arc<dyn ScheduledTask>,
206}
207
208impl std::fmt::Debug for CronTask {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        f.debug_struct("CronTask")
211            .field("id", &self.id)
212            .field("name", &self.name)
213            .field("expression", &self.expression)
214            .finish()
215    }
216}
217
218/// Cron 调度器配置
219#[derive(Debug, Clone)]
220pub struct CronSchedulerConfig {
221    /// 轮询间隔
222    pub poll_interval: Duration,
223    /// 最大并发任务数
224    pub max_concurrent_tasks: usize,
225}
226
227impl Default for CronSchedulerConfig {
228    fn default() -> Self {
229        Self { poll_interval: Duration::from_secs(1), max_concurrent_tasks: 100 }
230    }
231}
232
233/// Cron 调度器
234///
235/// 基于 Cron 表达式的任务调度器。
236pub struct CronScheduler {
237    /// 配置
238    #[allow(dead_code)]
239    config: CronSchedulerConfig,
240    /// Cron 任务映射
241    cron_tasks: Arc<RwLock<BTreeMap<TaskId, CronTask>>>,
242    /// 任务句柄映射
243    handles: Arc<RwLock<BTreeMap<TaskId, TaskHandle>>>,
244    /// 下次执行时间映射
245    next_executions: Arc<RwLock<BTreeMap<TaskId, DateTime<Utc>>>>,
246    /// 关闭信号
247    shutdown_tx: broadcast::Sender<()>,
248    /// 是否已关闭
249    is_shutdown: Arc<AtomicBool>,
250}
251
252impl CronScheduler {
253    /// 创建新的 Cron 调度器
254    pub fn new(config: CronSchedulerConfig) -> Self {
255        let (shutdown_tx, _) = broadcast::channel(1);
256        let cron_tasks: Arc<RwLock<BTreeMap<TaskId, CronTask>>> = Arc::new(RwLock::new(BTreeMap::new()));
257        let handles: Arc<RwLock<BTreeMap<TaskId, TaskHandle>>> = Arc::new(RwLock::new(BTreeMap::new()));
258        let next_executions = Arc::new(RwLock::new(BTreeMap::new()));
259        let is_shutdown = Arc::new(AtomicBool::new(false));
260
261        let cron_tasks_clone = cron_tasks.clone();
262        let handles_clone = handles.clone();
263        let next_executions_clone = next_executions.clone();
264        let is_shutdown_clone = is_shutdown.clone();
265        let mut shutdown_rx = shutdown_tx.subscribe();
266        let poll_interval = config.poll_interval;
267
268        tokio::spawn(async move {
269            loop {
270                tokio::select! {
271                    _ = shutdown_rx.recv() => {
272                        debug!("Cron scheduler received shutdown signal");
273                        break;
274                    }
275                    _ = tokio::time::sleep(poll_interval) => {
276                        if is_shutdown_clone.load(Ordering::SeqCst) {
277                            break;
278                        }
279
280                        let now = Utc::now();
281                        let tasks = cron_tasks_clone.read().await;
282
283                        for (task_id, cron_task) in tasks.iter() {
284                            let next_exec = next_executions_clone.read().await.get(task_id).cloned();
285
286                            if let Some(next) = next_exec
287                                && now >= next
288                            {
289                                let task = cron_task.task.clone();
290                                let handle = handles_clone.read().await.get(task_id).cloned();
291                                let expression = cron_task.expression.clone();
292                                let next_executions_ref = next_executions_clone.clone();
293                                let task_id_clone = task_id.clone();
294
295                                tokio::spawn(async move {
296                                    if let Some(h) = &handle {
297                                        h.set_state(TaskState::Running).await;
298                                    }
299
300                                    let result = task.execute().await;
301
302                                    if let Some(h) = &handle {
303                                        match result {
304                                            Ok(()) => {
305                                                h.record_execution().await;
306                                                h.set_state(TaskState::Pending).await;
307                                            }
308                                            Err(e) => {
309                                                h.record_error(e.to_string()).await;
310                                                h.set_state(TaskState::Failed).await;
311                                                error!("Cron task {} execution failed: {}", h.name, e);
312                                            }
313                                        }
314                                    }
315
316                                    if let Some(next_time) = expression.next_execution(now) {
317                                        next_executions_ref.write().await.insert(task_id_clone, next_time);
318                                    }
319                                });
320                            }
321                        }
322                    }
323                }
324            }
325        });
326
327        Self { config, cron_tasks, handles, next_executions, shutdown_tx, is_shutdown }
328    }
329
330    /// 使用默认配置创建调度器
331    pub fn default_config() -> Self {
332        Self::new(CronSchedulerConfig::default())
333    }
334
335    /// 注册 Cron 任务
336    ///
337    /// # 参数
338    ///
339    /// - `task`: 要执行的任务
340    /// - `expression`: Cron 表达式
341    ///
342    /// # 返回值
343    ///
344    /// 返回任务句柄。
345    pub async fn schedule_cron(&self, task: Arc<dyn ScheduledTask>, expression: &str) -> WaeResult<TaskHandle> {
346        if self.is_shutdown.load(Ordering::SeqCst) {
347            return Err(WaeError::scheduler_shutdown());
348        }
349
350        let cron_expr = CronExpression::parse(expression)?;
351        let task_id = Uuid::new_v4().to_string();
352        let handle = TaskHandle::new(task_id.clone(), task.name().to_string());
353
354        let next_execution = cron_expr
355            .next_execution(Utc::now())
356            .ok_or_else(|| WaeError::invalid_cron_expression("Cannot determine next execution time"))?;
357
358        {
359            let mut handles = self.handles.write().await;
360            handles.insert(task_id.clone(), handle.clone());
361        }
362
363        {
364            let cron_task = CronTask { id: task_id.clone(), name: task.name().to_string(), expression: cron_expr, task };
365            let mut tasks = self.cron_tasks.write().await;
366            tasks.insert(task_id.clone(), cron_task);
367        }
368
369        {
370            let mut next_execs = self.next_executions.write().await;
371            next_execs.insert(task_id.clone(), next_execution);
372        }
373
374        info!("Scheduled cron task: {} (expression: {})", handle.name, expression);
375        Ok(handle)
376    }
377
378    /// 取消任务
379    pub async fn cancel_task(&self, task_id: &str) -> WaeResult<bool> {
380        {
381            let mut tasks = self.cron_tasks.write().await;
382            tasks.remove(task_id);
383        }
384
385        {
386            let mut next_execs = self.next_executions.write().await;
387            next_execs.remove(task_id);
388        }
389
390        let mut handles = self.handles.write().await;
391        if let Some(handle) = handles.remove(task_id) {
392            handle.cancel();
393            handle.set_state(TaskState::Cancelled).await;
394            info!("Cancelled cron task: {}", task_id);
395            Ok(true)
396        }
397        else {
398            Err(WaeError::task_not_found(task_id))
399        }
400    }
401
402    /// 获取任务句柄
403    pub async fn get_handle(&self, task_id: &str) -> Option<TaskHandle> {
404        self.handles.read().await.get(task_id).cloned()
405    }
406
407    /// 获取所有任务句柄
408    pub async fn get_all_handles(&self) -> Vec<TaskHandle> {
409        self.handles.read().await.values().cloned().collect()
410    }
411
412    /// 获取下次执行时间
413    pub async fn get_next_execution(&self, task_id: &str) -> Option<DateTime<Utc>> {
414        self.next_executions.read().await.get(task_id).cloned()
415    }
416
417    /// 关闭调度器
418    pub fn shutdown(&self) {
419        self.is_shutdown.store(true, Ordering::SeqCst);
420        let _ = self.shutdown_tx.send(());
421        info!("Cron scheduler shutdown initiated");
422    }
423
424    /// 检查是否已关闭
425    pub fn is_shutdown(&self) -> bool {
426        self.is_shutdown.load(Ordering::SeqCst)
427    }
428}
429
430/// 便捷函数:创建间隔任务调度器
431pub fn interval_scheduler() -> IntervalScheduler {
432    IntervalScheduler::default_config()
433}
434
435/// 便捷函数:创建 Cron 调度器
436pub fn cron_scheduler() -> CronScheduler {
437    CronScheduler::default_config()
438}