tulpje_framework/
scheduler.rs

1use std::collections::HashMap;
2
3use async_cron_scheduler::{Job, JobId, Scheduler as CronScheduler};
4use chrono::Utc;
5use tokio::{sync::mpsc, task::JoinHandle};
6use tokio_util::sync::CancellationToken;
7
8use crate::{
9    Error,
10    context::{Context, TaskContext},
11    handler::task_handler::TaskHandler,
12};
13
14pub enum SchedulerTaskMessage<T: Clone + Send + Sync> {
15    Start(Vec<TaskHandler<T>>),
16    Enable(Box<TaskHandler<T>>),
17    Disable(String),
18}
19
20pub struct SchedulerHandle<T: Clone + Send + Sync> {
21    tasks: Vec<TaskHandler<T>>,
22    sender: mpsc::UnboundedSender<SchedulerTaskMessage<T>>,
23    shutdown: CancellationToken,
24    handle: Option<JoinHandle<()>>,
25}
26impl<T: Clone + Send + Sync + 'static> SchedulerHandle<T> {
27    pub(crate) fn new(tasks: Vec<TaskHandler<T>>, ctx: Context<T>) -> Self {
28        let (sender, receiver) = mpsc::unbounded_channel();
29        let shutdown = CancellationToken::new();
30
31        let mut scheduler = Scheduler::new(ctx, receiver, shutdown.clone());
32        let handle = Some(tokio::spawn(async move { scheduler.run().await }));
33
34        Self {
35            tasks,
36            sender,
37            shutdown,
38            handle,
39        }
40    }
41
42    pub(crate) fn shutdown(&mut self) {
43        self.shutdown.cancel();
44    }
45
46    pub(crate) fn start(
47        &mut self,
48    ) -> Result<(), Box<mpsc::error::SendError<SchedulerTaskMessage<T>>>> {
49        Ok(self
50            .sender
51            .send(SchedulerTaskMessage::Start(self.tasks.clone()))?)
52    }
53
54    pub fn enable_task(
55        &mut self,
56        handler: TaskHandler<T>,
57    ) -> Result<(), Box<mpsc::error::SendError<SchedulerTaskMessage<T>>>> {
58        Ok(self
59            .sender
60            .send(SchedulerTaskMessage::Enable(Box::new(handler)))?)
61    }
62
63    pub fn disable_task(
64        &mut self,
65        name: String,
66    ) -> Result<(), Box<mpsc::error::SendError<SchedulerTaskMessage<T>>>> {
67        Ok(self.sender.send(SchedulerTaskMessage::Disable(name))?)
68    }
69
70    pub(crate) async fn join(&mut self) -> Result<(), Error> {
71        Ok(self
72            .handle
73            .take()
74            .ok_or("Scheduler already shutdown")?
75            .await?)
76    }
77}
78
79struct Scheduler<T: Clone + Send + Sync> {
80    job_map: HashMap<String, JobId>,
81    scheduler: Option<CronScheduler<Utc>>,
82    handle: Option<JoinHandle<()>>,
83
84    ctx: Context<T>,
85    receiver: mpsc::UnboundedReceiver<SchedulerTaskMessage<T>>,
86    shutdown: CancellationToken,
87}
88
89impl<T: Clone + Send + Sync + 'static> Scheduler<T> {
90    fn new(
91        ctx: Context<T>,
92        receiver: mpsc::UnboundedReceiver<SchedulerTaskMessage<T>>,
93        shutdown: CancellationToken,
94    ) -> Self {
95        let (scheduler, service) = CronScheduler::<Utc>::launch(tokio::time::sleep);
96
97        Self {
98            ctx,
99            receiver,
100            shutdown,
101
102            job_map: HashMap::new(),
103            scheduler: Some(scheduler),
104            handle: Some(tokio::spawn(service)),
105        }
106    }
107
108    pub async fn enable_task(&mut self, handler: TaskHandler<T>) {
109        let local_ctx = self.ctx.clone();
110
111        let job = Job::<Utc>::cron_schedule(handler.cron.clone());
112        let job_name = handler.name.clone();
113        let job_id = self
114            .scheduler
115            .as_mut()
116            .unwrap()
117            .insert(job, move |_id| {
118                let job_ctx = local_ctx.clone();
119                let job_handler = handler.clone();
120
121                tokio::spawn(async move {
122                    if let Err(err) = job_handler.run(TaskContext::from_context(job_ctx)).await {
123                        tracing::error!("error running task {}: {}", job_handler.name, err);
124                    };
125                });
126            })
127            .await;
128
129        self.job_map.insert(job_name, job_id);
130    }
131
132    pub async fn disable_task(&mut self, name: &str) {
133        let Some(job_id) = self.job_map.remove(name) else {
134            return;
135        };
136
137        self.scheduler.as_mut().unwrap().remove(job_id).await;
138    }
139
140    async fn run(&mut self) {
141        loop {
142            tokio::select! {
143                Some(msg) = self.receiver.recv() => {
144                    match msg {
145                        SchedulerTaskMessage::Start(tasks) => {
146                            for task in tasks {
147                                self.enable_task(task.clone()).await;
148                            }
149                        },
150                        SchedulerTaskMessage::Enable(task) => self.enable_task(*task).await,
151                        SchedulerTaskMessage::Disable(name) => self.disable_task(&name).await,
152                    }
153                },
154                () = self.shutdown.cancelled() => break,
155            }
156        }
157
158        // drain the jobs from the job map and also take the scheduler
159        // removing the jobs, and taking the scheduler from the runner should
160        // cause the scheduler to be dropped and thus stop
161        //
162        // NOTE: Separate scope so we drop correctly after removing jobs
163        {
164            let Some(mut scheduler) = self.scheduler.take() else {
165                tracing::warn!("Scheduler already removed");
166                return;
167            };
168
169            for (_, job) in self.job_map.drain() {
170                scheduler.remove(job).await;
171            }
172        }
173
174        let Some(handle) = self.handle.take() else {
175            tracing::warn!("CronScheduler already shutdown");
176            return;
177        };
178
179        if let Err(err) = handle.await {
180            tracing::warn!("Error joining CronScheduler: {}", err);
181        }
182    }
183}