task_supervisor/task/
mod.rs1use std::{
2 future::Future,
3 pin::Pin,
4 time::{Duration, Instant},
5};
6
7use tokio::task::JoinHandle;
8use tokio_util::sync::CancellationToken;
9
10pub type TaskError = anyhow::Error;
11pub type TaskResult = Result<(), TaskError>;
12
13pub trait SupervisedTask: Send + 'static {
30 fn run(&mut self) -> impl Future<Output = TaskResult> + Send;
32}
33
34pub(crate) trait DynSupervisedTask: Send + 'static {
39 fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>>;
40 fn clone_box(&self) -> Box<dyn DynSupervisedTask>;
41}
42
43impl<T> DynSupervisedTask for T
44where
45 T: SupervisedTask + Clone + Send + 'static,
46{
47 fn run_boxed(&mut self) -> Pin<Box<dyn Future<Output = TaskResult> + Send + '_>> {
48 Box::pin(self.run())
49 }
50
51 fn clone_box(&self) -> Box<dyn DynSupervisedTask> {
52 Box::new(self.clone())
53 }
54}
55
56pub(crate) type DynTask = Box<dyn DynSupervisedTask>;
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum TaskStatus {
61 Created,
63 Healthy,
65 Failed,
67 Completed,
69 Dead,
71}
72
73impl TaskStatus {
74 pub fn is_restarting(&self) -> bool {
75 matches!(self, TaskStatus::Failed)
76 }
77
78 pub fn is_healthy(&self) -> bool {
79 matches!(self, TaskStatus::Healthy)
80 }
81
82 pub fn is_dead(&self) -> bool {
83 matches!(self, TaskStatus::Dead)
84 }
85
86 pub fn has_completed(&self) -> bool {
87 matches!(self, TaskStatus::Completed)
88 }
89}
90
91impl std::fmt::Display for TaskStatus {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Self::Created => write!(f, "created"),
95 Self::Healthy => write!(f, "healthy"),
96 Self::Failed => write!(f, "failed"),
97 Self::Completed => write!(f, "completed"),
98 Self::Dead => write!(f, "dead"),
99 }
100 }
101}
102
103pub(crate) struct TaskHandle {
104 pub(crate) status: TaskStatus,
105 pub(crate) task: DynTask,
106 pub(crate) join_handle: Option<JoinHandle<()>>,
107 pub(crate) restart_attempts: u32,
108 pub(crate) healthy_since: Option<Instant>,
109 pub(crate) cancellation_token: Option<CancellationToken>,
110 pub(crate) max_restart_attempts: Option<u32>,
111 pub(crate) base_restart_delay: Duration,
112 pub(crate) max_backoff_exponent: u32,
113}
114
115impl TaskHandle {
116 pub(crate) fn new(task: DynTask) -> Self {
119 Self {
120 status: TaskStatus::Created,
121 task,
122 join_handle: None,
123 restart_attempts: 0,
124 healthy_since: None,
125 cancellation_token: None,
126 max_restart_attempts: None,
127 base_restart_delay: Duration::from_secs(1),
128 max_backoff_exponent: 5,
129 }
130 }
131
132 pub(crate) fn from_task<T: SupervisedTask + Clone>(task: T) -> Self {
134 Self::new(Box::new(task))
135 }
136
137 pub(crate) fn restart_delay(&self) -> Duration {
139 let factor = 2u32.saturating_pow(self.restart_attempts.min(self.max_backoff_exponent));
140 self.base_restart_delay.saturating_mul(factor)
141 }
142
143 pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
145 if let Some(max_restart_attempts) = self.max_restart_attempts {
146 self.restart_attempts >= max_restart_attempts
147 } else {
148 false
149 }
150 }
151
152 pub(crate) fn mark(&mut self, status: TaskStatus) {
154 self.status = status;
155 }
156
157 pub(crate) fn clean(&mut self) {
159 if let Some(token) = self.cancellation_token.take() {
160 token.cancel();
161 }
162 if let Some(handle) = self.join_handle.take() {
163 handle.abort();
164 }
165 self.healthy_since = None;
166 }
167}