Skip to main content

task_supervisor/task/
mod.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    time::{Duration, Instant},
5};
6
7use tokio::task::JoinHandle;
8use tokio_util::sync::CancellationToken;
9
10pub type TaskError = Box<dyn std::error::Error + Send + Sync + 'static>;
11pub type TaskResult = Result<(), TaskError>;
12
13/// The trait users implement for tasks managed by the supervisor.
14///
15/// # Clone and restart semantics
16///
17/// The supervisor stores the **original** instance and clones it for each
18/// run. Mutations via `&mut self` only live in the clone and are lost on
19/// restart. Shared state (`Arc<...>`) survives because `Clone` just bumps
20/// the refcount.
21///
22/// Use owned fields for per-run state, `Arc` for cross-restart state.
23///
24/// # Example
25///
26/// ```rust
27/// use task_supervisor::{SupervisedTask, TaskResult};
28/// use std::sync::Arc;
29/// use std::sync::atomic::{AtomicUsize, Ordering};
30///
31/// #[derive(Clone)]
32/// struct MyTask {
33///     /// Reset to 0 on every restart (owned, cloned from original).
34///     local_counter: u64,
35///     /// Shared across restarts (Arc, cloned by reference).
36///     total_runs: Arc<AtomicUsize>,
37/// }
38///
39/// impl SupervisedTask for MyTask {
40///     async fn run(&mut self) -> TaskResult {
41///         self.total_runs.fetch_add(1, Ordering::Relaxed);
42///         self.local_counter += 1;
43///         // local_counter is always 1 here — fresh clone each restart.
44///         Ok(())
45///     }
46/// }
47/// ```
48pub trait SupervisedTask: Send + 'static {
49    /// Runs the task until completion or failure.
50    ///
51    /// Mutations to `&mut self` are **not** preserved across restarts.
52    /// See the [trait-level docs](SupervisedTask) for details.
53    fn run(&mut self) -> impl Future<Output = TaskResult> + Send;
54}
55
56/// Dyn-compatible wrapper for `SupervisedTask`. Not user-facing.
57pub(crate) trait DynSupervisedTask: Send + 'static {
58    fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>>;
59    fn clone_box(&self) -> Box<dyn DynSupervisedTask>;
60}
61
62impl<T> DynSupervisedTask for T
63where
64    T: SupervisedTask + Clone + Send + 'static,
65{
66    fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>> {
67        Box::pin(self.run())
68    }
69
70    fn clone_box(&self) -> Box<dyn DynSupervisedTask> {
71        Box::new(self.clone())
72    }
73}
74
75pub(crate) type DynTask = Box<dyn DynSupervisedTask>;
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum TaskStatus {
79    Created,
80    Healthy,
81    Failed,
82    Completed,
83    Dead,
84}
85
86impl TaskStatus {
87    pub fn is_restarting(&self) -> bool {
88        matches!(self, TaskStatus::Failed)
89    }
90
91    pub fn is_healthy(&self) -> bool {
92        matches!(self, TaskStatus::Healthy)
93    }
94
95    pub fn is_dead(&self) -> bool {
96        matches!(self, TaskStatus::Dead)
97    }
98
99    pub fn has_completed(&self) -> bool {
100        matches!(self, TaskStatus::Completed)
101    }
102}
103
104impl std::fmt::Display for TaskStatus {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        match self {
107            Self::Created => write!(f, "created"),
108            Self::Healthy => write!(f, "healthy"),
109            Self::Failed => write!(f, "failed"),
110            Self::Completed => write!(f, "completed"),
111            Self::Dead => write!(f, "dead"),
112        }
113    }
114}
115
116pub(crate) struct TaskHandle {
117    pub(crate) status: TaskStatus,
118    pub(crate) task: DynTask,
119    pub(crate) join_handle: Option<JoinHandle<()>>,
120    pub(crate) restart_attempts: u32,
121    pub(crate) healthy_since: Option<Instant>,
122    pub(crate) cancellation_token: Option<CancellationToken>,
123    pub(crate) max_restart_attempts: Option<u32>,
124    pub(crate) base_restart_delay: Duration,
125    pub(crate) max_backoff_exponent: u32,
126}
127
128impl TaskHandle {
129    pub(crate) fn new(task: DynTask) -> Self {
130        Self {
131            status: TaskStatus::Created,
132            task,
133            join_handle: None,
134            restart_attempts: 0,
135            healthy_since: None,
136            cancellation_token: None,
137            // Defaults — overwritten by the builder at build() time.
138            max_restart_attempts: None,
139            base_restart_delay: Duration::from_secs(1),
140            max_backoff_exponent: 5,
141        }
142    }
143
144    pub(crate) fn from_task<T: SupervisedTask + Clone>(task: T) -> Self {
145        Self::new(Box::new(task))
146    }
147
148    /// Delay = base_restart_delay * 2^min(attempts, max_backoff_exponent).
149    pub(crate) fn restart_delay(&self) -> Duration {
150        let factor = 2u32.saturating_pow(self.restart_attempts.min(self.max_backoff_exponent));
151        self.base_restart_delay.saturating_mul(factor)
152    }
153
154    pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
155        if let Some(max_restart_attempts) = self.max_restart_attempts {
156            self.restart_attempts >= max_restart_attempts
157        } else {
158            false
159        }
160    }
161
162    pub(crate) fn mark(&mut self, status: TaskStatus) {
163        self.status = status;
164    }
165
166    pub(crate) fn clean(&mut self) {
167        if let Some(token) = self.cancellation_token.take() {
168            token.cancel();
169        }
170        if let Some(handle) = self.join_handle.take() {
171            handle.abort();
172        }
173        self.healthy_since = None;
174    }
175}