Skip to main content

task_supervisor/supervisor/
mod.rs

1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5    collections::{BinaryHeap, HashMap},
6    sync::Arc,
7    time::{Duration, Instant},
8};
9
10use tokio::{sync::mpsc, time::interval};
11use tokio_util::sync::CancellationToken;
12
13#[cfg(feature = "with_tracing")]
14use tracing::{debug, error, info, warn};
15
16use crate::{
17    supervisor::handle::{SupervisorHandle, SupervisorMessage},
18    task::{TaskHandle, TaskResult, TaskStatus},
19};
20
21#[derive(Clone, Debug, thiserror::Error)]
22pub enum SupervisorError {
23    #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
24    TooManyDeadTasks {
25        current_percentage: f64,
26        threshold: f64,
27    },
28}
29
30/// Internal messages sent from tasks to the supervisor.
31#[derive(Debug)]
32pub(crate) enum SupervisedTaskMessage {
33    /// Sent when a task completes, either successfully or with a failure.
34    Completed(Arc<str>, TaskResult),
35    /// Sent when a shutdown signal is requested by the user.
36    Shutdown,
37}
38
39/// A pending restart, ordered by deadline (earliest first).
40struct PendingRestart {
41    deadline: tokio::time::Instant,
42    task_name: Arc<str>,
43}
44
45impl PartialEq for PendingRestart {
46    fn eq(&self, other: &Self) -> bool {
47        self.deadline == other.deadline
48    }
49}
50
51impl Eq for PendingRestart {}
52
53impl PartialOrd for PendingRestart {
54    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
55        Some(self.cmp(other))
56    }
57}
58
59impl Ord for PendingRestart {
60    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
61        // Reverse ordering so BinaryHeap (max-heap) pops the earliest deadline first.
62        other.deadline.cmp(&self.deadline)
63    }
64}
65
66/// Manages a set of tasks, ensuring they remain operational through restarts.
67///
68/// The `Supervisor` spawns each task with monitoring to detect failures.
69/// If a task fails, it is restarted with an exponential backoff.
70/// User commands such as adding, restarting, or killing tasks are supported via the `SupervisorHandle`.
71pub struct Supervisor {
72    pub(crate) tasks: HashMap<Arc<str>, TaskHandle>,
73    // Durations for tasks lifecycle
74    pub(crate) health_check_interval: Duration,
75    pub(crate) base_restart_delay: Duration,
76    pub(crate) task_is_stable_after: Duration,
77    pub(crate) max_restart_attempts: Option<u32>,
78    pub(crate) max_backoff_exponent: u32,
79    pub(crate) max_dead_tasks_percentage_threshold: Option<f64>,
80    // Channels between the User & the Supervisor
81    pub(crate) external_tx: mpsc::UnboundedSender<SupervisorMessage>,
82    pub(crate) external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
83    // Internal channel: tasks -> supervisor
84    pub(crate) internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
85    pub(crate) internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
86}
87
88impl Supervisor {
89    /// Runs the supervisor, consuming it and returning a handle for external control.
90    pub fn run(self) -> SupervisorHandle {
91        let user_tx = self.external_tx.clone();
92        let handle = tokio::spawn(async move { self.run_and_supervise().await });
93        SupervisorHandle::new(handle, user_tx)
94    }
95
96    async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
97        self.start_all_tasks();
98        self.supervise_all_tasks().await
99    }
100
101    fn start_all_tasks(&mut self) {
102        let task_names: Vec<Arc<str>> = self.tasks.keys().cloned().collect();
103        for task_name in task_names {
104            self.start_task(&task_name);
105        }
106    }
107
108    /// Main supervision loop.
109    async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
110        let mut health_check_ticker = interval(self.health_check_interval);
111        let mut pending_restarts: BinaryHeap<PendingRestart> = BinaryHeap::new();
112
113        loop {
114            // Compute the sleep for the next pending restart (if any).
115            let next_restart = async {
116                match pending_restarts.peek() {
117                    Some(pr) => tokio::time::sleep_until(pr.deadline).await,
118                    None => std::future::pending().await,
119                }
120            };
121
122            tokio::select! {
123                biased;
124                Some(internal_msg) = self.internal_rx.recv() => {
125                    match internal_msg {
126                        SupervisedTaskMessage::Shutdown => {
127                            #[cfg(feature = "with_tracing")]
128                            info!("Supervisor received shutdown signal");
129                            return Ok(());
130                        }
131                        SupervisedTaskMessage::Completed(task_name, outcome) => {
132                            #[cfg(feature = "with_tracing")]
133                            match &outcome {
134                                Ok(()) => info!("Task '{}' completed successfully", task_name),
135                                Err(e) => warn!("Task '{}' completed with error: {e}", task_name),
136                            }
137                            self.handle_task_completion(&task_name, outcome, &mut pending_restarts);
138                        }
139                    }
140                },
141                Some(user_msg) = self.external_rx.recv() => {
142                    self.handle_user_message(user_msg, &mut pending_restarts);
143                },
144                _ = next_restart => {
145                    // Pop and execute the restart whose deadline has arrived.
146                    if let Some(pr) = pending_restarts.pop() {
147                        self.restart_task(&pr.task_name);
148                    }
149                },
150                _ = health_check_ticker.tick() => {
151                    #[cfg(feature = "with_tracing")]
152                    debug!("Supervisor checking health of all tasks");
153                    self.check_all_health(&mut pending_restarts);
154                    self.check_dead_tasks_threshold()?;
155                }
156            }
157        }
158    }
159
160    /// Processes user commands received via the `SupervisorHandle`.
161    fn handle_user_message(
162        &mut self,
163        msg: SupervisorMessage,
164        pending_restarts: &mut BinaryHeap<PendingRestart>,
165    ) {
166        match msg {
167            SupervisorMessage::AddTask(task_name, task_dyn) => {
168                let key: Arc<str> = Arc::from(task_name);
169
170                // TODO: This branch should return an error
171                if self.tasks.contains_key(&key) {
172                    #[cfg(feature = "with_tracing")]
173                    warn!("Attempted to add task '{}' but it already exists", key);
174                    return;
175                }
176
177                let mut task_handle = TaskHandle::new(task_dyn);
178                task_handle.max_restart_attempts = self.max_restart_attempts;
179                task_handle.base_restart_delay = self.base_restart_delay;
180                task_handle.max_backoff_exponent = self.max_backoff_exponent;
181
182                self.tasks.insert(Arc::clone(&key), task_handle);
183                self.start_task(&key);
184            }
185            SupervisorMessage::RestartTask(task_name) => {
186                let key: Arc<str> = Arc::from(task_name);
187                #[cfg(feature = "with_tracing")]
188                info!("User requested restart for task: {}", key);
189                self.restart_task(&key);
190            }
191            SupervisorMessage::KillTask(task_name) => {
192                let key: Arc<str> = Arc::from(task_name);
193                if let Some(task_handle) = self.tasks.get_mut(&key) {
194                    if task_handle.status != TaskStatus::Dead {
195                        task_handle.mark(TaskStatus::Dead);
196                        task_handle.clean();
197                    }
198                } else {
199                    #[cfg(feature = "with_tracing")]
200                    warn!("Attempted to kill non-existent task: {}", key);
201                }
202            }
203            SupervisorMessage::GetTaskStatus(task_name, sender) => {
204                let key: Arc<str> = Arc::from(task_name);
205                let status = self.tasks.get(&key).map(|handle| handle.status);
206
207                #[cfg(feature = "with_tracing")]
208                debug!("Status query for task '{}': {:?}", key, status);
209
210                let _ = sender.send(status);
211            }
212            SupervisorMessage::GetAllTaskStatuses(sender) => {
213                let statuses = self
214                    .tasks
215                    .iter()
216                    .map(|(name, handle)| (String::from(name.as_ref()), handle.status))
217                    .collect();
218                let _ = sender.send(statuses);
219            }
220            SupervisorMessage::Shutdown => {
221                #[cfg(feature = "with_tracing")]
222                info!("User requested supervisor shutdown");
223
224                for (_, task_handle) in self.tasks.iter_mut() {
225                    if task_handle.status != TaskStatus::Dead
226                        && task_handle.status != TaskStatus::Completed
227                    {
228                        task_handle.clean();
229                        task_handle.mark(TaskStatus::Dead);
230                    }
231                }
232                pending_restarts.clear();
233                let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
234            }
235        }
236    }
237
238    /// Starts a task, spawning a single tokio task that runs it and reports completion.
239    fn start_task(&mut self, task_name: &Arc<str>) {
240        let Some(task_handle) = self.tasks.get_mut(task_name) else {
241            return;
242        };
243
244        task_handle.mark(TaskStatus::Healthy);
245
246        let token = CancellationToken::new();
247        task_handle.cancellation_token = Some(token.clone());
248
249        let mut task_instance = task_handle.task.clone_box();
250        let internal_tx = self.internal_tx.clone();
251        let name = Arc::clone(task_name);
252
253        let join_handle = tokio::spawn(async move {
254            tokio::select! {
255                _ = token.cancelled() => { }
256                result = task_instance.run_boxed() => {
257                    let _ = internal_tx.send(SupervisedTaskMessage::Completed(name, result));
258                }
259            }
260        });
261
262        task_handle.join_handle = Some(join_handle);
263    }
264
265    /// Restarts a task after cleaning up its previous execution.
266    fn restart_task(&mut self, task_name: &Arc<str>) {
267        if let Some(task_handle) = self.tasks.get_mut(task_name) {
268            task_handle.clean();
269        }
270        self.start_task(task_name);
271    }
272
273    fn check_all_health(&mut self, pending_restarts: &mut BinaryHeap<PendingRestart>) {
274        let now = Instant::now();
275
276        // First pass: mark failed tasks and collect their names.
277        // We collect into a fixed-capacity buffer to avoid per-tick heap allocation
278        // in the common case where few (or zero) tasks need restart.
279        let mut failed_names: Vec<Arc<str>> = Vec::new();
280
281        for (task_name, task_handle) in self.tasks.iter_mut() {
282            if task_handle.status != TaskStatus::Healthy {
283                continue;
284            }
285
286            if let Some(handle) = &task_handle.join_handle {
287                if handle.is_finished() {
288                    #[cfg(feature = "with_tracing")]
289                    warn!(
290                        "Task '{}' unexpectedly finished, marking as failed",
291                        task_name
292                    );
293
294                    task_handle.mark(TaskStatus::Failed);
295                    failed_names.push(Arc::clone(task_name));
296                } else {
297                    // Task is running. Check for stability — reset restart counter
298                    // once a task has been healthy long enough.
299                    if let Some(healthy_since) = task_handle.healthy_since {
300                        if now.duration_since(healthy_since) > self.task_is_stable_after
301                            && task_handle.restart_attempts > 0
302                        {
303                            #[cfg(feature = "with_tracing")]
304                            info!(
305                                "Task '{}' is now stable, resetting restart attempts",
306                                task_name
307                            );
308                            task_handle.restart_attempts = 0;
309                        }
310                    } else {
311                        task_handle.healthy_since = Some(now);
312                    }
313                }
314            } else {
315                #[cfg(feature = "with_tracing")]
316                error!("Task '{}' has no join handle, marking as failed", task_name);
317
318                task_handle.mark(TaskStatus::Failed);
319                failed_names.push(Arc::clone(task_name));
320            }
321        }
322
323        for task_name in failed_names {
324            self.schedule_restart_or_kill(&task_name, pending_restarts);
325        }
326    }
327
328    fn handle_task_completion(
329        &mut self,
330        task_name: &Arc<str>,
331        outcome: TaskResult,
332        pending_restarts: &mut BinaryHeap<PendingRestart>,
333    ) {
334        let Some(task_handle) = self.tasks.get_mut(task_name) else {
335            #[cfg(feature = "with_tracing")]
336            warn!("Received completion for non-existent task: {}", task_name);
337            return;
338        };
339
340        task_handle.clean();
341
342        match outcome {
343            Ok(()) => {
344                #[cfg(feature = "with_tracing")]
345                info!("Task '{}' completed successfully", task_name);
346
347                task_handle.mark(TaskStatus::Completed);
348            }
349            #[allow(unused_variables)]
350            Err(ref e) => {
351                #[cfg(feature = "with_tracing")]
352                error!("Task '{}' failed with error: {:?}", task_name, e);
353
354                task_handle.mark(TaskStatus::Failed);
355                self.schedule_restart_or_kill(task_name, pending_restarts);
356            }
357        }
358    }
359
360    /// Shared logic for scheduling a restart with backoff, or marking the task dead
361    /// if max retries have been exceeded.
362    fn schedule_restart_or_kill(
363        &mut self,
364        task_name: &Arc<str>,
365        pending_restarts: &mut BinaryHeap<PendingRestart>,
366    ) {
367        let Some(task_handle) = self.tasks.get_mut(task_name) else {
368            return;
369        };
370
371        if task_handle.has_exceeded_max_retries() {
372            #[cfg(feature = "with_tracing")]
373            error!(
374                "Task '{}' exceeded max restart attempts ({:?}), marking as dead",
375                task_name,
376                task_handle
377                    .max_restart_attempts
378                    .expect("is provided if has exceeded")
379            );
380
381            task_handle.mark(TaskStatus::Dead);
382            task_handle.clean();
383            return;
384        }
385
386        task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
387        let restart_delay = task_handle.restart_delay();
388
389        #[cfg(feature = "with_tracing")]
390        info!(
391            "Scheduling restart for task '{}' in {:?} (attempt {}/{})",
392            task_name,
393            restart_delay,
394            task_handle.restart_attempts,
395            task_handle
396                .max_restart_attempts
397                .map(|t| t.to_string())
398                .unwrap_or_else(|| "\u{221e}".to_string())
399        );
400
401        pending_restarts.push(PendingRestart {
402            deadline: tokio::time::Instant::now() + restart_delay,
403            task_name: Arc::clone(task_name),
404        });
405    }
406
407    fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
408        let Some(threshold) = self.max_dead_tasks_percentage_threshold else {
409            return Ok(());
410        };
411
412        let total_task_count = self.tasks.len();
413        if total_task_count == 0 {
414            return Ok(());
415        }
416
417        // Single-pass: count dead tasks.
418        let dead_task_count = self
419            .tasks
420            .values()
421            .filter(|handle| handle.status == TaskStatus::Dead)
422            .count();
423
424        let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
425
426        if current_dead_percentage <= threshold {
427            return Ok(());
428        }
429
430        #[cfg(feature = "with_tracing")]
431        error!(
432            "Dead tasks threshold exceeded: {:.2}% > {:.2}% ({}/{} tasks dead)",
433            current_dead_percentage * 100.0,
434            threshold * 100.0,
435            dead_task_count,
436            total_task_count
437        );
438
439        // Kill all remaining non-dead/non-completed tasks
440        #[allow(unused_variables)]
441        for (task_name, task_handle) in self.tasks.iter_mut() {
442            if task_handle.status != TaskStatus::Dead && task_handle.status != TaskStatus::Completed
443            {
444                #[cfg(feature = "with_tracing")]
445                debug!("Killing task '{}' due to threshold breach", task_name);
446
447                task_handle.clean();
448                task_handle.mark(TaskStatus::Dead);
449            }
450        }
451
452        Err(SupervisorError::TooManyDeadTasks {
453            current_percentage: current_dead_percentage,
454            threshold,
455        })
456    }
457}