task_supervisor/supervisor/
mod.rs1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5 collections::HashMap,
6 time::{Duration, Instant},
7};
8
9use tokio::{sync::mpsc, time::interval};
10use tokio_util::sync::CancellationToken;
11
12use crate::{
13 supervisor::handle::{SupervisorHandle, SupervisorMessage},
14 task::{TaskHandle, TaskOutcome, TaskStatus},
15};
16
17#[derive(Debug, thiserror::Error)]
18pub enum SupervisorError {
19 #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
20 TooManyDeadTasks {
21 current_percentage: f64,
22 threshold: f64,
23 },
24}
25
26#[derive(Debug)]
28pub(crate) enum SupervisedTaskMessage {
29 Restart(String),
31 Completed(String, TaskOutcome),
33 Shutdown,
35}
36
37pub struct Supervisor {
43 tasks: HashMap<String, TaskHandle>,
45 health_check_interval: Duration,
47 base_restart_delay: Duration,
48 task_is_stable_after: Duration,
49 max_restart_attempts: u32,
50 max_dead_tasks_percentage_threshold: Option<f64>,
51 external_tx: mpsc::UnboundedSender<SupervisorMessage>,
53 external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
54 internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
56 internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
57}
58
59impl Supervisor {
60 pub fn run(self) -> SupervisorHandle {
64 let user_tx = self.external_tx.clone();
65 let handle = tokio::spawn(async move { self.run_and_supervise().await });
66 SupervisorHandle::new(handle, user_tx)
67 }
68
69 async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
71 self.start_all_tasks().await;
72 self.supervise_all_tasks().await
73 }
74
75 async fn start_all_tasks(&mut self) {
77 for (task_name, task_handle) in self.tasks.iter_mut() {
78 Self::start_task(task_name.to_string(), task_handle, self.internal_tx.clone()).await;
79 }
80 }
81
82 async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
84 let mut health_check_ticker = interval(self.health_check_interval);
85
86 loop {
87 tokio::select! {
88 biased;
89 Some(internal_msg) = self.internal_rx.recv() => {
90 match internal_msg {
91 SupervisedTaskMessage::Shutdown => {
92 return Ok(());
93 }
94 _ => self.handle_internal_message(internal_msg).await,
95 }
96 },
97 Some(user_msg) = self.external_rx.recv() => {
98 self.handle_user_message(user_msg).await;
99 },
100 _ = health_check_ticker.tick() => {
101 self.check_all_health().await;
102 }
103 }
104
105 self.check_dead_tasks_threshold().await?;
106 }
107 }
108
109 async fn handle_internal_message(&mut self, msg: SupervisedTaskMessage) {
110 match msg {
111 SupervisedTaskMessage::Restart(task_name) => {
112 self.restart_task(task_name).await;
113 }
114 SupervisedTaskMessage::Completed(task_name, outcome) => {
115 self.handle_task_completion(task_name, outcome).await;
116 }
117 SupervisedTaskMessage::Shutdown => {
118 unreachable!("Shutdown should be handled by the main select loop to break.");
119 }
120 }
121 }
122
123 async fn handle_user_message(&mut self, msg: SupervisorMessage) {
125 match msg {
126 SupervisorMessage::AddTask(task_name, task_dyn) => {
127 if self.tasks.contains_key(&task_name) {
129 return;
130 }
131 let mut task_handle =
132 TaskHandle::new(task_dyn, self.max_restart_attempts, self.base_restart_delay);
133 Self::start_task(
134 task_name.clone(),
135 &mut task_handle,
136 self.internal_tx.clone(),
137 )
138 .await;
139 self.tasks.insert(task_name, task_handle);
140 }
141 SupervisorMessage::RestartTask(task_name) => {
142 self.restart_task(task_name).await;
143 }
144 SupervisorMessage::KillTask(task_name) => {
145 if let Some(task_handle) = self.tasks.get_mut(&task_name) {
146 if task_handle.status != TaskStatus::Dead {
147 task_handle.mark(TaskStatus::Dead);
148 task_handle.clean().await;
149 }
150 }
151 }
152 SupervisorMessage::GetTaskStatus(task_name, sender) => {
153 let status = self.tasks.get(&task_name).map(|handle| handle.status);
154 let _ = sender.send(status);
155 }
156 SupervisorMessage::GetAllTaskStatuses(sender) => {
157 let statuses = self
158 .tasks
159 .iter()
160 .map(|(name, handle)| (name.clone(), handle.status))
161 .collect();
162 let _ = sender.send(statuses);
163 }
164 SupervisorMessage::Shutdown => {
165 for (_, task_handle) in self.tasks.iter_mut() {
166 if task_handle.status != TaskStatus::Dead
167 && task_handle.status != TaskStatus::Completed
168 {
169 task_handle.clean().await;
170 task_handle.mark(TaskStatus::Dead);
171 }
172 }
173 let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
174 }
175 }
176 }
177
178 async fn start_task(
180 task_name: String,
181 task_handle: &mut TaskHandle,
182 internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
183 ) {
184 task_handle.started_at = Some(Instant::now());
185 task_handle.mark(TaskStatus::Healthy);
186
187 let token = CancellationToken::new();
188 task_handle.cancellation_token = Some(token.clone());
189
190 let (completion_tx, mut completion_rx) = mpsc::channel::<TaskOutcome>(1);
191
192 let task_name_completion = task_name.clone();
194 let token_completion = token.clone();
195 let internal_tx_completion = internal_tx.clone();
196 let completion_listener_handle = tokio::spawn(async move {
197 tokio::select! {
198 _ = token_completion.cancelled() => { }
199 Some(outcome) = completion_rx.recv() => {
200 let completion_msg = SupervisedTaskMessage::Completed(task_name_completion.clone(), outcome);
201 let _ = internal_tx_completion.send(completion_msg);
202 token_completion.cancel();
203 }
204 }
205 });
206
207 let mut task_instance = task_handle.task.clone_box();
209 let token_main = token.clone();
210 let main_task_execution_handle = tokio::spawn(async move {
211 tokio::select! {
212 _ = token_main.cancelled() => { }
213 run_result = task_instance.run() => {
214 match run_result {
215 Ok(outcome) => {
216 let _ = completion_tx.send(outcome).await;
217 }
218 Err(e) => {
219 let _ = completion_tx.send(TaskOutcome::Failed(e.to_string())).await;
220 }
221 }
222 }
223 }
224 });
225
226 task_handle.main_task_handle = Some(main_task_execution_handle);
227 task_handle.completion_task_handle = Some(completion_listener_handle);
228 }
229
230 async fn restart_task(&mut self, task_name: String) {
232 if let Some(task_handle) = self.tasks.get_mut(&task_name) {
233 task_handle.clean().await;
234 Self::start_task(task_name, task_handle, self.internal_tx.clone()).await;
235 }
236 }
237
238 async fn check_all_health(&mut self) {
239 let mut tasks_needing_restart: Vec<String> = Vec::new();
240 let now = Instant::now();
241
242 for (task_name, task_handle) in self.tasks.iter_mut() {
243 if task_handle.status == TaskStatus::Healthy {
244 if let Some(main_handle) = &task_handle.main_task_handle {
245 if main_handle.is_finished() {
246 task_handle.mark(TaskStatus::Failed);
247 tasks_needing_restart.push(task_name.clone());
248 } else {
249 if let Some(healthy_since) = task_handle.healthy_since {
251 if (now.duration_since(healthy_since) > self.task_is_stable_after)
252 && task_handle.restart_attempts > 0
253 {
254 task_handle.restart_attempts = 0;
255 }
256 } else {
257 task_handle.healthy_since = Some(now);
258 }
259 }
260 } else {
261 task_handle.mark(TaskStatus::Failed);
262 tasks_needing_restart.push(task_name.clone());
263 }
264 }
265 }
266
267 for task_name in tasks_needing_restart {
268 let Some(task_handle) = self.tasks.get_mut(&task_name) else {
269 continue;
270 };
271
272 if task_handle.has_exceeded_max_retries() {
274 task_handle.mark(TaskStatus::Dead);
275 task_handle.clean().await;
276 continue;
277 }
278
279 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
281 let restart_delay = task_handle.restart_delay();
282
283 let internal_tx_clone = self.internal_tx.clone();
284 tokio::spawn(async move {
285 tokio::time::sleep(restart_delay).await;
286 let _ = internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
287 });
288 }
289 }
290
291 async fn handle_task_completion(&mut self, task_name: String, outcome: TaskOutcome) {
292 let Some(task_handle) = self.tasks.get_mut(&task_name) else {
293 return;
294 };
295
296 task_handle.clean().await;
297
298 match outcome {
299 TaskOutcome::Completed => {
300 task_handle.mark(TaskStatus::Completed);
301 }
302 TaskOutcome::Failed(_) => {
303 task_handle.mark(TaskStatus::Failed);
304
305 if task_handle.has_exceeded_max_retries() {
306 task_handle.mark(TaskStatus::Dead);
307 return;
308 }
309
310 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
311 let restart_delay = task_handle.restart_delay();
312
313 let internal_tx_clone = self.internal_tx.clone();
314 tokio::spawn(async move {
315 tokio::time::sleep(restart_delay).await;
316 let _ =
317 internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
318 });
319 }
320 }
321 }
322
323 async fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
324 if let Some(threshold) = self.max_dead_tasks_percentage_threshold {
325 if !self.tasks.is_empty() {
326 let dead_task_count = self
327 .tasks
328 .values()
329 .filter(|handle| handle.status == TaskStatus::Dead)
330 .count();
331
332 let total_task_count = self.tasks.len();
333 let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
334
335 if current_dead_percentage > threshold {
336 for (_, task_handle) in self.tasks.iter_mut() {
338 if task_handle.status != TaskStatus::Dead
339 && task_handle.status != TaskStatus::Completed
340 {
341 task_handle.clean().await;
342 task_handle.mark(TaskStatus::Dead);
343 }
344 }
345
346 return Err(SupervisorError::TooManyDeadTasks {
347 current_percentage: current_dead_percentage,
348 threshold,
349 });
350 }
351 }
352 };
353
354 Ok(())
355 }
356}