task_supervisor/supervisor/
mod.rs

1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5    collections::HashMap,
6    time::{Duration, Instant},
7};
8
9use tokio::{sync::mpsc, time::interval};
10use tokio_util::sync::CancellationToken;
11
12use crate::{
13    supervisor::handle::{SupervisorHandle, SupervisorMessage},
14    task::{TaskHandle, TaskOutcome, TaskStatus},
15};
16
17#[derive(Debug, thiserror::Error)]
18pub enum SupervisorError {
19    #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
20    TooManyDeadTasks {
21        current_percentage: f64,
22        threshold: f64,
23    },
24}
25
26/// Internal messages sent from tasks and by the `Supervisor` to manage task lifecycle.
27#[derive(Debug)]
28pub(crate) enum SupervisedTaskMessage {
29    /// Sent by the supervisor to itself to trigger a task restart.
30    Restart(String),
31    /// Sent when a task completes, either successfully or with a failure.
32    Completed(String, TaskOutcome),
33    /// Sent when a shutdown signal is asked by the User. Close every tasks & the supervisor.
34    Shutdown,
35}
36
37/// Manages a set of tasks, ensuring they remain operational through restarts.
38///
39/// The `Supervisor` spawns each task with a heartbeat mechanism to monitor liveness.
40/// If a task stops sending heartbeats or fails, it is restarted with an exponential backoff.
41/// User commands such as adding, restarting, or killing tasks are supported via the `SupervisorHandle`.
42pub struct Supervisor {
43    // List of current tasks
44    tasks: HashMap<String, TaskHandle>,
45    // Durations for tasks lifecycle
46    health_check_interval: Duration,
47    base_restart_delay: Duration,
48    task_is_stable_after: Duration,
49    max_restart_attempts: u32,
50    max_dead_tasks_percentage_threshold: Option<f64>,
51    // Channels between the User & the Supervisor
52    external_tx: mpsc::UnboundedSender<SupervisorMessage>,
53    external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
54    // Internal channels used in the Supervisor for actions propagation
55    internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
56    internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
57}
58
59impl Supervisor {
60    /// Runs the supervisor, consuming it and returning a handle for external control.
61    ///
62    /// This method initiates all tasks and starts the supervision loop.
63    pub fn run(self) -> SupervisorHandle {
64        let user_tx = self.external_tx.clone();
65        let handle = tokio::spawn(async move { self.run_and_supervise().await });
66        SupervisorHandle::new(handle, user_tx)
67    }
68
69    /// Starts and supervises all tasks, running the main supervision loop.
70    async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
71        self.start_all_tasks().await;
72        self.supervise_all_tasks().await
73    }
74
75    /// Initiates all tasks managed by the supervisor.
76    async fn start_all_tasks(&mut self) {
77        for (task_name, task_handle) in self.tasks.iter_mut() {
78            Self::start_task(task_name.to_string(), task_handle, self.internal_tx.clone()).await;
79        }
80    }
81
82    /// Supervises tasks by processing messages and performing periodic health checks.
83    async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
84        let mut health_check_ticker = interval(self.health_check_interval);
85
86        loop {
87            tokio::select! {
88                biased;
89                Some(internal_msg) = self.internal_rx.recv() => {
90                    match internal_msg {
91                        SupervisedTaskMessage::Shutdown => {
92                            return Ok(());
93                        }
94                        _ => self.handle_internal_message(internal_msg).await,
95                    }
96                },
97                Some(user_msg) = self.external_rx.recv() => {
98                    self.handle_user_message(user_msg).await;
99                },
100                _ = health_check_ticker.tick() => {
101                    self.check_all_health().await;
102                }
103            }
104
105            self.check_dead_tasks_threshold().await?;
106        }
107    }
108
109    async fn handle_internal_message(&mut self, msg: SupervisedTaskMessage) {
110        match msg {
111            SupervisedTaskMessage::Restart(task_name) => {
112                self.restart_task(task_name).await;
113            }
114            SupervisedTaskMessage::Completed(task_name, outcome) => {
115                self.handle_task_completion(task_name, outcome).await;
116            }
117            SupervisedTaskMessage::Shutdown => {
118                unreachable!("Shutdown should be handled by the main select loop to break.");
119            }
120        }
121    }
122
123    /// Processes user commands received via the `SupervisorHandle`.
124    async fn handle_user_message(&mut self, msg: SupervisorMessage) {
125        match msg {
126            SupervisorMessage::AddTask(task_name, task_dyn) => {
127                // TODO: This branch should return an error
128                if self.tasks.contains_key(&task_name) {
129                    return;
130                }
131                let mut task_handle =
132                    TaskHandle::new(task_dyn, self.max_restart_attempts, self.base_restart_delay);
133                Self::start_task(
134                    task_name.clone(),
135                    &mut task_handle,
136                    self.internal_tx.clone(),
137                )
138                .await;
139                self.tasks.insert(task_name, task_handle);
140            }
141            SupervisorMessage::RestartTask(task_name) => {
142                self.restart_task(task_name).await;
143            }
144            SupervisorMessage::KillTask(task_name) => {
145                if let Some(task_handle) = self.tasks.get_mut(&task_name) {
146                    if task_handle.status != TaskStatus::Dead {
147                        task_handle.mark(TaskStatus::Dead);
148                        task_handle.clean().await;
149                    }
150                }
151            }
152            SupervisorMessage::GetTaskStatus(task_name, sender) => {
153                let status = self.tasks.get(&task_name).map(|handle| handle.status);
154                let _ = sender.send(status);
155            }
156            SupervisorMessage::GetAllTaskStatuses(sender) => {
157                let statuses = self
158                    .tasks
159                    .iter()
160                    .map(|(name, handle)| (name.clone(), handle.status))
161                    .collect();
162                let _ = sender.send(statuses);
163            }
164            SupervisorMessage::Shutdown => {
165                for (_, task_handle) in self.tasks.iter_mut() {
166                    if task_handle.status != TaskStatus::Dead
167                        && task_handle.status != TaskStatus::Completed
168                    {
169                        task_handle.clean().await;
170                        task_handle.mark(TaskStatus::Dead);
171                    }
172                }
173                let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
174            }
175        }
176    }
177
178    /// Starts a task, setting up its execution, heartbeat, and completion handling.
179    async fn start_task(
180        task_name: String,
181        task_handle: &mut TaskHandle,
182        internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
183    ) {
184        task_handle.started_at = Some(Instant::now());
185        task_handle.mark(TaskStatus::Healthy);
186
187        let token = CancellationToken::new();
188        task_handle.cancellation_token = Some(token.clone());
189
190        let (completion_tx, mut completion_rx) = mpsc::channel::<TaskOutcome>(1);
191
192        // Completion Listener Task
193        let task_name_completion = task_name.clone();
194        let token_completion = token.clone();
195        let internal_tx_completion = internal_tx.clone();
196        let completion_listener_handle = tokio::spawn(async move {
197            tokio::select! {
198                _ = token_completion.cancelled() => { }
199                Some(outcome) = completion_rx.recv() => {
200                    let completion_msg = SupervisedTaskMessage::Completed(task_name_completion.clone(), outcome);
201                    let _ = internal_tx_completion.send(completion_msg);
202                    token_completion.cancel();
203                }
204            }
205        });
206
207        // Main Task Execution
208        let mut task_instance = task_handle.task.clone_box();
209        let token_main = token.clone();
210        let main_task_execution_handle = tokio::spawn(async move {
211            tokio::select! {
212                _ = token_main.cancelled() => { }
213                run_result = task_instance.run() => {
214                    match run_result {
215                        Ok(outcome) => {
216                            let _ = completion_tx.send(outcome).await;
217                        }
218                        Err(e) => {
219                            let _ = completion_tx.send(TaskOutcome::Failed(e.to_string())).await;
220                        }
221                    }
222                }
223            }
224        });
225
226        task_handle.main_task_handle = Some(main_task_execution_handle);
227        task_handle.completion_task_handle = Some(completion_listener_handle);
228    }
229
230    /// Restarts a task after cleaning up its previous execution.
231    async fn restart_task(&mut self, task_name: String) {
232        if let Some(task_handle) = self.tasks.get_mut(&task_name) {
233            task_handle.clean().await;
234            Self::start_task(task_name, task_handle, self.internal_tx.clone()).await;
235        }
236    }
237
238    async fn check_all_health(&mut self) {
239        let mut tasks_needing_restart: Vec<String> = Vec::new();
240        let now = Instant::now();
241
242        for (task_name, task_handle) in self.tasks.iter_mut() {
243            if task_handle.status == TaskStatus::Healthy {
244                if let Some(main_handle) = &task_handle.main_task_handle {
245                    if main_handle.is_finished() {
246                        task_handle.mark(TaskStatus::Failed);
247                        tasks_needing_restart.push(task_name.clone());
248                    } else {
249                        // Task is Healthy and running. Check for stability.
250                        if let Some(healthy_since) = task_handle.healthy_since {
251                            if (now.duration_since(healthy_since) > self.task_is_stable_after)
252                                && task_handle.restart_attempts > 0
253                            {
254                                task_handle.restart_attempts = 0;
255                            }
256                        } else {
257                            task_handle.healthy_since = Some(now);
258                        }
259                    }
260                } else {
261                    task_handle.mark(TaskStatus::Failed);
262                    tasks_needing_restart.push(task_name.clone());
263                }
264            }
265        }
266
267        for task_name in tasks_needing_restart {
268            let Some(task_handle) = self.tasks.get_mut(&task_name) else {
269                continue;
270            };
271
272            // Ensure it's still failed
273            if task_handle.has_exceeded_max_retries() {
274                task_handle.mark(TaskStatus::Dead);
275                task_handle.clean().await;
276                continue;
277            }
278
279            // Increment before calculating delay for the current attempt
280            task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
281            let restart_delay = task_handle.restart_delay();
282
283            let internal_tx_clone = self.internal_tx.clone();
284            tokio::spawn(async move {
285                tokio::time::sleep(restart_delay).await;
286                let _ = internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
287            });
288        }
289    }
290
291    async fn handle_task_completion(&mut self, task_name: String, outcome: TaskOutcome) {
292        let Some(task_handle) = self.tasks.get_mut(&task_name) else {
293            return;
294        };
295
296        task_handle.clean().await;
297
298        match outcome {
299            TaskOutcome::Completed => {
300                task_handle.mark(TaskStatus::Completed);
301            }
302            TaskOutcome::Failed(_) => {
303                task_handle.mark(TaskStatus::Failed);
304
305                if task_handle.has_exceeded_max_retries() {
306                    task_handle.mark(TaskStatus::Dead);
307                    return;
308                }
309
310                task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
311                let restart_delay = task_handle.restart_delay();
312
313                let internal_tx_clone = self.internal_tx.clone();
314                tokio::spawn(async move {
315                    tokio::time::sleep(restart_delay).await;
316                    let _ =
317                        internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
318                });
319            }
320        }
321    }
322
323    async fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
324        if let Some(threshold) = self.max_dead_tasks_percentage_threshold {
325            if !self.tasks.is_empty() {
326                let dead_task_count = self
327                    .tasks
328                    .values()
329                    .filter(|handle| handle.status == TaskStatus::Dead)
330                    .count();
331
332                let total_task_count = self.tasks.len();
333                let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
334
335                if current_dead_percentage > threshold {
336                    // Kill all remaining non-dead/non-completed tasks
337                    for (_, task_handle) in self.tasks.iter_mut() {
338                        if task_handle.status != TaskStatus::Dead
339                            && task_handle.status != TaskStatus::Completed
340                        {
341                            task_handle.clean().await;
342                            task_handle.mark(TaskStatus::Dead);
343                        }
344                    }
345
346                    return Err(SupervisorError::TooManyDeadTasks {
347                        current_percentage: current_dead_percentage,
348                        threshold,
349                    });
350                }
351            }
352        };
353
354        Ok(())
355    }
356}