task_supervisor/task/
mod.rs1use std::time::{Duration, Instant};
2
3use tokio::task::JoinHandle;
4use tokio_util::sync::CancellationToken;
5
6pub type DynTask = Box<dyn CloneableSupervisedTask>;
7
8pub type TaskError = anyhow::Error;
9pub type TaskResult = Result<(), TaskError>;
10
11#[async_trait::async_trait]
12pub trait SupervisedTask: Send + 'static {
13 async fn run(&mut self) -> TaskResult;
15}
16
17pub trait CloneableSupervisedTask: SupervisedTask {
18 fn clone_box(&self) -> Box<dyn CloneableSupervisedTask>;
19}
20
21impl<T> CloneableSupervisedTask for T
22where
23 T: SupervisedTask + Clone + Send + 'static,
24{
25 fn clone_box(&self) -> Box<dyn CloneableSupervisedTask> {
26 Box::new(self.clone())
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum TaskStatus {
33 Created,
35 Healthy,
37 Failed,
39 Completed,
41 Dead,
43}
44
45impl TaskStatus {
46 pub fn is_restarting(&self) -> bool {
47 matches!(self, TaskStatus::Failed)
48 }
49
50 pub fn is_healthy(&self) -> bool {
51 matches!(self, TaskStatus::Healthy)
52 }
53
54 pub fn is_dead(&self) -> bool {
55 matches!(self, TaskStatus::Dead)
56 }
57
58 pub fn has_completed(&self) -> bool {
59 matches!(self, TaskStatus::Completed)
60 }
61}
62
63impl std::fmt::Display for TaskStatus {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 Self::Created => write!(f, "created"),
67 Self::Healthy => write!(f, "healthy"),
68 Self::Failed => write!(f, "failed"),
69 Self::Completed => write!(f, "completed"),
70 Self::Dead => write!(f, "dead"),
71 }
72 }
73}
74
75pub(crate) struct TaskHandle {
76 pub(crate) status: TaskStatus,
77 pub(crate) task: DynTask,
78 pub(crate) main_task_handle: Option<JoinHandle<()>>,
79 pub(crate) completion_task_handle: Option<JoinHandle<()>>,
80 pub(crate) restart_attempts: u32,
81 pub(crate) started_at: Option<Instant>,
82 pub(crate) healthy_since: Option<Instant>,
83 pub(crate) cancellation_token: Option<CancellationToken>,
84 max_restart_attempts: u32,
85 base_restart_delay: Duration,
86 max_backoff_exponent: u32,
87}
88
89impl TaskHandle {
90 pub(crate) fn new(
92 task: Box<dyn CloneableSupervisedTask>,
93 max_restart_attempts: u32,
94 base_restart_delay: Duration,
95 max_backoff_exponent: u32,
96 ) -> Self {
97 Self {
98 status: TaskStatus::Created,
99 task,
100 main_task_handle: None,
101 completion_task_handle: None,
102 restart_attempts: 0,
103 started_at: None,
104 healthy_since: None,
105 cancellation_token: None,
106 max_restart_attempts,
107 base_restart_delay,
108 max_backoff_exponent,
109 }
110 }
111
112 pub(crate) fn from_task<T: CloneableSupervisedTask + 'static>(
114 task: T,
115 max_restart_attempts: u32,
116 base_restart_delay: Duration,
117 max_backoff_exponent: u32,
118 ) -> Self {
119 let task = Box::new(task);
120 Self::new(
121 task,
122 max_restart_attempts,
123 base_restart_delay,
124 max_backoff_exponent,
125 )
126 }
127
128 pub(crate) fn restart_delay(&self) -> Duration {
130 let factor = 2u32.saturating_pow(self.restart_attempts.min(self.max_backoff_exponent));
131 self.base_restart_delay.saturating_mul(factor)
132 }
133
134 pub(crate) const fn has_exceeded_max_retries(&self) -> bool {
136 self.restart_attempts >= self.max_restart_attempts
137 }
138
139 pub(crate) fn mark(&mut self, status: TaskStatus) {
141 self.status = status;
142 }
143
144 pub(crate) async fn clean(&mut self) {
146 if let Some(token) = self.cancellation_token.take() {
147 token.cancel();
148 }
149 if let Some(handle) = self.main_task_handle.take() {
150 handle.abort();
151 }
152 if let Some(handle) = self.completion_task_handle.take() {
153 handle.abort();
154 }
155 self.healthy_since = None;
156 self.started_at = None;
157 }
158}