1use std::time::{Duration, Instant};
2
3#[async_trait::async_trait]
4pub trait SupervisedTask {
5 type Error: Send;
6
7 fn name(&self) -> Option<&str> {
8 None
9 }
10
11 async fn run_forever(&mut self) -> Result<(), Self::Error>;
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum TaskStatus {
17 Created,
19 Starting,
21 Healthy,
23 Failed,
25 Dead,
27}
28
29#[derive(Debug)]
30pub(crate) struct TaskHandle<T: SupervisedTask> {
31 pub(crate) status: TaskStatus,
32 pub(crate) task: T,
33 pub(crate) handle: Option<tokio::task::JoinHandle<()>>,
34 pub(crate) last_heartbeat: Option<Instant>,
35 pub(crate) restart_attempts: u32,
36 max_restart_attempts: u32,
37 base_restart_delay: Duration,
38}
39
40impl<T: SupervisedTask> TaskHandle<T> {
41 pub(crate) fn new(task: T) -> Self {
42 Self {
43 status: TaskStatus::Created,
44 task,
45 handle: None,
46 last_heartbeat: None,
47 restart_attempts: 0,
48 max_restart_attempts: 5,
49 base_restart_delay: Duration::from_secs(1),
50 }
51 }
52
53 pub(crate) fn ticked_at(&mut self, at: Instant) {
54 self.last_heartbeat = Some(at);
55 }
56
57 pub(crate) fn time_since_last_heartbeat(&self) -> Option<Duration> {
58 self.last_heartbeat
59 .map(|last_heartbeat| Instant::now().duration_since(last_heartbeat))
60 }
61
62 pub(crate) fn has_crashed(&self, timeout_threshold: Duration) -> bool {
63 let Some(time_since_last_heartbeat) = self.time_since_last_heartbeat() else {
64 return self.status != TaskStatus::Dead;
65 };
66 (self.status != TaskStatus::Dead) && (time_since_last_heartbeat > timeout_threshold)
67 }
68
69 pub(crate) fn restart_delay(&self) -> Duration {
70 let factor = 2u32.saturating_pow(self.restart_attempts.min(5));
71 self.base_restart_delay.saturating_mul(factor)
72 }
73
74 pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
75 self.restart_attempts >= self.max_restart_attempts
76 }
77
78 pub(crate) fn mark(&mut self, status: TaskStatus) {
79 self.status = status;
80 }
81
82 pub(crate) fn clean_before_restart(&mut self) {
83 self.last_heartbeat = None;
84 if let Some(still_running_task) = self.handle.take() {
85 still_running_task.abort();
86 }
87 }
88}