task_supervisor/task/
mod.rs1use std::{
2 future::Future,
3 pin::Pin,
4 time::{Duration, Instant},
5};
6
7use tokio::task::JoinHandle;
8use tokio_util::sync::CancellationToken;
9
10#[cfg(feature = "anyhow")]
11pub type TaskError = anyhow::Error;
12#[cfg(not(feature = "anyhow"))]
13pub type TaskError = Box<dyn std::error::Error + Send + Sync + 'static>;
14
15pub type TaskResult = Result<(), TaskError>;
16
17pub trait SupervisedTask: Send + 'static {
53 fn run(&mut self) -> impl Future<Output = TaskResult> + Send;
58}
59
60pub(crate) trait DynSupervisedTask: Send + 'static {
62 fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>>;
63 fn clone_box(&self) -> Box<dyn DynSupervisedTask>;
64}
65
66impl<T> DynSupervisedTask for T
67where
68 T: SupervisedTask + Clone + Send + 'static,
69{
70 fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>> {
71 Box::pin(self.run())
72 }
73
74 fn clone_box(&self) -> Box<dyn DynSupervisedTask> {
75 Box::new(self.clone())
76 }
77}
78
79pub(crate) type DynTask = Box<dyn DynSupervisedTask>;
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum TaskStatus {
83 Created,
84 Healthy,
85 Failed,
86 Completed,
87 Dead,
88}
89
90impl TaskStatus {
91 pub fn is_restarting(&self) -> bool {
92 matches!(self, TaskStatus::Failed)
93 }
94
95 pub fn is_healthy(&self) -> bool {
96 matches!(self, TaskStatus::Healthy)
97 }
98
99 pub fn is_dead(&self) -> bool {
100 matches!(self, TaskStatus::Dead)
101 }
102
103 pub fn has_completed(&self) -> bool {
104 matches!(self, TaskStatus::Completed)
105 }
106}
107
108impl std::fmt::Display for TaskStatus {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 match self {
111 Self::Created => write!(f, "created"),
112 Self::Healthy => write!(f, "healthy"),
113 Self::Failed => write!(f, "failed"),
114 Self::Completed => write!(f, "completed"),
115 Self::Dead => write!(f, "dead"),
116 }
117 }
118}
119
120pub(crate) struct TaskHandle {
121 pub(crate) status: TaskStatus,
122 pub(crate) task: DynTask,
123 pub(crate) join_handle: Option<JoinHandle<()>>,
124 pub(crate) restart_attempts: u32,
125 pub(crate) healthy_since: Option<Instant>,
126 pub(crate) cancellation_token: Option<CancellationToken>,
127 pub(crate) max_restart_attempts: Option<u32>,
128 pub(crate) base_restart_delay: Duration,
129 pub(crate) max_backoff_exponent: u32,
130}
131
132impl TaskHandle {
133 pub(crate) fn new(task: DynTask) -> Self {
134 Self {
135 status: TaskStatus::Created,
136 task,
137 join_handle: None,
138 restart_attempts: 0,
139 healthy_since: None,
140 cancellation_token: None,
141 max_restart_attempts: None,
143 base_restart_delay: Duration::from_secs(1),
144 max_backoff_exponent: 5,
145 }
146 }
147
148 pub(crate) fn from_task<T: SupervisedTask + Clone>(task: T) -> Self {
149 Self::new(Box::new(task))
150 }
151
152 pub(crate) fn restart_delay(&self) -> Duration {
154 let factor = 2u32.saturating_pow(self.restart_attempts.min(self.max_backoff_exponent));
155 self.base_restart_delay.saturating_mul(factor)
156 }
157
158 pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
159 if let Some(max_restart_attempts) = self.max_restart_attempts {
160 self.restart_attempts >= max_restart_attempts
161 } else {
162 false
163 }
164 }
165
166 pub(crate) fn mark(&mut self, status: TaskStatus) {
167 self.status = status;
168 }
169
170 pub(crate) fn clean(&mut self) {
171 if let Some(token) = self.cancellation_token.take() {
172 token.cancel();
173 }
174 if let Some(handle) = self.join_handle.take() {
175 handle.abort();
176 }
177 self.healthy_since = None;
178 }
179}