task_supervisor/supervisor/
mod.rs

1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5    collections::HashMap,
6    time::{Duration, Instant},
7};
8
9use tokio::{sync::mpsc, time::interval};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "with_tracing")]
13use tracing::{debug, error, info, warn};
14
15use crate::{
16    supervisor::handle::{SupervisorHandle, SupervisorMessage},
17    task::{TaskHandle, TaskResult, TaskStatus},
18};
19
20#[derive(Clone, Debug, thiserror::Error)]
21pub enum SupervisorError {
22    #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
23    TooManyDeadTasks {
24        current_percentage: f64,
25        threshold: f64,
26    },
27}
28
29/// Internal messages sent from tasks and by the `Supervisor` to manage task lifecycle.
30#[derive(Debug)]
31pub(crate) enum SupervisedTaskMessage {
32    /// Sent by the supervisor to itself to trigger a task restart.
33    Restart(String),
34    /// Sent when a task completes, either successfully or with a failure.
35    Completed(String, TaskResult),
36    /// Sent when a shutdown signal is asked by the User. Close every tasks & the supervisor.
37    Shutdown,
38}
39
40/// Manages a set of tasks, ensuring they remain operational through restarts.
41///
42/// The `Supervisor` spawns each task with a heartbeat mechanism to monitor liveness.
43/// If a task stops sending heartbeats or fails, it is restarted with an exponential backoff.
44/// User commands such as adding, restarting, or killing tasks are supported via the `SupervisorHandle`.
45pub struct Supervisor {
46    // List of current tasks
47    tasks: HashMap<String, TaskHandle>,
48    // Durations for tasks lifecycle
49    health_check_interval: Duration,
50    base_restart_delay: Duration,
51    task_is_stable_after: Duration,
52    max_restart_attempts: u32,
53    max_dead_tasks_percentage_threshold: Option<f64>,
54    // Channels between the User & the Supervisor
55    external_tx: mpsc::UnboundedSender<SupervisorMessage>,
56    external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
57    // Internal channels used in the Supervisor for actions propagation
58    internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
59    internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
60}
61
62impl Supervisor {
63    /// Runs the supervisor, consuming it and returning a handle for external control.
64    ///
65    /// This method initiates all tasks and starts the supervision loop.
66    pub fn run(self) -> SupervisorHandle {
67        let user_tx = self.external_tx.clone();
68        let handle = tokio::spawn(async move { self.run_and_supervise().await });
69        SupervisorHandle::new(handle, user_tx)
70    }
71
72    /// Starts and supervises all tasks, running the main supervision loop.
73    async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
74        self.start_all_tasks().await;
75        self.supervise_all_tasks().await
76    }
77
78    /// Initiates all tasks managed by the supervisor.
79    async fn start_all_tasks(&mut self) {
80        for (task_name, task_handle) in self.tasks.iter_mut() {
81            Self::start_task(task_name.to_string(), task_handle, self.internal_tx.clone()).await;
82        }
83    }
84
85    /// Supervises tasks by processing messages and performing periodic health checks.
86    async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
87        let mut health_check_ticker = interval(self.health_check_interval);
88
89        loop {
90            tokio::select! {
91                biased;
92                Some(internal_msg) = self.internal_rx.recv() => {
93                    match internal_msg {
94                        SupervisedTaskMessage::Shutdown => {
95                            #[cfg(feature = "with_tracing")]
96                            info!("Supervisor received shutdown signal");
97                            return Ok(());
98                        }
99                        _ => self.handle_internal_message(internal_msg).await,
100                    }
101                },
102                Some(user_msg) = self.external_rx.recv() => {
103                    self.handle_user_message(user_msg).await;
104                },
105                _ = health_check_ticker.tick() => {
106                    #[cfg(feature = "with_tracing")]
107                    debug!("Supervisor checking health of all tasks");
108                    self.check_all_health().await;
109                    self.check_dead_tasks_threshold().await?;
110                }
111            }
112        }
113    }
114
115    async fn handle_internal_message(&mut self, msg: SupervisedTaskMessage) {
116        match msg {
117            SupervisedTaskMessage::Restart(task_name) => {
118                #[cfg(feature = "with_tracing")]
119                info!("Processing restart request for task: {task_name}");
120                self.restart_task(task_name).await;
121            }
122            SupervisedTaskMessage::Completed(task_name, outcome) => {
123                #[cfg(feature = "with_tracing")]
124                match &outcome {
125                    Ok(()) => info!("Task '{task_name}' completed successfully"),
126                    Err(e) => warn!("Task '{task_name}' completed with error: {e}"),
127                }
128                self.handle_task_completion(task_name, outcome).await;
129            }
130            SupervisedTaskMessage::Shutdown => {
131                unreachable!("Shutdown is handled by the main select loop.");
132            }
133        }
134    }
135
136    /// Processes user commands received via the `SupervisorHandle`.
137    async fn handle_user_message(&mut self, msg: SupervisorMessage) {
138        match msg {
139            SupervisorMessage::AddTask(task_name, task_dyn) => {
140                // TODO: This branch should return an error
141                if self.tasks.contains_key(&task_name) {
142                    #[cfg(feature = "with_tracing")]
143                    warn!("Attempted to add task '{task_name}' but it already exists");
144                    return;
145                }
146
147                let mut task_handle =
148                    TaskHandle::new(task_dyn, self.max_restart_attempts, self.base_restart_delay);
149                Self::start_task(
150                    task_name.clone(),
151                    &mut task_handle,
152                    self.internal_tx.clone(),
153                )
154                .await;
155                self.tasks.insert(task_name, task_handle);
156            }
157            SupervisorMessage::RestartTask(task_name) => {
158                #[cfg(feature = "with_tracing")]
159                info!("User requested restart for task: {task_name}");
160                self.restart_task(task_name).await;
161            }
162            SupervisorMessage::KillTask(task_name) => {
163                if let Some(task_handle) = self.tasks.get_mut(&task_name) {
164                    if task_handle.status != TaskStatus::Dead {
165                        task_handle.mark(TaskStatus::Dead);
166                        task_handle.clean().await;
167                    }
168                } else {
169                    #[cfg(feature = "with_tracing")]
170                    warn!("Attempted to kill non-existent task: {task_name}");
171                }
172            }
173            SupervisorMessage::GetTaskStatus(task_name, sender) => {
174                let status = self.tasks.get(&task_name).map(|handle| handle.status);
175
176                #[cfg(feature = "with_tracing")]
177                debug!("Status query for task '{task_name}': {status:?}");
178
179                let _ = sender.send(status);
180            }
181            SupervisorMessage::GetAllTaskStatuses(sender) => {
182                let statuses = self
183                    .tasks
184                    .iter()
185                    .map(|(name, handle)| (name.clone(), handle.status))
186                    .collect();
187                let _ = sender.send(statuses);
188            }
189            SupervisorMessage::Shutdown => {
190                #[cfg(feature = "with_tracing")]
191                info!("User requested supervisor shutdown");
192
193                for (_, task_handle) in self.tasks.iter_mut() {
194                    if task_handle.status != TaskStatus::Dead
195                        && task_handle.status != TaskStatus::Completed
196                    {
197                        task_handle.clean().await;
198                        task_handle.mark(TaskStatus::Dead);
199                    }
200                }
201                let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
202            }
203        }
204    }
205
206    /// Starts a task, setting up its execution, heartbeat, and completion handling.
207    async fn start_task(
208        task_name: String,
209        task_handle: &mut TaskHandle,
210        internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
211    ) {
212        task_handle.started_at = Some(Instant::now());
213        task_handle.mark(TaskStatus::Healthy);
214
215        let token = CancellationToken::new();
216        task_handle.cancellation_token = Some(token.clone());
217
218        let (completion_tx, mut completion_rx) = mpsc::channel::<TaskResult>(1);
219
220        // Completion Listener Task
221        let task_name_completion = task_name.clone();
222        let token_completion = token.clone();
223        let internal_tx_completion = internal_tx.clone();
224        let completion_listener_handle = tokio::spawn(async move {
225            tokio::select! {
226                _ = token_completion.cancelled() => { }
227                Some(outcome) = completion_rx.recv() => {
228                    let completion_msg = SupervisedTaskMessage::Completed(task_name_completion.clone(), outcome);
229                    let _ = internal_tx_completion.send(completion_msg);
230                    token_completion.cancel();
231                }
232            }
233        });
234
235        // Main Task Execution
236        let mut task_instance = task_handle.task.clone_box();
237        let token_main = token.clone();
238        let main_task_execution_handle = tokio::spawn(async move {
239            tokio::select! {
240                _ = token_main.cancelled() => { }
241                run_result = task_instance.run() => {
242                    let _ = completion_tx.send(run_result).await;
243                }
244            }
245        });
246
247        task_handle.main_task_handle = Some(main_task_execution_handle);
248        task_handle.completion_task_handle = Some(completion_listener_handle);
249    }
250
251    /// Restarts a task after cleaning up its previous execution.
252    async fn restart_task(&mut self, task_name: String) {
253        if let Some(task_handle) = self.tasks.get_mut(&task_name) {
254            task_handle.clean().await;
255            Self::start_task(task_name, task_handle, self.internal_tx.clone()).await;
256        }
257    }
258
259    async fn check_all_health(&mut self) {
260        let mut tasks_needing_restart: Vec<String> = Vec::new();
261        let now = Instant::now();
262
263        for (task_name, task_handle) in self.tasks.iter_mut() {
264            if task_handle.status == TaskStatus::Healthy {
265                if let Some(main_handle) = &task_handle.main_task_handle {
266                    if main_handle.is_finished() {
267                        #[cfg(feature = "with_tracing")]
268                        warn!("Task '{task_name}' unexpectedly finished, marking as failed");
269
270                        task_handle.mark(TaskStatus::Failed);
271                        tasks_needing_restart.push(task_name.clone());
272                    } else {
273                        // Task is Healthy and running. Check for stability.
274                        if let Some(healthy_since) = task_handle.healthy_since {
275                            if (now.duration_since(healthy_since) > self.task_is_stable_after)
276                                && task_handle.restart_attempts > 0
277                            {
278                                #[cfg(feature = "with_tracing")]
279                                info!(
280                                    "Task '{task_name}' is now stable, resetting restart attempts",
281                                );
282
283                                task_handle.restart_attempts = 0;
284                            }
285                        } else {
286                            task_handle.healthy_since = Some(now);
287                        }
288                    }
289                } else {
290                    #[cfg(feature = "with_tracing")]
291                    error!("Task '{task_name}' has no main handle, marking as failed");
292
293                    task_handle.mark(TaskStatus::Failed);
294                    tasks_needing_restart.push(task_name.clone());
295                }
296            }
297        }
298
299        for task_name in tasks_needing_restart {
300            let Some(task_handle) = self.tasks.get_mut(&task_name) else {
301                continue;
302            };
303
304            // Ensure it's still failed
305            if task_handle.has_exceeded_max_retries() {
306                #[cfg(feature = "with_tracing")]
307                error!(
308                    "Task '{task_name}' exceeded max restart attempts ({}), marking as dead",
309                    self.max_restart_attempts
310                );
311
312                task_handle.mark(TaskStatus::Dead);
313                task_handle.clean().await;
314                continue;
315            }
316
317            // Increment before calculating delay for the current attempt
318            task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
319            let restart_delay = task_handle.restart_delay();
320
321            #[cfg(feature = "with_tracing")]
322            info!(
323                "Scheduling restart for task '{task_name}' in {restart_delay:?} (attempt {}/{})",
324                task_handle.restart_attempts, self.max_restart_attempts
325            );
326
327            let internal_tx_clone = self.internal_tx.clone();
328            tokio::spawn(async move {
329                tokio::time::sleep(restart_delay).await;
330                let _ = internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
331            });
332        }
333    }
334
335    async fn handle_task_completion(&mut self, task_name: String, outcome: TaskResult) {
336        let Some(task_handle) = self.tasks.get_mut(&task_name) else {
337            #[cfg(feature = "with_tracing")]
338            warn!("Received completion for non-existent task: {}", task_name);
339            return;
340        };
341
342        task_handle.clean().await;
343
344        match outcome {
345            Ok(()) => {
346                #[cfg(feature = "with_tracing")]
347                info!("Task '{task_name}' completed successfully");
348
349                task_handle.mark(TaskStatus::Completed);
350            }
351            #[allow(unused_variables)]
352            Err(ref e) => {
353                #[cfg(feature = "with_tracing")]
354                error!("Task '{task_name}' failed with error: {e:?}");
355
356                task_handle.mark(TaskStatus::Failed);
357                if task_handle.has_exceeded_max_retries() {
358                    #[cfg(feature = "with_tracing")]
359                    error!(
360                        "Task '{task_name}' exceeded max restart attempts after failure, marking as dead",
361                    );
362
363                    task_handle.mark(TaskStatus::Dead);
364                    return;
365                }
366
367                task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
368                let restart_delay = task_handle.restart_delay();
369
370                #[cfg(feature = "with_tracing")]
371                info!(
372                    "Scheduling restart for failed task '{task_name}' in {restart_delay:?} (attempt {}/{})",
373                    task_handle.restart_attempts,
374                    self.max_restart_attempts
375                );
376
377                let internal_tx_clone = self.internal_tx.clone();
378                tokio::spawn(async move {
379                    tokio::time::sleep(restart_delay).await;
380                    let _ =
381                        internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
382                });
383            }
384        }
385    }
386
387    async fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
388        if let Some(threshold) = self.max_dead_tasks_percentage_threshold {
389            if !self.tasks.is_empty() {
390                let dead_task_count = self
391                    .tasks
392                    .values()
393                    .filter(|handle| handle.status == TaskStatus::Dead)
394                    .count();
395
396                let total_task_count = self.tasks.len();
397                let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
398
399                if current_dead_percentage > threshold {
400                    #[cfg(feature = "with_tracing")]
401                    error!(
402                        "Dead tasks threshold exceeded: {:.2}% > {:.2}% ({}/{} tasks dead)",
403                        current_dead_percentage * 100.0,
404                        threshold * 100.0,
405                        dead_task_count,
406                        total_task_count
407                    );
408
409                    // Kill all remaining non-dead/non-completed tasks
410                    #[allow(unused_variables)]
411                    for (task_name, task_handle) in self.tasks.iter_mut() {
412                        if task_handle.status != TaskStatus::Dead
413                            && task_handle.status != TaskStatus::Completed
414                        {
415                            #[cfg(feature = "with_tracing")]
416                            debug!("Killing task '{task_name}' due to threshold breach");
417
418                            task_handle.clean().await;
419                            task_handle.mark(TaskStatus::Dead);
420                        }
421                    }
422
423                    return Err(SupervisorError::TooManyDeadTasks {
424                        current_percentage: current_dead_percentage,
425                        threshold,
426                    });
427                }
428            }
429        };
430
431        Ok(())
432    }
433}