1use chrono::{DateTime, Datelike, Timelike, Utc};
2use serde::{Deserialize, Serialize};
3use std::{
4 collections::BTreeMap,
5 sync::{
6 Arc,
7 atomic::{AtomicBool, Ordering},
8 },
9 time::Duration,
10};
11use tokio::sync::{RwLock, broadcast};
12use tracing::{debug, error, info};
13use uuid::Uuid;
14
15use crate::{
16 interval::IntervalScheduler,
17 task::{ScheduledTask, TaskHandle, TaskId, TaskState},
18};
19use wae_types::{WaeError, WaeResult};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub enum CronField {
24 Any,
26 Value(u32),
28 Range(u32, u32),
30 Step(u32, u32),
32 List(Vec<u32>),
34}
35
36impl CronField {
37 pub fn matches(&self, value: u32) -> bool {
39 match self {
40 CronField::Any => true,
41 CronField::Value(v) => *v == value,
42 CronField::Range(start, end) => value >= *start && value <= *end,
43 CronField::Step(start, step) => (value - start).is_multiple_of(*step),
44 CronField::List(values) => values.contains(&value),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CronExpression {
54 pub second: CronField,
56 pub minute: CronField,
58 pub hour: CronField,
60 pub day_of_month: CronField,
62 pub month: CronField,
64 pub day_of_week: CronField,
66}
67
68impl CronExpression {
69 pub fn parse(expression: &str) -> WaeResult<Self> {
81 let parts: Vec<&str> = expression.split_whitespace().collect();
82 if parts.len() != 6 {
83 return Err(WaeError::invalid_cron_expression("Expected 6 fields: second minute hour day month weekday"));
84 }
85
86 Ok(Self {
87 second: Self::parse_field(parts[0], 0, 59)?,
88 minute: Self::parse_field(parts[1], 0, 59)?,
89 hour: Self::parse_field(parts[2], 0, 23)?,
90 day_of_month: Self::parse_field(parts[3], 1, 31)?,
91 month: Self::parse_field(parts[4], 1, 12)?,
92 day_of_week: Self::parse_field(parts[5], 0, 6)?,
93 })
94 }
95
96 fn parse_field(s: &str, min: u32, max: u32) -> WaeResult<CronField> {
97 if s == "*" {
98 return Ok(CronField::Any);
99 }
100
101 if s.contains('/') {
102 let parts: Vec<&str> = s.split('/').collect();
103 if parts.len() != 2 {
104 return Err(WaeError::invalid_cron_expression(format!("Invalid step expression: {}", s)));
105 }
106 let start = if parts[0] == "*" {
107 min
108 }
109 else {
110 parts[0]
111 .parse::<u32>()
112 .map_err(|_| WaeError::invalid_cron_expression(format!("Invalid value: {}", parts[0])))?
113 };
114 let step = parts[1]
115 .parse::<u32>()
116 .map_err(|_| WaeError::invalid_cron_expression(format!("Invalid step: {}", parts[1])))?;
117 return Ok(CronField::Step(start, step));
118 }
119
120 if s.contains('-') {
121 let parts: Vec<&str> = s.split('-').collect();
122 if parts.len() != 2 {
123 return Err(WaeError::invalid_cron_expression(format!("Invalid range expression: {}", s)));
124 }
125 let start = parts[0]
126 .parse::<u32>()
127 .map_err(|_| WaeError::invalid_cron_expression(format!("Invalid start: {}", parts[0])))?;
128 let end =
129 parts[1].parse::<u32>().map_err(|_| WaeError::invalid_cron_expression(format!("Invalid end: {}", parts[1])))?;
130 return Ok(CronField::Range(start, end));
131 }
132
133 if s.contains(',') {
134 let values: Result<Vec<u32>, _> = s.split(',').map(|v| v.parse::<u32>()).collect();
135 let values = values.map_err(|_| WaeError::invalid_cron_expression(format!("Invalid list: {}", s)))?;
136 return Ok(CronField::List(values));
137 }
138
139 let value = s.parse::<u32>().map_err(|_| WaeError::invalid_cron_expression(format!("Invalid value: {}", s)))?;
140
141 if value < min || value > max {
142 return Err(WaeError::invalid_cron_expression(format!("Value {} out of range [{}, {}]", value, min, max)));
143 }
144
145 Ok(CronField::Value(value))
146 }
147
148 pub fn next_execution(&self, from: DateTime<Utc>) -> Option<DateTime<Utc>> {
150 let mut current = from + chrono::Duration::seconds(1);
151
152 for _ in 0..366 * 24 * 60 * 60 {
153 let second = current.second();
154 let minute = current.minute();
155 let hour = current.hour();
156 let day = current.day();
157 let month = current.month();
158 let weekday = current.weekday().num_days_from_sunday();
159
160 if !self.second.matches(second) {
161 current += chrono::Duration::seconds(1);
162 continue;
163 }
164 if !self.minute.matches(minute) {
165 current += chrono::Duration::minutes(1);
166 current = current.with_second(0).unwrap();
167 continue;
168 }
169 if !self.hour.matches(hour) {
170 current += chrono::Duration::hours(1);
171 current = current.with_second(0).unwrap().with_minute(0).unwrap();
172 continue;
173 }
174 if !self.day_of_month.matches(day) {
175 current += chrono::Duration::days(1);
176 current = current.with_second(0).unwrap().with_minute(0).unwrap().with_hour(0).unwrap();
177 continue;
178 }
179 if !self.month.matches(month) {
180 current += chrono::Duration::days(28);
181 continue;
182 }
183 if !self.day_of_week.matches(weekday) {
184 current += chrono::Duration::days(1);
185 current = current.with_second(0).unwrap().with_minute(0).unwrap().with_hour(0).unwrap();
186 continue;
187 }
188
189 return Some(current);
190 }
191
192 None
193 }
194}
195
196pub struct CronTask {
198 pub id: TaskId,
200 pub name: String,
202 pub expression: CronExpression,
204 pub task: Arc<dyn ScheduledTask>,
206}
207
208impl std::fmt::Debug for CronTask {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 f.debug_struct("CronTask")
211 .field("id", &self.id)
212 .field("name", &self.name)
213 .field("expression", &self.expression)
214 .finish()
215 }
216}
217
218#[derive(Debug, Clone)]
220pub struct CronSchedulerConfig {
221 pub poll_interval: Duration,
223 pub max_concurrent_tasks: usize,
225}
226
227impl Default for CronSchedulerConfig {
228 fn default() -> Self {
229 Self { poll_interval: Duration::from_secs(1), max_concurrent_tasks: 100 }
230 }
231}
232
233pub struct CronScheduler {
237 #[allow(dead_code)]
239 config: CronSchedulerConfig,
240 cron_tasks: Arc<RwLock<BTreeMap<TaskId, CronTask>>>,
242 handles: Arc<RwLock<BTreeMap<TaskId, TaskHandle>>>,
244 next_executions: Arc<RwLock<BTreeMap<TaskId, DateTime<Utc>>>>,
246 shutdown_tx: broadcast::Sender<()>,
248 is_shutdown: Arc<AtomicBool>,
250}
251
252impl CronScheduler {
253 pub fn new(config: CronSchedulerConfig) -> Self {
255 let (shutdown_tx, _) = broadcast::channel(1);
256 let cron_tasks: Arc<RwLock<BTreeMap<TaskId, CronTask>>> = Arc::new(RwLock::new(BTreeMap::new()));
257 let handles: Arc<RwLock<BTreeMap<TaskId, TaskHandle>>> = Arc::new(RwLock::new(BTreeMap::new()));
258 let next_executions = Arc::new(RwLock::new(BTreeMap::new()));
259 let is_shutdown = Arc::new(AtomicBool::new(false));
260
261 let cron_tasks_clone = cron_tasks.clone();
262 let handles_clone = handles.clone();
263 let next_executions_clone = next_executions.clone();
264 let is_shutdown_clone = is_shutdown.clone();
265 let mut shutdown_rx = shutdown_tx.subscribe();
266 let poll_interval = config.poll_interval;
267
268 tokio::spawn(async move {
269 loop {
270 tokio::select! {
271 _ = shutdown_rx.recv() => {
272 debug!("Cron scheduler received shutdown signal");
273 break;
274 }
275 _ = tokio::time::sleep(poll_interval) => {
276 if is_shutdown_clone.load(Ordering::SeqCst) {
277 break;
278 }
279
280 let now = Utc::now();
281 let tasks = cron_tasks_clone.read().await;
282
283 for (task_id, cron_task) in tasks.iter() {
284 let next_exec = next_executions_clone.read().await.get(task_id).cloned();
285
286 if let Some(next) = next_exec
287 && now >= next
288 {
289 let task = cron_task.task.clone();
290 let handle = handles_clone.read().await.get(task_id).cloned();
291 let expression = cron_task.expression.clone();
292 let next_executions_ref = next_executions_clone.clone();
293 let task_id_clone = task_id.clone();
294
295 tokio::spawn(async move {
296 if let Some(h) = &handle {
297 h.set_state(TaskState::Running).await;
298 }
299
300 let result = task.execute().await;
301
302 if let Some(h) = &handle {
303 match result {
304 Ok(()) => {
305 h.record_execution().await;
306 h.set_state(TaskState::Pending).await;
307 }
308 Err(e) => {
309 h.record_error(e.to_string()).await;
310 h.set_state(TaskState::Failed).await;
311 error!("Cron task {} execution failed: {}", h.name, e);
312 }
313 }
314 }
315
316 if let Some(next_time) = expression.next_execution(now) {
317 next_executions_ref.write().await.insert(task_id_clone, next_time);
318 }
319 });
320 }
321 }
322 }
323 }
324 }
325 });
326
327 Self { config, cron_tasks, handles, next_executions, shutdown_tx, is_shutdown }
328 }
329
330 pub fn default_config() -> Self {
332 Self::new(CronSchedulerConfig::default())
333 }
334
335 pub async fn schedule_cron(&self, task: Arc<dyn ScheduledTask>, expression: &str) -> WaeResult<TaskHandle> {
346 if self.is_shutdown.load(Ordering::SeqCst) {
347 return Err(WaeError::scheduler_shutdown());
348 }
349
350 let cron_expr = CronExpression::parse(expression)?;
351 let task_id = Uuid::new_v4().to_string();
352 let handle = TaskHandle::new(task_id.clone(), task.name().to_string());
353
354 let next_execution = cron_expr
355 .next_execution(Utc::now())
356 .ok_or_else(|| WaeError::invalid_cron_expression("Cannot determine next execution time"))?;
357
358 {
359 let mut handles = self.handles.write().await;
360 handles.insert(task_id.clone(), handle.clone());
361 }
362
363 {
364 let cron_task = CronTask { id: task_id.clone(), name: task.name().to_string(), expression: cron_expr, task };
365 let mut tasks = self.cron_tasks.write().await;
366 tasks.insert(task_id.clone(), cron_task);
367 }
368
369 {
370 let mut next_execs = self.next_executions.write().await;
371 next_execs.insert(task_id.clone(), next_execution);
372 }
373
374 info!("Scheduled cron task: {} (expression: {})", handle.name, expression);
375 Ok(handle)
376 }
377
378 pub async fn cancel_task(&self, task_id: &str) -> WaeResult<bool> {
380 {
381 let mut tasks = self.cron_tasks.write().await;
382 tasks.remove(task_id);
383 }
384
385 {
386 let mut next_execs = self.next_executions.write().await;
387 next_execs.remove(task_id);
388 }
389
390 let mut handles = self.handles.write().await;
391 if let Some(handle) = handles.remove(task_id) {
392 handle.cancel();
393 handle.set_state(TaskState::Cancelled).await;
394 info!("Cancelled cron task: {}", task_id);
395 Ok(true)
396 }
397 else {
398 Err(WaeError::task_not_found(task_id))
399 }
400 }
401
402 pub async fn get_handle(&self, task_id: &str) -> Option<TaskHandle> {
404 self.handles.read().await.get(task_id).cloned()
405 }
406
407 pub async fn get_all_handles(&self) -> Vec<TaskHandle> {
409 self.handles.read().await.values().cloned().collect()
410 }
411
412 pub async fn get_next_execution(&self, task_id: &str) -> Option<DateTime<Utc>> {
414 self.next_executions.read().await.get(task_id).cloned()
415 }
416
417 pub fn shutdown(&self) {
419 self.is_shutdown.store(true, Ordering::SeqCst);
420 let _ = self.shutdown_tx.send(());
421 info!("Cron scheduler shutdown initiated");
422 }
423
424 pub fn is_shutdown(&self) -> bool {
426 self.is_shutdown.load(Ordering::SeqCst)
427 }
428}
429
430pub fn interval_scheduler() -> IntervalScheduler {
432 IntervalScheduler::default_config()
433}
434
435pub fn cron_scheduler() -> CronScheduler {
437 CronScheduler::default_config()
438}