task_supervisor/supervisor/
mod.rs1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5 collections::HashMap,
6 time::{Duration, Instant},
7};
8
9use handle::SupervisorMessage;
10use tokio::{sync::mpsc, time::interval_at};
11use tokio_util::sync::CancellationToken;
12
13use crate::{
14 supervisor::handle::SupervisorHandle,
15 task::{TaskHandle, TaskOutcome, TaskStatus},
16};
17
18#[derive(Debug, Clone)]
20pub(crate) struct Heartbeat {
21 pub(crate) task_name: String,
22 pub(crate) timestamp: Instant,
23}
24
25impl Heartbeat {
26 pub fn new(task_name: &String) -> Self {
28 Self {
29 task_name: task_name.to_string(),
30 timestamp: Instant::now(),
31 }
32 }
33}
34
35#[derive(Debug)]
37pub(crate) enum SupervisedTaskMessage {
38 Heartbeat(Heartbeat),
40 Restart(String),
42 Completed(String, TaskOutcome),
44 Shutdown,
46}
47
48pub struct Supervisor {
54 tasks: HashMap<String, TaskHandle>,
56 timeout_threshold: Duration,
58 heartbeat_interval: Duration,
59 health_check_initial_delay: Duration,
60 health_check_interval: Duration,
61 base_restart_delay: Duration,
62 task_is_stable_after: Duration,
63 max_restart_attempts: u32,
64 external_tx: mpsc::UnboundedSender<SupervisorMessage>,
66 external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
67 internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
69 internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
70}
71
72impl Supervisor {
73 pub fn run(self) -> SupervisorHandle {
77 let user_tx = self.external_tx.clone();
78 let handle = tokio::spawn(async move {
79 self.run_and_supervise().await;
80 });
81 SupervisorHandle::new(handle, user_tx)
82 }
83
84 async fn run_and_supervise(mut self) {
86 self.start_all_tasks().await;
87 self.supervise_all_tasks().await;
88 }
89
90 async fn start_all_tasks(&mut self) {
92 for (task_name, task_handle) in self.tasks.iter_mut() {
93 Self::start_task(
94 task_name.to_string(),
95 task_handle,
96 self.internal_tx.clone(),
97 self.heartbeat_interval,
98 )
99 .await;
100 }
101 }
102
103 async fn supervise_all_tasks(&mut self) {
105 let mut health_check_interval = interval_at(
106 tokio::time::Instant::now() + self.health_check_initial_delay,
107 self.health_check_interval,
108 );
109
110 loop {
111 tokio::select! {
112 Some(internal_msg) = self.internal_rx.recv() => {
113 if matches!(internal_msg, SupervisedTaskMessage::Shutdown) {
115 return;
116 }
117 self.handle_internal_message(internal_msg).await;
118 },
119 Some(user_msg) = self.external_rx.recv() => {
120 self.handle_user_message(user_msg).await;
121 },
122 _ = health_check_interval.tick() => {
123 self.check_all_health();
124 }
125 }
126 }
127 }
128
129 async fn handle_internal_message(&mut self, msg: SupervisedTaskMessage) {
130 match msg {
131 SupervisedTaskMessage::Heartbeat(heartbeat) => {
132 self.register_heartbeat(heartbeat);
133 }
134 SupervisedTaskMessage::Restart(task_name) => {
135 self.restart_task(task_name).await;
136 }
137 SupervisedTaskMessage::Completed(task_name, outcome) => {
138 self.handle_task_completion(task_name, outcome).await;
139 }
140 SupervisedTaskMessage::Shutdown => unreachable!(),
141 }
142 }
143
144 async fn handle_user_message(&mut self, msg: SupervisorMessage) {
146 match msg {
147 SupervisorMessage::AddTask(task_name, task) => {
148 if self.tasks.contains_key(&task_name) {
149 return;
150 }
151 let mut task_handle =
152 TaskHandle::new(task, self.max_restart_attempts, self.base_restart_delay);
153 Self::start_task(
154 task_name.clone(),
155 &mut task_handle,
156 self.internal_tx.clone(),
157 self.heartbeat_interval,
158 )
159 .await;
160 self.tasks.insert(task_name, task_handle);
161 }
162 SupervisorMessage::RestartTask(task_name) => {
163 self.restart_task(task_name).await;
164 }
165 SupervisorMessage::KillTask(task_name) => {
166 let Some(task_handle) = self.tasks.get_mut(&task_name) else {
167 return;
168 };
169 if task_handle.status == TaskStatus::Dead {
170 return;
171 }
172 task_handle.mark(TaskStatus::Dead);
173 task_handle.clean().await;
174 }
175 SupervisorMessage::GetTaskStatus(task_name, sender) => {
176 let status = self.tasks.get(&task_name).map(|handle| handle.status);
177 let _ = sender.send(status);
178 }
179 SupervisorMessage::GetAllTaskStatuses(sender) => {
180 let statuses = self
181 .tasks
182 .iter()
183 .map(|(name, handle)| (name.clone(), handle.status))
184 .collect();
185 let _ = sender.send(statuses);
186 }
187 SupervisorMessage::Shutdown => {
188 for (_, task_handle) in self.tasks.iter_mut() {
189 if task_handle.status != TaskStatus::Dead {
190 task_handle.clean().await;
191 task_handle.mark(TaskStatus::Dead);
192 }
193 }
194 let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
195 }
196 }
197 }
198
199 async fn start_task(
201 task_name: String,
202 task_handle: &mut TaskHandle,
203 internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
204 heartbeat_interval: Duration,
205 ) {
206 let token = CancellationToken::new();
207
208 let (completion_tx, mut completion_rx) = mpsc::channel::<TaskOutcome>(1);
210 let tx_heartbeat = internal_tx.clone();
211 let task_name_clone = task_name.clone();
212 let token_completion = token.clone();
213 let completion_task = tokio::spawn(async move {
214 tokio::select! {
215 _ = token_completion.cancelled() => {
216 }
218 Some(outcome) = completion_rx.recv() => {
219 let completion_msg = SupervisedTaskMessage::Completed(task_name_clone, outcome);
220 let _ = tx_heartbeat.send(completion_msg);
221 token_completion.cancel(); }
223 }
224 });
225
226 let task_name_heartbeat = task_name.clone();
228 let token_heartbeat = token.clone();
229 let heartbeat_task = tokio::spawn(async move {
230 let mut beat_interval = tokio::time::interval(heartbeat_interval);
231 loop {
232 tokio::select! {
233 _ = beat_interval.tick() => {
234 let beat = SupervisedTaskMessage::Heartbeat(Heartbeat::new(&task_name_heartbeat));
235 if internal_tx.send(beat).is_err() {
236 break;
237 }
238 }
239 _ = token_heartbeat.cancelled() => {
240 break; }
242 }
243 }
244 });
245
246 let mut task = task_handle.task.clone_box();
248 let token_main = token.clone();
249 let ran_task = tokio::spawn(async move {
250 tokio::select! {
251 _ = token_main.cancelled() => {
252 }
254 _ = async {
255 match task.run().await {
256 Ok(outcome) => {
257 let _ = completion_tx.send(outcome).await;
258 }
259 Err(e) => {
260 let _ = completion_tx.send(TaskOutcome::Failed(e.to_string())).await;
261 }
262 }
263 } => {}
264 }
265 });
266
267 task_handle.mark(TaskStatus::Starting);
269 task_handle.cancellation_token = Some(token);
271 task_handle.handles = Some(vec![ran_task, heartbeat_task, completion_task]);
272 }
273
274 fn register_heartbeat(&mut self, heartbeat: Heartbeat) {
276 let Some(task_handle) = self.tasks.get_mut(&heartbeat.task_name) else {
277 return;
278 };
279
280 if task_handle.status == TaskStatus::Dead {
281 return;
282 }
283
284 task_handle.ticked_at(heartbeat.timestamp);
285
286 match task_handle.status {
287 TaskStatus::Starting => {
288 task_handle.mark(TaskStatus::Healthy);
289 task_handle.healthy_since = Some(heartbeat.timestamp);
290 }
291 TaskStatus::Healthy => {
292 if let Some(healthy_since) = task_handle.healthy_since {
294 if heartbeat.timestamp.duration_since(healthy_since) > self.task_is_stable_after
295 {
296 task_handle.restart_attempts = 0;
297 }
298 } else {
299 task_handle.healthy_since = Some(heartbeat.timestamp);
300 }
301 }
302 _ => {}
303 }
304 }
305
306 async fn restart_task(&mut self, task_name: String) {
308 let Some(task_handle) = self.tasks.get_mut(&task_name) else {
309 return;
310 };
311 task_handle.clean().await;
312 task_handle.mark(TaskStatus::Created);
313 Self::start_task(
314 task_name,
315 task_handle,
316 self.internal_tx.clone(),
317 self.heartbeat_interval,
318 )
319 .await;
320 }
321
322 fn check_all_health(&mut self) {
324 let crashed_tasks = self
325 .tasks
326 .iter()
327 .filter(|(_, handle)| handle.has_crashed(self.timeout_threshold))
328 .map(|(name, _)| name.clone())
329 .collect::<Vec<_>>();
330
331 for crashed_task in crashed_tasks {
332 let Some(task_handle) = self.tasks.get_mut(&crashed_task) else {
333 continue;
334 };
335 if task_handle.has_exceeded_max_retries() && task_handle.status != TaskStatus::Dead {
336 task_handle.mark(TaskStatus::Dead);
337 continue;
338 }
339 let restart_delay = task_handle.restart_delay();
340 task_handle.mark(TaskStatus::Failed);
341 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
342 let internal_tx = self.internal_tx.clone();
343 tokio::spawn(async move {
344 tokio::time::sleep(restart_delay).await;
345 let _ = internal_tx.send(SupervisedTaskMessage::Restart(crashed_task));
346 });
347 }
348 }
349
350 async fn handle_task_completion(&mut self, task_name: String, outcome: TaskOutcome) {
352 let Some(task_handle) = self.tasks.get_mut(&task_name) else {
353 return;
354 };
355
356 match outcome {
357 TaskOutcome::Completed => {
358 task_handle.mark(TaskStatus::Completed);
359 }
360 TaskOutcome::Failed(_) => {
361 task_handle.mark(TaskStatus::Failed);
362 if task_handle.has_exceeded_max_retries() {
363 task_handle.mark(TaskStatus::Dead);
364 return;
365 }
366 let restart_delay = task_handle.restart_delay();
367 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
368 let internal_tx = self.internal_tx.clone();
369 tokio::spawn(async move {
370 tokio::time::sleep(restart_delay).await;
371 let _ = internal_tx.send(SupervisedTaskMessage::Restart(task_name));
372 });
373 }
374 }
375 }
376}