task_supervisor/
task.rs

1use std::time::{Duration, Instant};
2
3#[async_trait::async_trait]
4pub trait SupervisedTask {
5    type Error: Send;
6
7    fn name(&self) -> Option<&str> {
8        None
9    }
10
11    async fn run_forever(&mut self) -> Result<(), Self::Error>;
12}
13
14/// Status of a task
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum TaskStatus {
17    /// Task has just been created and will be starting soon
18    Created,
19    /// Task is starting up
20    Starting,
21    /// Task is running normally
22    Healthy,
23    /// Task failed and will be restarted
24    Failed,
25    /// Task has exceeded max retries & we stopped trying
26    Dead,
27}
28
29#[derive(Debug)]
30pub(crate) struct TaskHandle<T: SupervisedTask> {
31    pub(crate) status: TaskStatus,
32    pub(crate) task: T,
33    pub(crate) handle: Option<tokio::task::JoinHandle<()>>,
34    pub(crate) last_heartbeat: Option<Instant>,
35    pub(crate) restart_attempts: u32,
36    max_restart_attempts: u32,
37    base_restart_delay: Duration,
38}
39
40impl<T: SupervisedTask> TaskHandle<T> {
41    pub(crate) fn new(task: T) -> Self {
42        Self {
43            status: TaskStatus::Created,
44            task,
45            handle: None,
46            last_heartbeat: None,
47            restart_attempts: 0,
48            max_restart_attempts: 5,
49            base_restart_delay: Duration::from_secs(1),
50        }
51    }
52
53    pub(crate) fn ticked_at(&mut self, at: Instant) {
54        self.last_heartbeat = Some(at);
55    }
56
57    pub(crate) fn time_since_last_heartbeat(&self) -> Option<Duration> {
58        self.last_heartbeat
59            .map(|last_heartbeat| Instant::now().duration_since(last_heartbeat))
60    }
61
62    pub(crate) fn has_crashed(&self, timeout_threshold: Duration) -> bool {
63        let Some(time_since_last_heartbeat) = self.time_since_last_heartbeat() else {
64            return self.status != TaskStatus::Dead;
65        };
66        (self.status != TaskStatus::Dead) && (time_since_last_heartbeat > timeout_threshold)
67    }
68
69    pub(crate) fn restart_delay(&self) -> Duration {
70        let factor = 2u32.saturating_pow(self.restart_attempts.min(5));
71        self.base_restart_delay.saturating_mul(factor)
72    }
73
74    pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
75        self.restart_attempts >= self.max_restart_attempts
76    }
77
78    pub(crate) fn mark(&mut self, status: TaskStatus) {
79        self.status = status;
80    }
81
82    pub(crate) fn clean_before_restart(&mut self) {
83        self.last_heartbeat = None;
84        if let Some(still_running_task) = self.handle.take() {
85            still_running_task.abort();
86        }
87    }
88}