Skip to main content

task_supervisor/task/
mod.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    time::{Duration, Instant},
5};
6
7use tokio::task::JoinHandle;
8use tokio_util::sync::CancellationToken;
9
10pub type TaskError = anyhow::Error;
11pub type TaskResult = Result<(), TaskError>;
12
13/// The trait users implement for tasks managed by the supervisor.
14///
15/// # Example
16/// ```rust
17/// use task_supervisor::{SupervisedTask, TaskResult};
18///
19/// #[derive(Clone)]
20/// struct MyTask;
21///
22/// impl SupervisedTask for MyTask {
23///     async fn run(&mut self) -> TaskResult {
24///         // do work...
25///         Ok(())
26///     }
27/// }
28/// ```
29pub trait SupervisedTask: Send + 'static {
30    /// Runs the task until completion or failure.
31    fn run(&mut self) -> impl Future<Output = TaskResult> + Send;
32}
33
34// ---- Internal dyn-compatible wrapper ----
35
36/// Object-safe version of `SupervisedTask` used internally for dynamic dispatch.
37/// Users never see or implement this trait directly.
38pub(crate) trait DynSupervisedTask: Send + 'static {
39    fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>>;
40    fn clone_box(&self) -> Box<dyn DynSupervisedTask>;
41}
42
43impl<T> DynSupervisedTask for T
44where
45    T: SupervisedTask + Clone + Send + 'static,
46{
47    fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>> {
48        Box::pin(self.run())
49    }
50
51    fn clone_box(&self) -> Box<dyn DynSupervisedTask> {
52        Box::new(self.clone())
53    }
54}
55
56pub(crate) type DynTask = Box<dyn DynSupervisedTask>;
57
58/// Represents the current state of a supervised task.
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum TaskStatus {
61    /// Task has been created but not yet started.
62    Created,
63    /// Task is running and healthy.
64    Healthy,
65    /// Task has failed and is pending restart.
66    Failed,
67    /// Task has completed successfully.
68    Completed,
69    /// Task has failed too many times and is terminated.
70    Dead,
71}
72
73impl TaskStatus {
74    pub fn is_restarting(&self) -> bool {
75        matches!(self, TaskStatus::Failed)
76    }
77
78    pub fn is_healthy(&self) -> bool {
79        matches!(self, TaskStatus::Healthy)
80    }
81
82    pub fn is_dead(&self) -> bool {
83        matches!(self, TaskStatus::Dead)
84    }
85
86    pub fn has_completed(&self) -> bool {
87        matches!(self, TaskStatus::Completed)
88    }
89}
90
91impl std::fmt::Display for TaskStatus {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        match self {
94            Self::Created => write!(f, "created"),
95            Self::Healthy => write!(f, "healthy"),
96            Self::Failed => write!(f, "failed"),
97            Self::Completed => write!(f, "completed"),
98            Self::Dead => write!(f, "dead"),
99        }
100    }
101}
102
103pub(crate) struct TaskHandle {
104    pub(crate) status: TaskStatus,
105    pub(crate) task: DynTask,
106    pub(crate) join_handle: Option<JoinHandle<()>>,
107    pub(crate) restart_attempts: u32,
108    pub(crate) healthy_since: Option<Instant>,
109    pub(crate) cancellation_token: Option<CancellationToken>,
110    pub(crate) max_restart_attempts: Option<u32>,
111    pub(crate) base_restart_delay: Duration,
112    pub(crate) max_backoff_exponent: u32,
113}
114
115impl TaskHandle {
116    /// Creates a `TaskHandle` from a boxed task.
117    /// Configuration is set later by the builder at `build()` time.
118    pub(crate) fn new(task: DynTask) -> Self {
119        Self {
120            status: TaskStatus::Created,
121            task,
122            join_handle: None,
123            restart_attempts: 0,
124            healthy_since: None,
125            cancellation_token: None,
126            max_restart_attempts: None,
127            base_restart_delay: Duration::from_secs(1),
128            max_backoff_exponent: 5,
129        }
130    }
131
132    /// Creates a new `TaskHandle` from a concrete task type.
133    pub(crate) fn from_task<T: SupervisedTask + Clone>(task: T) -> Self {
134        Self::new(Box::new(task))
135    }
136
137    /// Calculates the restart delay using exponential backoff.
138    pub(crate) fn restart_delay(&self) -> Duration {
139        let factor = 2u32.saturating_pow(self.restart_attempts.min(self.max_backoff_exponent));
140        self.base_restart_delay.saturating_mul(factor)
141    }
142
143    /// Checks if the task has exceeded its maximum restart attempts.
144    pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
145        if let Some(max_restart_attempts) = self.max_restart_attempts {
146            self.restart_attempts >= max_restart_attempts
147        } else {
148            false
149        }
150    }
151
152    /// Updates the task's status.
153    pub(crate) fn mark(&mut self, status: TaskStatus) {
154        self.status = status;
155    }
156
157    /// Cleans up the task by cancelling its token and aborting the join handle.
158    pub(crate) fn clean(&mut self) {
159        if let Some(token) = self.cancellation_token.take() {
160            token.cancel();
161        }
162        if let Some(handle) = self.join_handle.take() {
163            handle.abort();
164        }
165        self.healthy_since = None;
166    }
167}