task_supervisor/task/
mod.rs

1use std::{
2    error::Error,
3    time::{Duration, Instant},
4};
5
6use tokio::task::JoinHandle;
7use tokio_util::sync::CancellationToken;
8
9pub type DynTask = Box<dyn CloneableSupervisedTask>;
10pub type TaskError = Box<dyn Error + Send + Sync>;
11
12#[async_trait::async_trait]
13pub trait SupervisedTask: Send + 'static {
14    /// Runs the task until completion or failure.
15    async fn run(&mut self) -> Result<TaskOutcome, TaskError>;
16}
17
18pub trait CloneableSupervisedTask: SupervisedTask {
19    fn clone_box(&self) -> Box<dyn CloneableSupervisedTask>;
20}
21
22impl<T> CloneableSupervisedTask for T
23where
24    T: SupervisedTask + Clone + Send + 'static,
25{
26    fn clone_box(&self) -> Box<dyn CloneableSupervisedTask> {
27        Box::new(self.clone())
28    }
29}
30
31/// Represents the current state of a supervised task.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum TaskStatus {
34    /// Task has been created but not yet started.
35    Created,
36    /// Task is running and healthy.
37    Healthy,
38    /// Task has failed and is pending restart.
39    Failed,
40    /// Task has completed successfully.
41    Completed,
42    /// Task has failed too many times and is terminated.
43    Dead,
44}
45
46impl TaskStatus {
47    pub fn is_restarting(&self) -> bool {
48        matches!(self, TaskStatus::Failed)
49    }
50
51    pub fn is_healthy(&self) -> bool {
52        matches!(self, TaskStatus::Healthy)
53    }
54
55    pub fn is_dead(&self) -> bool {
56        matches!(self, TaskStatus::Dead)
57    }
58
59    pub fn has_completed(&self) -> bool {
60        matches!(self, TaskStatus::Completed)
61    }
62}
63
64impl std::fmt::Display for TaskStatus {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        match self {
67            Self::Created => write!(f, "created"),
68            Self::Healthy => write!(f, "healthy"),
69            Self::Failed => write!(f, "failed"),
70            Self::Completed => write!(f, "completed"),
71            Self::Dead => write!(f, "dead"),
72        }
73    }
74}
75
76/// Outcome of a task's execution.
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum TaskOutcome {
79    /// Task completed successfully and should not be restarted.
80    Completed,
81    /// Task failed and may be restarted, with an optional reason.
82    Failed(String),
83}
84
85impl std::fmt::Display for TaskOutcome {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match self {
88            Self::Completed => write!(f, "completed"),
89            Self::Failed(e) => write!(f, "failed: {e}"),
90        }
91    }
92}
93
94pub(crate) struct TaskHandle {
95    pub(crate) status: TaskStatus,
96    pub(crate) task: DynTask,
97    pub(crate) main_task_handle: Option<JoinHandle<()>>,
98    pub(crate) completion_task_handle: Option<JoinHandle<()>>,
99    pub(crate) restart_attempts: u32,
100    pub(crate) started_at: Option<Instant>,
101    pub(crate) healthy_since: Option<Instant>,
102    pub(crate) cancellation_token: Option<CancellationToken>,
103    max_restart_attempts: u32,
104    base_restart_delay: Duration,
105}
106
107impl TaskHandle {
108    /// Creates a `TaskHandle` from a boxed task with default configuration.
109    pub(crate) fn new(
110        task: Box<dyn CloneableSupervisedTask>,
111        max_restart_attempts: u32,
112        base_restart_delay: Duration,
113    ) -> Self {
114        Self {
115            status: TaskStatus::Created,
116            task,
117            main_task_handle: None,
118            completion_task_handle: None,
119            restart_attempts: 0,
120            started_at: None,
121            healthy_since: None,
122            cancellation_token: None,
123            max_restart_attempts,
124            base_restart_delay,
125        }
126    }
127
128    /// Creates a new `TaskHandle` with custom restart configuration.
129    pub(crate) fn from_task<T: CloneableSupervisedTask + 'static>(
130        task: T,
131        max_restart_attempts: u32,
132        base_restart_delay: Duration,
133    ) -> Self {
134        let task = Box::new(task);
135        Self::new(task, max_restart_attempts, base_restart_delay)
136    }
137
138    /// Calculates the restart delay using exponential backoff.
139    pub(crate) fn restart_delay(&self) -> Duration {
140        let factor = 2u32.saturating_pow(self.restart_attempts.min(5));
141        self.base_restart_delay.saturating_mul(factor)
142    }
143
144    /// Checks if the task has exceeded its maximum restart attempts.
145    pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
146        self.restart_attempts >= self.max_restart_attempts
147    }
148
149    /// Updates the task's status.
150    pub(crate) fn mark(&mut self, status: TaskStatus) {
151        self.status = status;
152    }
153
154    /// Cleans up the task by aborting its handle and resetting state.
155    pub(crate) async fn clean(&mut self) {
156        if let Some(token) = self.cancellation_token.take() {
157            token.cancel();
158        }
159        if let Some(handle) = self.main_task_handle.take() {
160            handle.abort();
161        }
162        if let Some(handle) = self.completion_task_handle.take() {
163            handle.abort();
164        }
165        self.healthy_since = None;
166        self.started_at = None;
167    }
168}