Skip to main content

task_supervisor/supervisor/
mod.rs

1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5    collections::{BinaryHeap, HashMap},
6    sync::Arc,
7    time::{Duration, Instant},
8};
9
10use tokio::{sync::mpsc, time::interval};
11use tokio_util::sync::CancellationToken;
12
13#[cfg(feature = "tracing")]
14use tracing::{debug, error, info, warn};
15
16use crate::{
17    supervisor::handle::{SupervisorHandle, SupervisorMessage},
18    task::{TaskHandle, TaskResult, TaskStatus},
19};
20
21#[derive(Clone, Debug, thiserror::Error)]
22pub enum SupervisorError {
23    #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
24    TooManyDeadTasks {
25        current_percentage: f64,
26        threshold: f64,
27    },
28}
29
30#[derive(Debug)]
31pub(crate) enum SupervisedTaskMessage {
32    Completed(Arc<str>, TaskResult),
33    Shutdown,
34}
35
36/// Pending restart ordered by deadline (earliest first).
37struct PendingRestart {
38    deadline: tokio::time::Instant,
39    task_name: Arc<str>,
40}
41
42impl PartialEq for PendingRestart {
43    fn eq(&self, other: &Self) -> bool {
44        self.deadline == other.deadline
45    }
46}
47
48impl Eq for PendingRestart {}
49
50impl PartialOrd for PendingRestart {
51    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
52        Some(self.cmp(other))
53    }
54}
55
56impl Ord for PendingRestart {
57    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
58        // Reversed: BinaryHeap is a max-heap, we want earliest deadline first.
59        other.deadline.cmp(&self.deadline)
60    }
61}
62
63pub struct Supervisor {
64    pub(crate) tasks: HashMap<Arc<str>, TaskHandle>,
65    pub(crate) health_check_interval: Duration,
66    pub(crate) base_restart_delay: Duration,
67    pub(crate) task_is_stable_after: Duration,
68    pub(crate) max_restart_attempts: Option<u32>,
69    pub(crate) max_backoff_exponent: u32,
70    pub(crate) max_dead_tasks_percentage_threshold: Option<f64>,
71    pub(crate) external_tx: mpsc::UnboundedSender<SupervisorMessage>,
72    pub(crate) external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
73    pub(crate) internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
74    pub(crate) internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
75}
76
77impl Supervisor {
78    pub fn run(self) -> SupervisorHandle {
79        let user_tx = self.external_tx.clone();
80        let handle = tokio::spawn(async move { self.run_and_supervise().await });
81        SupervisorHandle::new(handle, user_tx)
82    }
83
84    async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
85        self.start_all_tasks();
86        self.supervise_all_tasks().await
87    }
88
89    fn start_all_tasks(&mut self) {
90        let task_names: Vec<Arc<str>> = self.tasks.keys().cloned().collect();
91        for task_name in task_names {
92            self.start_task(&task_name);
93        }
94    }
95
96    async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
97        let mut health_check_ticker = interval(self.health_check_interval);
98        let mut pending_restarts: BinaryHeap<PendingRestart> = BinaryHeap::new();
99
100        loop {
101            let next_restart = async {
102                match pending_restarts.peek() {
103                    Some(pr) => tokio::time::sleep_until(pr.deadline).await,
104                    None => std::future::pending().await,
105                }
106            };
107
108            tokio::select! {
109                biased;
110                Some(internal_msg) = self.internal_rx.recv() => {
111                    match internal_msg {
112                        SupervisedTaskMessage::Shutdown => {
113                            #[cfg(feature = "tracing")]
114                            info!("Supervisor received shutdown signal");
115                            return Ok(());
116                        }
117                        SupervisedTaskMessage::Completed(task_name, outcome) => {
118                            #[cfg(feature = "tracing")]
119                            match &outcome {
120                                Ok(()) => info!("Task '{}' completed successfully", task_name),
121                                Err(e) => warn!("Task '{}' completed with error: {e}", task_name),
122                            }
123                            self.handle_task_completion(&task_name, outcome, &mut pending_restarts);
124                        }
125                    }
126                },
127                Some(user_msg) = self.external_rx.recv() => {
128                    self.handle_user_message(user_msg, &mut pending_restarts);
129                },
130                _ = next_restart => {
131                    if let Some(pr) = pending_restarts.pop() {
132                        self.restart_task(&pr.task_name);
133                    }
134                },
135                _ = health_check_ticker.tick() => {
136                    #[cfg(feature = "tracing")]
137                    debug!("Health check tick");
138                    self.check_all_health(&mut pending_restarts);
139                    self.check_dead_tasks_threshold()?;
140                }
141            }
142        }
143    }
144
145    fn handle_user_message(
146        &mut self,
147        msg: SupervisorMessage,
148        pending_restarts: &mut BinaryHeap<PendingRestart>,
149    ) {
150        match msg {
151            SupervisorMessage::AddTask(task_name, task_dyn) => {
152                let key: Arc<str> = Arc::from(task_name);
153
154                if self.tasks.contains_key(&key) {
155                    #[cfg(feature = "tracing")]
156                    warn!("Task '{}' already exists, ignoring add", key);
157                    return;
158                }
159
160                let mut task_handle = TaskHandle::new(task_dyn);
161                task_handle.max_restart_attempts = self.max_restart_attempts;
162                task_handle.base_restart_delay = self.base_restart_delay;
163                task_handle.max_backoff_exponent = self.max_backoff_exponent;
164
165                self.tasks.insert(Arc::clone(&key), task_handle);
166                self.start_task(&key);
167            }
168            SupervisorMessage::RestartTask(task_name) => {
169                let key: Arc<str> = Arc::from(task_name);
170                #[cfg(feature = "tracing")]
171                info!("User requested restart for task: {}", key);
172                self.restart_task(&key);
173            }
174            SupervisorMessage::KillTask(task_name) => {
175                let key: Arc<str> = Arc::from(task_name);
176                if let Some(task_handle) = self.tasks.get_mut(&key) {
177                    if task_handle.status != TaskStatus::Dead {
178                        task_handle.mark(TaskStatus::Dead);
179                        task_handle.clean();
180                    }
181                } else {
182                    #[cfg(feature = "tracing")]
183                    warn!("Attempted to kill non-existent task: {}", key);
184                }
185            }
186            SupervisorMessage::GetTaskStatus(task_name, sender) => {
187                let key: Arc<str> = Arc::from(task_name);
188                let status = self.tasks.get(&key).map(|handle| handle.status);
189                let _ = sender.send(status);
190            }
191            SupervisorMessage::GetAllTaskStatuses(sender) => {
192                let statuses = self
193                    .tasks
194                    .iter()
195                    .map(|(name, handle)| (String::from(name.as_ref()), handle.status))
196                    .collect();
197                let _ = sender.send(statuses);
198            }
199            SupervisorMessage::Shutdown => {
200                #[cfg(feature = "tracing")]
201                info!("User requested supervisor shutdown");
202
203                for (_, task_handle) in self.tasks.iter_mut() {
204                    if task_handle.status != TaskStatus::Dead
205                        && task_handle.status != TaskStatus::Completed
206                    {
207                        task_handle.clean();
208                        task_handle.mark(TaskStatus::Dead);
209                    }
210                }
211                pending_restarts.clear();
212                let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
213            }
214        }
215    }
216
217    fn start_task(&mut self, task_name: &Arc<str>) {
218        let Some(task_handle) = self.tasks.get_mut(task_name) else {
219            return;
220        };
221
222        task_handle.mark(TaskStatus::Healthy);
223
224        let token = CancellationToken::new();
225        task_handle.cancellation_token = Some(token.clone());
226
227        // Cloned from the stored original — owned fields reset, Arc fields shared.
228        let mut task_instance = task_handle.task.clone_box();
229        let internal_tx = self.internal_tx.clone();
230        let name = Arc::clone(task_name);
231
232        let join_handle = tokio::spawn(async move {
233            tokio::select! {
234                _ = token.cancelled() => { }
235                result = task_instance.run_boxed() => {
236                    let _ = internal_tx.send(SupervisedTaskMessage::Completed(name, result));
237                }
238            }
239        });
240
241        task_handle.join_handle = Some(join_handle);
242    }
243
244    fn restart_task(&mut self, task_name: &Arc<str>) {
245        if let Some(task_handle) = self.tasks.get_mut(task_name) {
246            task_handle.clean();
247        }
248        self.start_task(task_name);
249    }
250
251    fn check_all_health(&mut self, pending_restarts: &mut BinaryHeap<PendingRestart>) {
252        let now = Instant::now();
253        let mut failed_names: Vec<Arc<str>> = Vec::new();
254
255        for (task_name, task_handle) in self.tasks.iter_mut() {
256            if task_handle.status != TaskStatus::Healthy {
257                continue;
258            }
259
260            if let Some(handle) = &task_handle.join_handle {
261                if handle.is_finished() {
262                    #[cfg(feature = "tracing")]
263                    warn!(
264                        "Task '{}' unexpectedly finished, marking as failed",
265                        task_name
266                    );
267
268                    task_handle.mark(TaskStatus::Failed);
269                    failed_names.push(Arc::clone(task_name));
270                } else {
271                    // Stability check: reset restart counter after sustained health.
272                    if let Some(healthy_since) = task_handle.healthy_since {
273                        if now.duration_since(healthy_since) > self.task_is_stable_after
274                            && task_handle.restart_attempts > 0
275                        {
276                            #[cfg(feature = "tracing")]
277                            info!(
278                                "Task '{}' is now stable, resetting restart attempts",
279                                task_name
280                            );
281                            task_handle.restart_attempts = 0;
282                        }
283                    } else {
284                        task_handle.healthy_since = Some(now);
285                    }
286                }
287            } else {
288                #[cfg(feature = "tracing")]
289                error!("Task '{}' has no join handle, marking as failed", task_name);
290
291                task_handle.mark(TaskStatus::Failed);
292                failed_names.push(Arc::clone(task_name));
293            }
294        }
295
296        for task_name in failed_names {
297            self.schedule_restart_or_kill(&task_name, pending_restarts);
298        }
299    }
300
301    fn handle_task_completion(
302        &mut self,
303        task_name: &Arc<str>,
304        outcome: TaskResult,
305        pending_restarts: &mut BinaryHeap<PendingRestart>,
306    ) {
307        let Some(task_handle) = self.tasks.get_mut(task_name) else {
308            #[cfg(feature = "tracing")]
309            warn!("Completion for unknown task: {}", task_name);
310            return;
311        };
312
313        task_handle.clean();
314
315        match outcome {
316            Ok(()) => {
317                #[cfg(feature = "tracing")]
318                info!("Task '{}' completed successfully", task_name);
319                task_handle.mark(TaskStatus::Completed);
320            }
321            #[allow(unused_variables)]
322            Err(ref e) => {
323                #[cfg(feature = "tracing")]
324                error!("Task '{}' failed: {:?}", task_name, e);
325
326                task_handle.mark(TaskStatus::Failed);
327                self.schedule_restart_or_kill(task_name, pending_restarts);
328            }
329        }
330    }
331
332    /// Schedules a restart with backoff, or marks the task dead if retries exhausted.
333    fn schedule_restart_or_kill(
334        &mut self,
335        task_name: &Arc<str>,
336        pending_restarts: &mut BinaryHeap<PendingRestart>,
337    ) {
338        let Some(task_handle) = self.tasks.get_mut(task_name) else {
339            return;
340        };
341
342        if task_handle.has_exceeded_max_retries() {
343            #[cfg(feature = "tracing")]
344            error!(
345                "Task '{}' exceeded max restart attempts ({:?}), marking as dead",
346                task_name,
347                task_handle
348                    .max_restart_attempts
349                    .expect("is provided if has exceeded")
350            );
351
352            task_handle.mark(TaskStatus::Dead);
353            task_handle.clean();
354            return;
355        }
356
357        task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
358        let restart_delay = task_handle.restart_delay();
359
360        #[cfg(feature = "tracing")]
361        info!(
362            "Scheduling restart for task '{}' in {:?} (attempt {}/{})",
363            task_name,
364            restart_delay,
365            task_handle.restart_attempts,
366            task_handle
367                .max_restart_attempts
368                .map(|t| t.to_string())
369                .unwrap_or_else(|| "\u{221e}".to_string())
370        );
371
372        pending_restarts.push(PendingRestart {
373            deadline: tokio::time::Instant::now() + restart_delay,
374            task_name: Arc::clone(task_name),
375        });
376    }
377
378    fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
379        let Some(threshold) = self.max_dead_tasks_percentage_threshold else {
380            return Ok(());
381        };
382
383        let total_task_count = self.tasks.len();
384        if total_task_count == 0 {
385            return Ok(());
386        }
387
388        let dead_task_count = self
389            .tasks
390            .values()
391            .filter(|handle| handle.status == TaskStatus::Dead)
392            .count();
393
394        let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
395
396        if current_dead_percentage <= threshold {
397            return Ok(());
398        }
399
400        #[cfg(feature = "tracing")]
401        error!(
402            "Dead tasks threshold exceeded: {:.2}% > {:.2}% ({}/{} tasks dead)",
403            current_dead_percentage * 100.0,
404            threshold * 100.0,
405            dead_task_count,
406            total_task_count
407        );
408
409        #[allow(unused_variables)]
410        for (task_name, task_handle) in self.tasks.iter_mut() {
411            if task_handle.status != TaskStatus::Dead && task_handle.status != TaskStatus::Completed
412            {
413                #[cfg(feature = "tracing")]
414                debug!("Killing task '{}' due to threshold breach", task_name);
415
416                task_handle.clean();
417                task_handle.mark(TaskStatus::Dead);
418            }
419        }
420
421        Err(SupervisorError::TooManyDeadTasks {
422            current_percentage: current_dead_percentage,
423            threshold,
424        })
425    }
426}