task_supervisor/supervisor/
mod.rs

1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5    collections::HashMap,
6    time::{Duration, Instant},
7};
8
9use tokio::{sync::mpsc, time::interval};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "with_tracing")]
13use tracing::{debug, error, info, warn};
14
15use crate::{
16    supervisor::handle::{SupervisorHandle, SupervisorMessage},
17    task::{TaskHandle, TaskResult, TaskStatus},
18};
19
20#[derive(Clone, Debug, thiserror::Error)]
21pub enum SupervisorError {
22    #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
23    TooManyDeadTasks {
24        current_percentage: f64,
25        threshold: f64,
26    },
27}
28
29/// Internal messages sent from tasks and by the `Supervisor` to manage task lifecycle.
30#[derive(Debug)]
31pub(crate) enum SupervisedTaskMessage {
32    /// Sent by the supervisor to itself to trigger a task restart.
33    Restart(String),
34    /// Sent when a task completes, either successfully or with a failure.
35    Completed(String, TaskResult),
36    /// Sent when a shutdown signal is asked by the User. Close every tasks & the supervisor.
37    Shutdown,
38}
39
40/// Manages a set of tasks, ensuring they remain operational through restarts.
41///
42/// The `Supervisor` spawns each task with a heartbeat mechanism to monitor liveness.
43/// If a task stops sending heartbeats or fails, it is restarted with an exponential backoff.
44/// User commands such as adding, restarting, or killing tasks are supported via the `SupervisorHandle`.
45pub struct Supervisor {
46    // List of current tasks
47    tasks: HashMap<String, TaskHandle>,
48    // Durations for tasks lifecycle
49    health_check_interval: Duration,
50    base_restart_delay: Duration,
51    task_is_stable_after: Duration,
52    max_restart_attempts: u32,
53    max_backoff_exponent: u32,
54    max_dead_tasks_percentage_threshold: Option<f64>,
55    // Channels between the User & the Supervisor
56    external_tx: mpsc::UnboundedSender<SupervisorMessage>,
57    external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
58    // Internal channels used in the Supervisor for actions propagation
59    internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
60    internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
61}
62
63impl Supervisor {
64    /// Runs the supervisor, consuming it and returning a handle for external control.
65    ///
66    /// This method initiates all tasks and starts the supervision loop.
67    pub fn run(self) -> SupervisorHandle {
68        let user_tx = self.external_tx.clone();
69        let handle = tokio::spawn(async move { self.run_and_supervise().await });
70        SupervisorHandle::new(handle, user_tx)
71    }
72
73    /// Starts and supervises all tasks, running the main supervision loop.
74    async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
75        self.start_all_tasks().await;
76        self.supervise_all_tasks().await
77    }
78
79    /// Initiates all tasks managed by the supervisor.
80    async fn start_all_tasks(&mut self) {
81        for (task_name, task_handle) in self.tasks.iter_mut() {
82            Self::start_task(task_name.to_string(), task_handle, self.internal_tx.clone()).await;
83        }
84    }
85
86    /// Supervises tasks by processing messages and performing periodic health checks.
87    async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
88        let mut health_check_ticker = interval(self.health_check_interval);
89
90        loop {
91            tokio::select! {
92                biased;
93                Some(internal_msg) = self.internal_rx.recv() => {
94                    match internal_msg {
95                        SupervisedTaskMessage::Shutdown => {
96                            #[cfg(feature = "with_tracing")]
97                            info!("Supervisor received shutdown signal");
98                            return Ok(());
99                        }
100                        _ => self.handle_internal_message(internal_msg).await,
101                    }
102                },
103                Some(user_msg) = self.external_rx.recv() => {
104                    self.handle_user_message(user_msg).await;
105                },
106                _ = health_check_ticker.tick() => {
107                    #[cfg(feature = "with_tracing")]
108                    debug!("Supervisor checking health of all tasks");
109                    self.check_all_health().await;
110                    self.check_dead_tasks_threshold().await?;
111                }
112            }
113        }
114    }
115
116    async fn handle_internal_message(&mut self, msg: SupervisedTaskMessage) {
117        match msg {
118            SupervisedTaskMessage::Restart(task_name) => {
119                #[cfg(feature = "with_tracing")]
120                info!("Processing restart request for task: {task_name}");
121                self.restart_task(task_name).await;
122            }
123            SupervisedTaskMessage::Completed(task_name, outcome) => {
124                #[cfg(feature = "with_tracing")]
125                match &outcome {
126                    Ok(()) => info!("Task '{task_name}' completed successfully"),
127                    Err(e) => warn!("Task '{task_name}' completed with error: {e}"),
128                }
129                self.handle_task_completion(task_name, outcome).await;
130            }
131            SupervisedTaskMessage::Shutdown => {
132                unreachable!("Shutdown is handled by the main select loop.");
133            }
134        }
135    }
136
137    /// Processes user commands received via the `SupervisorHandle`.
138    async fn handle_user_message(&mut self, msg: SupervisorMessage) {
139        match msg {
140            SupervisorMessage::AddTask(task_name, task_dyn) => {
141                // TODO: This branch should return an error
142                if self.tasks.contains_key(&task_name) {
143                    #[cfg(feature = "with_tracing")]
144                    warn!("Attempted to add task '{task_name}' but it already exists");
145                    return;
146                }
147
148                let mut task_handle = TaskHandle::new(
149                    task_dyn,
150                    self.max_restart_attempts,
151                    self.base_restart_delay,
152                    self.max_backoff_exponent,
153                );
154
155                Self::start_task(
156                    task_name.clone(),
157                    &mut task_handle,
158                    self.internal_tx.clone(),
159                )
160                .await;
161                self.tasks.insert(task_name, task_handle);
162            }
163            SupervisorMessage::RestartTask(task_name) => {
164                #[cfg(feature = "with_tracing")]
165                info!("User requested restart for task: {task_name}");
166                self.restart_task(task_name).await;
167            }
168            SupervisorMessage::KillTask(task_name) => {
169                if let Some(task_handle) = self.tasks.get_mut(&task_name) {
170                    if task_handle.status != TaskStatus::Dead {
171                        task_handle.mark(TaskStatus::Dead);
172                        task_handle.clean().await;
173                    }
174                } else {
175                    #[cfg(feature = "with_tracing")]
176                    warn!("Attempted to kill non-existent task: {task_name}");
177                }
178            }
179            SupervisorMessage::GetTaskStatus(task_name, sender) => {
180                let status = self.tasks.get(&task_name).map(|handle| handle.status);
181
182                #[cfg(feature = "with_tracing")]
183                debug!("Status query for task '{task_name}': {status:?}");
184
185                let _ = sender.send(status);
186            }
187            SupervisorMessage::GetAllTaskStatuses(sender) => {
188                let statuses = self
189                    .tasks
190                    .iter()
191                    .map(|(name, handle)| (name.clone(), handle.status))
192                    .collect();
193                let _ = sender.send(statuses);
194            }
195            SupervisorMessage::Shutdown => {
196                #[cfg(feature = "with_tracing")]
197                info!("User requested supervisor shutdown");
198
199                for (_, task_handle) in self.tasks.iter_mut() {
200                    if task_handle.status != TaskStatus::Dead
201                        && task_handle.status != TaskStatus::Completed
202                    {
203                        task_handle.clean().await;
204                        task_handle.mark(TaskStatus::Dead);
205                    }
206                }
207                let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
208            }
209        }
210    }
211
212    /// Starts a task, setting up its execution, heartbeat, and completion handling.
213    async fn start_task(
214        task_name: String,
215        task_handle: &mut TaskHandle,
216        internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
217    ) {
218        task_handle.started_at = Some(Instant::now());
219        task_handle.mark(TaskStatus::Healthy);
220
221        let token = CancellationToken::new();
222        task_handle.cancellation_token = Some(token.clone());
223
224        let (completion_tx, mut completion_rx) = mpsc::channel::<TaskResult>(1);
225
226        // Completion Listener Task
227        let task_name_completion = task_name.clone();
228        let token_completion = token.clone();
229        let internal_tx_completion = internal_tx.clone();
230        let completion_listener_handle = tokio::spawn(async move {
231            tokio::select! {
232                _ = token_completion.cancelled() => { }
233                Some(outcome) = completion_rx.recv() => {
234                    let completion_msg = SupervisedTaskMessage::Completed(task_name_completion.clone(), outcome);
235                    let _ = internal_tx_completion.send(completion_msg);
236                    token_completion.cancel();
237                }
238            }
239        });
240
241        // Main Task Execution
242        let mut task_instance = task_handle.task.clone_box();
243        let token_main = token.clone();
244        let main_task_execution_handle = tokio::spawn(async move {
245            tokio::select! {
246                _ = token_main.cancelled() => { }
247                run_result = task_instance.run() => {
248                    let _ = completion_tx.send(run_result).await;
249                }
250            }
251        });
252
253        task_handle.main_task_handle = Some(main_task_execution_handle);
254        task_handle.completion_task_handle = Some(completion_listener_handle);
255    }
256
257    /// Restarts a task after cleaning up its previous execution.
258    async fn restart_task(&mut self, task_name: String) {
259        if let Some(task_handle) = self.tasks.get_mut(&task_name) {
260            task_handle.clean().await;
261            Self::start_task(task_name, task_handle, self.internal_tx.clone()).await;
262        }
263    }
264
265    async fn check_all_health(&mut self) {
266        let mut tasks_needing_restart: Vec<String> = Vec::new();
267        let now = Instant::now();
268
269        for (task_name, task_handle) in self.tasks.iter_mut() {
270            if task_handle.status == TaskStatus::Healthy {
271                if let Some(main_handle) = &task_handle.main_task_handle {
272                    if main_handle.is_finished() {
273                        #[cfg(feature = "with_tracing")]
274                        warn!("Task '{task_name}' unexpectedly finished, marking as failed");
275
276                        task_handle.mark(TaskStatus::Failed);
277                        tasks_needing_restart.push(task_name.clone());
278                    } else {
279                        // Task is Healthy and running. Check for stability.
280                        if let Some(healthy_since) = task_handle.healthy_since {
281                            if (now.duration_since(healthy_since) > self.task_is_stable_after)
282                                && task_handle.restart_attempts > 0
283                            {
284                                #[cfg(feature = "with_tracing")]
285                                info!(
286                                    "Task '{task_name}' is now stable, resetting restart attempts",
287                                );
288
289                                task_handle.restart_attempts = 0;
290                            }
291                        } else {
292                            task_handle.healthy_since = Some(now);
293                        }
294                    }
295                } else {
296                    #[cfg(feature = "with_tracing")]
297                    error!("Task '{task_name}' has no main handle, marking as failed");
298
299                    task_handle.mark(TaskStatus::Failed);
300                    tasks_needing_restart.push(task_name.clone());
301                }
302            }
303        }
304
305        for task_name in tasks_needing_restart {
306            let Some(task_handle) = self.tasks.get_mut(&task_name) else {
307                continue;
308            };
309
310            // Ensure it's still failed
311            if task_handle.has_exceeded_max_retries() {
312                #[cfg(feature = "with_tracing")]
313                error!(
314                    "Task '{task_name}' exceeded max restart attempts ({}), marking as dead",
315                    self.max_restart_attempts
316                );
317
318                task_handle.mark(TaskStatus::Dead);
319                task_handle.clean().await;
320                continue;
321            }
322
323            // Increment before calculating delay for the current attempt
324            task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
325            let restart_delay = task_handle.restart_delay();
326
327            #[cfg(feature = "with_tracing")]
328            info!(
329                "Scheduling restart for task '{task_name}' in {restart_delay:?} (attempt {}/{})",
330                task_handle.restart_attempts, self.max_restart_attempts
331            );
332
333            let internal_tx_clone = self.internal_tx.clone();
334            tokio::spawn(async move {
335                tokio::time::sleep(restart_delay).await;
336                let _ = internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
337            });
338        }
339    }
340
341    async fn handle_task_completion(&mut self, task_name: String, outcome: TaskResult) {
342        let Some(task_handle) = self.tasks.get_mut(&task_name) else {
343            #[cfg(feature = "with_tracing")]
344            warn!("Received completion for non-existent task: {}", task_name);
345            return;
346        };
347
348        task_handle.clean().await;
349
350        match outcome {
351            Ok(()) => {
352                #[cfg(feature = "with_tracing")]
353                info!("Task '{task_name}' completed successfully");
354
355                task_handle.mark(TaskStatus::Completed);
356            }
357            #[allow(unused_variables)]
358            Err(ref e) => {
359                #[cfg(feature = "with_tracing")]
360                error!("Task '{task_name}' failed with error: {e:?}");
361
362                task_handle.mark(TaskStatus::Failed);
363                if task_handle.has_exceeded_max_retries() {
364                    #[cfg(feature = "with_tracing")]
365                    error!(
366                        "Task '{task_name}' exceeded max restart attempts after failure, marking as dead",
367                    );
368
369                    task_handle.mark(TaskStatus::Dead);
370                    return;
371                }
372
373                task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
374                let restart_delay = task_handle.restart_delay();
375
376                #[cfg(feature = "with_tracing")]
377                info!(
378                    "Scheduling restart for failed task '{task_name}' in {restart_delay:?} (attempt {}/{})",
379                    task_handle.restart_attempts,
380                    self.max_restart_attempts
381                );
382
383                let internal_tx_clone = self.internal_tx.clone();
384                tokio::spawn(async move {
385                    tokio::time::sleep(restart_delay).await;
386                    let _ =
387                        internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
388                });
389            }
390        }
391    }
392
393    async fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
394        if let Some(threshold) = self.max_dead_tasks_percentage_threshold {
395            if !self.tasks.is_empty() {
396                let dead_task_count = self
397                    .tasks
398                    .values()
399                    .filter(|handle| handle.status == TaskStatus::Dead)
400                    .count();
401
402                let total_task_count = self.tasks.len();
403                let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
404
405                if current_dead_percentage > threshold {
406                    #[cfg(feature = "with_tracing")]
407                    error!(
408                        "Dead tasks threshold exceeded: {:.2}% > {:.2}% ({}/{} tasks dead)",
409                        current_dead_percentage * 100.0,
410                        threshold * 100.0,
411                        dead_task_count,
412                        total_task_count
413                    );
414
415                    // Kill all remaining non-dead/non-completed tasks
416                    #[allow(unused_variables)]
417                    for (task_name, task_handle) in self.tasks.iter_mut() {
418                        if task_handle.status != TaskStatus::Dead
419                            && task_handle.status != TaskStatus::Completed
420                        {
421                            #[cfg(feature = "with_tracing")]
422                            debug!("Killing task '{task_name}' due to threshold breach");
423
424                            task_handle.clean().await;
425                            task_handle.mark(TaskStatus::Dead);
426                        }
427                    }
428
429                    return Err(SupervisorError::TooManyDeadTasks {
430                        current_percentage: current_dead_percentage,
431                        threshold,
432                    });
433                }
434            }
435        };
436
437        Ok(())
438    }
439}