tulpje_framework/
scheduler.rs1use std::collections::HashMap;
2
3use async_cron_scheduler::{Job, JobId, Scheduler as CronScheduler};
4use chrono::Utc;
5use tokio::{sync::mpsc, task::JoinHandle};
6use tokio_util::sync::CancellationToken;
7
8use crate::{
9 Error,
10 context::{Context, TaskContext},
11 handler::task_handler::TaskHandler,
12};
13
14pub enum SchedulerTaskMessage<T: Clone + Send + Sync> {
15 Start(Vec<TaskHandler<T>>),
16 Enable(Box<TaskHandler<T>>),
17 Disable(String),
18}
19
20pub struct SchedulerHandle<T: Clone + Send + Sync> {
21 tasks: Vec<TaskHandler<T>>,
22 sender: mpsc::UnboundedSender<SchedulerTaskMessage<T>>,
23 shutdown: CancellationToken,
24 handle: Option<JoinHandle<()>>,
25}
26impl<T: Clone + Send + Sync + 'static> SchedulerHandle<T> {
27 pub(crate) fn new(tasks: Vec<TaskHandler<T>>, ctx: Context<T>) -> Self {
28 let (sender, receiver) = mpsc::unbounded_channel();
29 let shutdown = CancellationToken::new();
30
31 let mut scheduler = Scheduler::new(ctx, receiver, shutdown.clone());
32 let handle = Some(tokio::spawn(async move { scheduler.run().await }));
33
34 Self {
35 tasks,
36 sender,
37 shutdown,
38 handle,
39 }
40 }
41
42 pub(crate) fn shutdown(&mut self) {
43 self.shutdown.cancel();
44 }
45
46 pub(crate) fn start(
47 &mut self,
48 ) -> Result<(), Box<mpsc::error::SendError<SchedulerTaskMessage<T>>>> {
49 Ok(self
50 .sender
51 .send(SchedulerTaskMessage::Start(self.tasks.clone()))?)
52 }
53
54 pub fn enable_task(
55 &mut self,
56 handler: TaskHandler<T>,
57 ) -> Result<(), Box<mpsc::error::SendError<SchedulerTaskMessage<T>>>> {
58 Ok(self
59 .sender
60 .send(SchedulerTaskMessage::Enable(Box::new(handler)))?)
61 }
62
63 pub fn disable_task(
64 &mut self,
65 name: String,
66 ) -> Result<(), Box<mpsc::error::SendError<SchedulerTaskMessage<T>>>> {
67 Ok(self.sender.send(SchedulerTaskMessage::Disable(name))?)
68 }
69
70 pub(crate) async fn join(&mut self) -> Result<(), Error> {
71 Ok(self
72 .handle
73 .take()
74 .ok_or("Scheduler already shutdown")?
75 .await?)
76 }
77}
78
79struct Scheduler<T: Clone + Send + Sync> {
80 job_map: HashMap<String, JobId>,
81 scheduler: Option<CronScheduler<Utc>>,
82 handle: Option<JoinHandle<()>>,
83
84 ctx: Context<T>,
85 receiver: mpsc::UnboundedReceiver<SchedulerTaskMessage<T>>,
86 shutdown: CancellationToken,
87}
88
89impl<T: Clone + Send + Sync + 'static> Scheduler<T> {
90 fn new(
91 ctx: Context<T>,
92 receiver: mpsc::UnboundedReceiver<SchedulerTaskMessage<T>>,
93 shutdown: CancellationToken,
94 ) -> Self {
95 let (scheduler, service) = CronScheduler::<Utc>::launch(tokio::time::sleep);
96
97 Self {
98 ctx,
99 receiver,
100 shutdown,
101
102 job_map: HashMap::new(),
103 scheduler: Some(scheduler),
104 handle: Some(tokio::spawn(service)),
105 }
106 }
107
108 pub async fn enable_task(&mut self, handler: TaskHandler<T>) {
109 let local_ctx = self.ctx.clone();
110
111 let job = Job::<Utc>::cron_schedule(handler.cron.clone());
112 let job_name = handler.name.clone();
113 let job_id = self
114 .scheduler
115 .as_mut()
116 .unwrap()
117 .insert(job, move |_id| {
118 let job_ctx = local_ctx.clone();
119 let job_handler = handler.clone();
120
121 tokio::spawn(async move {
122 if let Err(err) = job_handler.run(TaskContext::from_context(job_ctx)).await {
123 tracing::error!("error running task {}: {}", job_handler.name, err);
124 };
125 });
126 })
127 .await;
128
129 self.job_map.insert(job_name, job_id);
130 }
131
132 pub async fn disable_task(&mut self, name: &str) {
133 let Some(job_id) = self.job_map.remove(name) else {
134 return;
135 };
136
137 self.scheduler.as_mut().unwrap().remove(job_id).await;
138 }
139
140 async fn run(&mut self) {
141 loop {
142 tokio::select! {
143 Some(msg) = self.receiver.recv() => {
144 match msg {
145 SchedulerTaskMessage::Start(tasks) => {
146 for task in tasks {
147 self.enable_task(task.clone()).await;
148 }
149 },
150 SchedulerTaskMessage::Enable(task) => self.enable_task(*task).await,
151 SchedulerTaskMessage::Disable(name) => self.disable_task(&name).await,
152 }
153 },
154 () = self.shutdown.cancelled() => break,
155 }
156 }
157
158 {
164 let Some(mut scheduler) = self.scheduler.take() else {
165 tracing::warn!("Scheduler already removed");
166 return;
167 };
168
169 for (_, job) in self.job_map.drain() {
170 scheduler.remove(job).await;
171 }
172 }
173
174 let Some(handle) = self.handle.take() else {
175 tracing::warn!("CronScheduler already shutdown");
176 return;
177 };
178
179 if let Err(err) = handle.await {
180 tracing::warn!("Error joining CronScheduler: {}", err);
181 }
182 }
183}