task_supervisor/supervisor/
mod.rs1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5 collections::HashMap,
6 time::{Duration, Instant},
7};
8
9use tokio::{sync::mpsc, time::interval};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "with_tracing")]
13use tracing::{debug, error, info, warn};
14
15use crate::{
16 supervisor::handle::{SupervisorHandle, SupervisorMessage},
17 task::{TaskHandle, TaskResult, TaskStatus},
18};
19
20#[derive(Clone, Debug, thiserror::Error)]
21pub enum SupervisorError {
22 #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
23 TooManyDeadTasks {
24 current_percentage: f64,
25 threshold: f64,
26 },
27}
28
29#[derive(Debug)]
31pub(crate) enum SupervisedTaskMessage {
32 Restart(String),
34 Completed(String, TaskResult),
36 Shutdown,
38}
39
40pub struct Supervisor {
46 tasks: HashMap<String, TaskHandle>,
48 health_check_interval: Duration,
50 base_restart_delay: Duration,
51 task_is_stable_after: Duration,
52 max_restart_attempts: u32,
53 max_backoff_exponent: u32,
54 max_dead_tasks_percentage_threshold: Option<f64>,
55 external_tx: mpsc::UnboundedSender<SupervisorMessage>,
57 external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
58 internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
60 internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
61}
62
63impl Supervisor {
64 pub fn run(self) -> SupervisorHandle {
68 let user_tx = self.external_tx.clone();
69 let handle = tokio::spawn(async move { self.run_and_supervise().await });
70 SupervisorHandle::new(handle, user_tx)
71 }
72
73 async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
75 self.start_all_tasks().await;
76 self.supervise_all_tasks().await
77 }
78
79 async fn start_all_tasks(&mut self) {
81 for (task_name, task_handle) in self.tasks.iter_mut() {
82 Self::start_task(task_name.to_string(), task_handle, self.internal_tx.clone()).await;
83 }
84 }
85
86 async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
88 let mut health_check_ticker = interval(self.health_check_interval);
89
90 loop {
91 tokio::select! {
92 biased;
93 Some(internal_msg) = self.internal_rx.recv() => {
94 match internal_msg {
95 SupervisedTaskMessage::Shutdown => {
96 #[cfg(feature = "with_tracing")]
97 info!("Supervisor received shutdown signal");
98 return Ok(());
99 }
100 _ => self.handle_internal_message(internal_msg).await,
101 }
102 },
103 Some(user_msg) = self.external_rx.recv() => {
104 self.handle_user_message(user_msg).await;
105 },
106 _ = health_check_ticker.tick() => {
107 #[cfg(feature = "with_tracing")]
108 debug!("Supervisor checking health of all tasks");
109 self.check_all_health().await;
110 self.check_dead_tasks_threshold().await?;
111 }
112 }
113 }
114 }
115
116 async fn handle_internal_message(&mut self, msg: SupervisedTaskMessage) {
117 match msg {
118 SupervisedTaskMessage::Restart(task_name) => {
119 #[cfg(feature = "with_tracing")]
120 info!("Processing restart request for task: {task_name}");
121 self.restart_task(task_name).await;
122 }
123 SupervisedTaskMessage::Completed(task_name, outcome) => {
124 #[cfg(feature = "with_tracing")]
125 match &outcome {
126 Ok(()) => info!("Task '{task_name}' completed successfully"),
127 Err(e) => warn!("Task '{task_name}' completed with error: {e}"),
128 }
129 self.handle_task_completion(task_name, outcome).await;
130 }
131 SupervisedTaskMessage::Shutdown => {
132 unreachable!("Shutdown is handled by the main select loop.");
133 }
134 }
135 }
136
137 async fn handle_user_message(&mut self, msg: SupervisorMessage) {
139 match msg {
140 SupervisorMessage::AddTask(task_name, task_dyn) => {
141 if self.tasks.contains_key(&task_name) {
143 #[cfg(feature = "with_tracing")]
144 warn!("Attempted to add task '{task_name}' but it already exists");
145 return;
146 }
147
148 let mut task_handle = TaskHandle::new(
149 task_dyn,
150 self.max_restart_attempts,
151 self.base_restart_delay,
152 self.max_backoff_exponent,
153 );
154
155 Self::start_task(
156 task_name.clone(),
157 &mut task_handle,
158 self.internal_tx.clone(),
159 )
160 .await;
161 self.tasks.insert(task_name, task_handle);
162 }
163 SupervisorMessage::RestartTask(task_name) => {
164 #[cfg(feature = "with_tracing")]
165 info!("User requested restart for task: {task_name}");
166 self.restart_task(task_name).await;
167 }
168 SupervisorMessage::KillTask(task_name) => {
169 if let Some(task_handle) = self.tasks.get_mut(&task_name) {
170 if task_handle.status != TaskStatus::Dead {
171 task_handle.mark(TaskStatus::Dead);
172 task_handle.clean().await;
173 }
174 } else {
175 #[cfg(feature = "with_tracing")]
176 warn!("Attempted to kill non-existent task: {task_name}");
177 }
178 }
179 SupervisorMessage::GetTaskStatus(task_name, sender) => {
180 let status = self.tasks.get(&task_name).map(|handle| handle.status);
181
182 #[cfg(feature = "with_tracing")]
183 debug!("Status query for task '{task_name}': {status:?}");
184
185 let _ = sender.send(status);
186 }
187 SupervisorMessage::GetAllTaskStatuses(sender) => {
188 let statuses = self
189 .tasks
190 .iter()
191 .map(|(name, handle)| (name.clone(), handle.status))
192 .collect();
193 let _ = sender.send(statuses);
194 }
195 SupervisorMessage::Shutdown => {
196 #[cfg(feature = "with_tracing")]
197 info!("User requested supervisor shutdown");
198
199 for (_, task_handle) in self.tasks.iter_mut() {
200 if task_handle.status != TaskStatus::Dead
201 && task_handle.status != TaskStatus::Completed
202 {
203 task_handle.clean().await;
204 task_handle.mark(TaskStatus::Dead);
205 }
206 }
207 let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
208 }
209 }
210 }
211
212 async fn start_task(
214 task_name: String,
215 task_handle: &mut TaskHandle,
216 internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
217 ) {
218 task_handle.started_at = Some(Instant::now());
219 task_handle.mark(TaskStatus::Healthy);
220
221 let token = CancellationToken::new();
222 task_handle.cancellation_token = Some(token.clone());
223
224 let (completion_tx, mut completion_rx) = mpsc::channel::<TaskResult>(1);
225
226 let task_name_completion = task_name.clone();
228 let token_completion = token.clone();
229 let internal_tx_completion = internal_tx.clone();
230 let completion_listener_handle = tokio::spawn(async move {
231 tokio::select! {
232 _ = token_completion.cancelled() => { }
233 Some(outcome) = completion_rx.recv() => {
234 let completion_msg = SupervisedTaskMessage::Completed(task_name_completion.clone(), outcome);
235 let _ = internal_tx_completion.send(completion_msg);
236 token_completion.cancel();
237 }
238 }
239 });
240
241 let mut task_instance = task_handle.task.clone_box();
243 let token_main = token.clone();
244 let main_task_execution_handle = tokio::spawn(async move {
245 tokio::select! {
246 _ = token_main.cancelled() => { }
247 run_result = task_instance.run() => {
248 let _ = completion_tx.send(run_result).await;
249 }
250 }
251 });
252
253 task_handle.main_task_handle = Some(main_task_execution_handle);
254 task_handle.completion_task_handle = Some(completion_listener_handle);
255 }
256
257 async fn restart_task(&mut self, task_name: String) {
259 if let Some(task_handle) = self.tasks.get_mut(&task_name) {
260 task_handle.clean().await;
261 Self::start_task(task_name, task_handle, self.internal_tx.clone()).await;
262 }
263 }
264
265 async fn check_all_health(&mut self) {
266 let mut tasks_needing_restart: Vec<String> = Vec::new();
267 let now = Instant::now();
268
269 for (task_name, task_handle) in self.tasks.iter_mut() {
270 if task_handle.status == TaskStatus::Healthy {
271 if let Some(main_handle) = &task_handle.main_task_handle {
272 if main_handle.is_finished() {
273 #[cfg(feature = "with_tracing")]
274 warn!("Task '{task_name}' unexpectedly finished, marking as failed");
275
276 task_handle.mark(TaskStatus::Failed);
277 tasks_needing_restart.push(task_name.clone());
278 } else {
279 if let Some(healthy_since) = task_handle.healthy_since {
281 if (now.duration_since(healthy_since) > self.task_is_stable_after)
282 && task_handle.restart_attempts > 0
283 {
284 #[cfg(feature = "with_tracing")]
285 info!(
286 "Task '{task_name}' is now stable, resetting restart attempts",
287 );
288
289 task_handle.restart_attempts = 0;
290 }
291 } else {
292 task_handle.healthy_since = Some(now);
293 }
294 }
295 } else {
296 #[cfg(feature = "with_tracing")]
297 error!("Task '{task_name}' has no main handle, marking as failed");
298
299 task_handle.mark(TaskStatus::Failed);
300 tasks_needing_restart.push(task_name.clone());
301 }
302 }
303 }
304
305 for task_name in tasks_needing_restart {
306 let Some(task_handle) = self.tasks.get_mut(&task_name) else {
307 continue;
308 };
309
310 if task_handle.has_exceeded_max_retries() {
312 #[cfg(feature = "with_tracing")]
313 error!(
314 "Task '{task_name}' exceeded max restart attempts ({}), marking as dead",
315 self.max_restart_attempts
316 );
317
318 task_handle.mark(TaskStatus::Dead);
319 task_handle.clean().await;
320 continue;
321 }
322
323 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
325 let restart_delay = task_handle.restart_delay();
326
327 #[cfg(feature = "with_tracing")]
328 info!(
329 "Scheduling restart for task '{task_name}' in {restart_delay:?} (attempt {}/{})",
330 task_handle.restart_attempts, self.max_restart_attempts
331 );
332
333 let internal_tx_clone = self.internal_tx.clone();
334 tokio::spawn(async move {
335 tokio::time::sleep(restart_delay).await;
336 let _ = internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
337 });
338 }
339 }
340
341 async fn handle_task_completion(&mut self, task_name: String, outcome: TaskResult) {
342 let Some(task_handle) = self.tasks.get_mut(&task_name) else {
343 #[cfg(feature = "with_tracing")]
344 warn!("Received completion for non-existent task: {}", task_name);
345 return;
346 };
347
348 task_handle.clean().await;
349
350 match outcome {
351 Ok(()) => {
352 #[cfg(feature = "with_tracing")]
353 info!("Task '{task_name}' completed successfully");
354
355 task_handle.mark(TaskStatus::Completed);
356 }
357 #[allow(unused_variables)]
358 Err(ref e) => {
359 #[cfg(feature = "with_tracing")]
360 error!("Task '{task_name}' failed with error: {e:?}");
361
362 task_handle.mark(TaskStatus::Failed);
363 if task_handle.has_exceeded_max_retries() {
364 #[cfg(feature = "with_tracing")]
365 error!(
366 "Task '{task_name}' exceeded max restart attempts after failure, marking as dead",
367 );
368
369 task_handle.mark(TaskStatus::Dead);
370 return;
371 }
372
373 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
374 let restart_delay = task_handle.restart_delay();
375
376 #[cfg(feature = "with_tracing")]
377 info!(
378 "Scheduling restart for failed task '{task_name}' in {restart_delay:?} (attempt {}/{})",
379 task_handle.restart_attempts,
380 self.max_restart_attempts
381 );
382
383 let internal_tx_clone = self.internal_tx.clone();
384 tokio::spawn(async move {
385 tokio::time::sleep(restart_delay).await;
386 let _ =
387 internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
388 });
389 }
390 }
391 }
392
393 async fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
394 if let Some(threshold) = self.max_dead_tasks_percentage_threshold {
395 if !self.tasks.is_empty() {
396 let dead_task_count = self
397 .tasks
398 .values()
399 .filter(|handle| handle.status == TaskStatus::Dead)
400 .count();
401
402 let total_task_count = self.tasks.len();
403 let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
404
405 if current_dead_percentage > threshold {
406 #[cfg(feature = "with_tracing")]
407 error!(
408 "Dead tasks threshold exceeded: {:.2}% > {:.2}% ({}/{} tasks dead)",
409 current_dead_percentage * 100.0,
410 threshold * 100.0,
411 dead_task_count,
412 total_task_count
413 );
414
415 #[allow(unused_variables)]
417 for (task_name, task_handle) in self.tasks.iter_mut() {
418 if task_handle.status != TaskStatus::Dead
419 && task_handle.status != TaskStatus::Completed
420 {
421 #[cfg(feature = "with_tracing")]
422 debug!("Killing task '{task_name}' due to threshold breach");
423
424 task_handle.clean().await;
425 task_handle.mark(TaskStatus::Dead);
426 }
427 }
428
429 return Err(SupervisorError::TooManyDeadTasks {
430 current_percentage: current_dead_percentage,
431 threshold,
432 });
433 }
434 }
435 };
436
437 Ok(())
438 }
439}