Skip to main content

task_supervisor/task/
mod.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    time::{Duration, Instant},
5};
6
7use tokio::task::JoinHandle;
8use tokio_util::sync::CancellationToken;
9
10#[cfg(feature = "anyhow")]
11pub type TaskError = anyhow::Error;
12#[cfg(not(feature = "anyhow"))]
13pub type TaskError = Box<dyn std::error::Error + Send + Sync + 'static>;
14
15pub type TaskResult = Result<(), TaskError>;
16
17/// The trait users implement for tasks managed by the supervisor.
18///
19/// # Clone and restart semantics
20///
21/// The supervisor stores the **original** instance and clones it for each
22/// run. Mutations via `&mut self` only live in the clone and are lost on
23/// restart. Shared state (`Arc<...>`) survives because `Clone` just bumps
24/// the refcount.
25///
26/// Use owned fields for per-run state, `Arc` for cross-restart state.
27///
28/// # Example
29///
30/// ```rust
31/// use task_supervisor::{SupervisedTask, TaskResult};
32/// use std::sync::Arc;
33/// use std::sync::atomic::{AtomicUsize, Ordering};
34///
35/// #[derive(Clone)]
36/// struct MyTask {
37///     /// Reset to 0 on every restart (owned, cloned from original).
38///     local_counter: u64,
39///     /// Shared across restarts (Arc, cloned by reference).
40///     total_runs: Arc<AtomicUsize>,
41/// }
42///
43/// impl SupervisedTask for MyTask {
44///     async fn run(&mut self) -> TaskResult {
45///         self.total_runs.fetch_add(1, Ordering::Relaxed);
46///         self.local_counter += 1;
47///         // local_counter is always 1 here — fresh clone each restart.
48///         Ok(())
49///     }
50/// }
51/// ```
52pub trait SupervisedTask: Send + 'static {
53    /// Runs the task until completion or failure.
54    ///
55    /// Mutations to `&mut self` are **not** preserved across restarts.
56    /// See the [trait-level docs](SupervisedTask) for details.
57    fn run(&mut self) -> impl Future<Output = TaskResult> + Send;
58}
59
60/// Dyn-compatible wrapper for `SupervisedTask`. Not user-facing.
61pub(crate) trait DynSupervisedTask: Send + 'static {
62    fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>>;
63    fn clone_box(&self) -> Box<dyn DynSupervisedTask>;
64}
65
66impl<T> DynSupervisedTask for T
67where
68    T: SupervisedTask + Clone + Send + 'static,
69{
70    fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>> {
71        Box::pin(self.run())
72    }
73
74    fn clone_box(&self) -> Box<dyn DynSupervisedTask> {
75        Box::new(self.clone())
76    }
77}
78
79pub(crate) type DynTask = Box<dyn DynSupervisedTask>;
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum TaskStatus {
83    Created,
84    Healthy,
85    Failed,
86    Completed,
87    Dead,
88}
89
90impl TaskStatus {
91    pub fn is_restarting(&self) -> bool {
92        matches!(self, TaskStatus::Failed)
93    }
94
95    pub fn is_healthy(&self) -> bool {
96        matches!(self, TaskStatus::Healthy)
97    }
98
99    pub fn is_dead(&self) -> bool {
100        matches!(self, TaskStatus::Dead)
101    }
102
103    pub fn has_completed(&self) -> bool {
104        matches!(self, TaskStatus::Completed)
105    }
106}
107
108impl std::fmt::Display for TaskStatus {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        match self {
111            Self::Created => write!(f, "created"),
112            Self::Healthy => write!(f, "healthy"),
113            Self::Failed => write!(f, "failed"),
114            Self::Completed => write!(f, "completed"),
115            Self::Dead => write!(f, "dead"),
116        }
117    }
118}
119
120pub(crate) struct TaskHandle {
121    pub(crate) status: TaskStatus,
122    pub(crate) task: DynTask,
123    pub(crate) join_handle: Option<JoinHandle<()>>,
124    pub(crate) restart_attempts: u32,
125    pub(crate) healthy_since: Option<Instant>,
126    pub(crate) cancellation_token: Option<CancellationToken>,
127    pub(crate) max_restart_attempts: Option<u32>,
128    pub(crate) base_restart_delay: Duration,
129    pub(crate) max_backoff_exponent: u32,
130}
131
132impl TaskHandle {
133    pub(crate) fn new(task: DynTask) -> Self {
134        Self {
135            status: TaskStatus::Created,
136            task,
137            join_handle: None,
138            restart_attempts: 0,
139            healthy_since: None,
140            cancellation_token: None,
141            // Defaults — overwritten by the builder at build() time.
142            max_restart_attempts: None,
143            base_restart_delay: Duration::from_secs(1),
144            max_backoff_exponent: 5,
145        }
146    }
147
148    pub(crate) fn from_task<T: SupervisedTask + Clone>(task: T) -> Self {
149        Self::new(Box::new(task))
150    }
151
152    /// Delay = base_restart_delay * 2^min(attempts, max_backoff_exponent).
153    pub(crate) fn restart_delay(&self) -> Duration {
154        let factor = 2u32.saturating_pow(self.restart_attempts.min(self.max_backoff_exponent));
155        self.base_restart_delay.saturating_mul(factor)
156    }
157
158    pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
159        if let Some(max_restart_attempts) = self.max_restart_attempts {
160            self.restart_attempts >= max_restart_attempts
161        } else {
162            false
163        }
164    }
165
166    pub(crate) fn mark(&mut self, status: TaskStatus) {
167        self.status = status;
168    }
169
170    pub(crate) fn clean(&mut self) {
171        if let Some(token) = self.cancellation_token.take() {
172            token.cancel();
173        }
174        if let Some(handle) = self.join_handle.take() {
175            handle.abort();
176        }
177        self.healthy_since = None;
178    }
179}