task_supervisor/task/
mod.rs

1use std::time::{Duration, Instant};
2
3use tokio::task::JoinHandle;
4use tokio_util::sync::CancellationToken;
5
6pub type DynTask = Box<dyn CloneableSupervisedTask>;
7
8pub type TaskError = anyhow::Error;
9pub type TaskResult = Result<(), TaskError>;
10
11#[async_trait::async_trait]
12pub trait SupervisedTask: Send + 'static {
13    /// Runs the task until completion or failure.
14    async fn run(&mut self) -> TaskResult;
15}
16
17pub trait CloneableSupervisedTask: SupervisedTask {
18    fn clone_box(&self) -> Box<dyn CloneableSupervisedTask>;
19}
20
21impl<T> CloneableSupervisedTask for T
22where
23    T: SupervisedTask + Clone + Send + 'static,
24{
25    fn clone_box(&self) -> Box<dyn CloneableSupervisedTask> {
26        Box::new(self.clone())
27    }
28}
29
30/// Represents the current state of a supervised task.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum TaskStatus {
33    /// Task has been created but not yet started.
34    Created,
35    /// Task is running and healthy.
36    Healthy,
37    /// Task has failed and is pending restart.
38    Failed,
39    /// Task has completed successfully.
40    Completed,
41    /// Task has failed too many times and is terminated.
42    Dead,
43}
44
45impl TaskStatus {
46    pub fn is_restarting(&self) -> bool {
47        matches!(self, TaskStatus::Failed)
48    }
49
50    pub fn is_healthy(&self) -> bool {
51        matches!(self, TaskStatus::Healthy)
52    }
53
54    pub fn is_dead(&self) -> bool {
55        matches!(self, TaskStatus::Dead)
56    }
57
58    pub fn has_completed(&self) -> bool {
59        matches!(self, TaskStatus::Completed)
60    }
61}
62
63impl std::fmt::Display for TaskStatus {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            Self::Created => write!(f, "created"),
67            Self::Healthy => write!(f, "healthy"),
68            Self::Failed => write!(f, "failed"),
69            Self::Completed => write!(f, "completed"),
70            Self::Dead => write!(f, "dead"),
71        }
72    }
73}
74
75pub(crate) struct TaskHandle {
76    pub(crate) status: TaskStatus,
77    pub(crate) task: DynTask,
78    pub(crate) main_task_handle: Option<JoinHandle<()>>,
79    pub(crate) completion_task_handle: Option<JoinHandle<()>>,
80    pub(crate) restart_attempts: u32,
81    pub(crate) started_at: Option<Instant>,
82    pub(crate) healthy_since: Option<Instant>,
83    pub(crate) cancellation_token: Option<CancellationToken>,
84    max_restart_attempts: u32,
85    base_restart_delay: Duration,
86}
87
88impl TaskHandle {
89    /// Creates a `TaskHandle` from a boxed task with default configuration.
90    pub(crate) fn new(
91        task: Box<dyn CloneableSupervisedTask>,
92        max_restart_attempts: u32,
93        base_restart_delay: Duration,
94    ) -> Self {
95        Self {
96            status: TaskStatus::Created,
97            task,
98            main_task_handle: None,
99            completion_task_handle: None,
100            restart_attempts: 0,
101            started_at: None,
102            healthy_since: None,
103            cancellation_token: None,
104            max_restart_attempts,
105            base_restart_delay,
106        }
107    }
108
109    /// Creates a new `TaskHandle` with custom restart configuration.
110    pub(crate) fn from_task<T: CloneableSupervisedTask + 'static>(
111        task: T,
112        max_restart_attempts: u32,
113        base_restart_delay: Duration,
114    ) -> Self {
115        let task = Box::new(task);
116        Self::new(task, max_restart_attempts, base_restart_delay)
117    }
118
119    /// Calculates the restart delay using exponential backoff.
120    pub(crate) fn restart_delay(&self) -> Duration {
121        let factor = 2u32.saturating_pow(self.restart_attempts.min(5));
122        self.base_restart_delay.saturating_mul(factor)
123    }
124
125    /// Checks if the task has exceeded its maximum restart attempts.
126    pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
127        self.restart_attempts >= self.max_restart_attempts
128    }
129
130    /// Updates the task's status.
131    pub(crate) fn mark(&mut self, status: TaskStatus) {
132        self.status = status;
133    }
134
135    /// Cleans up the task by aborting its handle and resetting state.
136    pub(crate) async fn clean(&mut self) {
137        if let Some(token) = self.cancellation_token.take() {
138            token.cancel();
139        }
140        if let Some(handle) = self.main_task_handle.take() {
141            handle.abort();
142        }
143        if let Some(handle) = self.completion_task_handle.take() {
144            handle.abort();
145        }
146        self.healthy_since = None;
147        self.started_at = None;
148    }
149}