task_supervisor/task/
mod.rs1use std::{
2 error::Error,
3 time::{Duration, Instant},
4};
5
6use tokio::task::JoinHandle;
7use tokio_util::sync::CancellationToken;
8
9pub type DynTask = Box<dyn CloneableSupervisedTask>;
10pub type TaskError = Box<dyn Error + Send + Sync>;
11
12#[async_trait::async_trait]
13pub trait SupervisedTask: Send + 'static {
14 async fn run(&mut self) -> Result<TaskOutcome, TaskError>;
16}
17
18pub trait CloneableSupervisedTask: SupervisedTask {
19 fn clone_box(&self) -> Box<dyn CloneableSupervisedTask>;
20}
21
22impl<T> CloneableSupervisedTask for T
23where
24 T: SupervisedTask + Clone + Send + 'static,
25{
26 fn clone_box(&self) -> Box<dyn CloneableSupervisedTask> {
27 Box::new(self.clone())
28 }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum TaskStatus {
34 Created,
36 Healthy,
38 Failed,
40 Completed,
42 Dead,
44}
45
46impl TaskStatus {
47 pub fn is_restarting(&self) -> bool {
48 matches!(self, TaskStatus::Failed)
49 }
50
51 pub fn is_healthy(&self) -> bool {
52 matches!(self, TaskStatus::Healthy)
53 }
54
55 pub fn is_dead(&self) -> bool {
56 matches!(self, TaskStatus::Dead)
57 }
58
59 pub fn has_completed(&self) -> bool {
60 matches!(self, TaskStatus::Completed)
61 }
62}
63
64impl std::fmt::Display for TaskStatus {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 Self::Created => write!(f, "created"),
68 Self::Healthy => write!(f, "healthy"),
69 Self::Failed => write!(f, "failed"),
70 Self::Completed => write!(f, "completed"),
71 Self::Dead => write!(f, "dead"),
72 }
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum TaskOutcome {
79 Completed,
81 Failed(String),
83}
84
85impl std::fmt::Display for TaskOutcome {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 Self::Completed => write!(f, "completed"),
89 Self::Failed(e) => write!(f, "failed: {e}"),
90 }
91 }
92}
93
94pub(crate) struct TaskHandle {
95 pub(crate) status: TaskStatus,
96 pub(crate) task: DynTask,
97 pub(crate) main_task_handle: Option<JoinHandle<()>>,
98 pub(crate) completion_task_handle: Option<JoinHandle<()>>,
99 pub(crate) restart_attempts: u32,
100 pub(crate) started_at: Option<Instant>,
101 pub(crate) healthy_since: Option<Instant>,
102 pub(crate) cancellation_token: Option<CancellationToken>,
103 max_restart_attempts: u32,
104 base_restart_delay: Duration,
105}
106
107impl TaskHandle {
108 pub(crate) fn new(
110 task: Box<dyn CloneableSupervisedTask>,
111 max_restart_attempts: u32,
112 base_restart_delay: Duration,
113 ) -> Self {
114 Self {
115 status: TaskStatus::Created,
116 task,
117 main_task_handle: None,
118 completion_task_handle: None,
119 restart_attempts: 0,
120 started_at: None,
121 healthy_since: None,
122 cancellation_token: None,
123 max_restart_attempts,
124 base_restart_delay,
125 }
126 }
127
128 pub(crate) fn from_task<T: CloneableSupervisedTask + 'static>(
130 task: T,
131 max_restart_attempts: u32,
132 base_restart_delay: Duration,
133 ) -> Self {
134 let task = Box::new(task);
135 Self::new(task, max_restart_attempts, base_restart_delay)
136 }
137
138 pub(crate) fn restart_delay(&self) -> Duration {
140 let factor = 2u32.saturating_pow(self.restart_attempts.min(5));
141 self.base_restart_delay.saturating_mul(factor)
142 }
143
144 pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
146 self.restart_attempts >= self.max_restart_attempts
147 }
148
149 pub(crate) fn mark(&mut self, status: TaskStatus) {
151 self.status = status;
152 }
153
154 pub(crate) async fn clean(&mut self) {
156 if let Some(token) = self.cancellation_token.take() {
157 token.cancel();
158 }
159 if let Some(handle) = self.main_task_handle.take() {
160 handle.abort();
161 }
162 if let Some(handle) = self.completion_task_handle.take() {
163 handle.abort();
164 }
165 self.healthy_since = None;
166 self.started_at = None;
167 }
168}