task_supervisor/supervisor/
mod.rs1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5 collections::HashMap,
6 time::{Duration, Instant},
7};
8
9use tokio::{sync::mpsc, time::interval};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "with_tracing")]
13use tracing::{debug, error, info, warn};
14
15use crate::{
16 supervisor::handle::{SupervisorHandle, SupervisorMessage},
17 task::{TaskHandle, TaskResult, TaskStatus},
18};
19
20#[derive(Clone, Debug, thiserror::Error)]
21pub enum SupervisorError {
22 #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
23 TooManyDeadTasks {
24 current_percentage: f64,
25 threshold: f64,
26 },
27}
28
29#[derive(Debug)]
31pub(crate) enum SupervisedTaskMessage {
32 Restart(String),
34 Completed(String, TaskResult),
36 Shutdown,
38}
39
40pub struct Supervisor {
46 tasks: HashMap<String, TaskHandle>,
48 health_check_interval: Duration,
50 base_restart_delay: Duration,
51 task_is_stable_after: Duration,
52 max_restart_attempts: u32,
53 max_dead_tasks_percentage_threshold: Option<f64>,
54 external_tx: mpsc::UnboundedSender<SupervisorMessage>,
56 external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
57 internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
59 internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
60}
61
62impl Supervisor {
63 pub fn run(self) -> SupervisorHandle {
67 let user_tx = self.external_tx.clone();
68 let handle = tokio::spawn(async move { self.run_and_supervise().await });
69 SupervisorHandle::new(handle, user_tx)
70 }
71
72 async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
74 self.start_all_tasks().await;
75 self.supervise_all_tasks().await
76 }
77
78 async fn start_all_tasks(&mut self) {
80 for (task_name, task_handle) in self.tasks.iter_mut() {
81 Self::start_task(task_name.to_string(), task_handle, self.internal_tx.clone()).await;
82 }
83 }
84
85 async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
87 let mut health_check_ticker = interval(self.health_check_interval);
88
89 loop {
90 tokio::select! {
91 biased;
92 Some(internal_msg) = self.internal_rx.recv() => {
93 match internal_msg {
94 SupervisedTaskMessage::Shutdown => {
95 #[cfg(feature = "with_tracing")]
96 info!("Supervisor received shutdown signal");
97 return Ok(());
98 }
99 _ => self.handle_internal_message(internal_msg).await,
100 }
101 },
102 Some(user_msg) = self.external_rx.recv() => {
103 self.handle_user_message(user_msg).await;
104 },
105 _ = health_check_ticker.tick() => {
106 #[cfg(feature = "with_tracing")]
107 debug!("Supervisor checking health of all tasks");
108 self.check_all_health().await;
109 self.check_dead_tasks_threshold().await?;
110 }
111 }
112 }
113 }
114
115 async fn handle_internal_message(&mut self, msg: SupervisedTaskMessage) {
116 match msg {
117 SupervisedTaskMessage::Restart(task_name) => {
118 #[cfg(feature = "with_tracing")]
119 info!("Processing restart request for task: {task_name}");
120 self.restart_task(task_name).await;
121 }
122 SupervisedTaskMessage::Completed(task_name, outcome) => {
123 #[cfg(feature = "with_tracing")]
124 match &outcome {
125 Ok(()) => info!("Task '{task_name}' completed successfully"),
126 Err(e) => warn!("Task '{task_name}' completed with error: {e}"),
127 }
128 self.handle_task_completion(task_name, outcome).await;
129 }
130 SupervisedTaskMessage::Shutdown => {
131 unreachable!("Shutdown is handled by the main select loop.");
132 }
133 }
134 }
135
136 async fn handle_user_message(&mut self, msg: SupervisorMessage) {
138 match msg {
139 SupervisorMessage::AddTask(task_name, task_dyn) => {
140 if self.tasks.contains_key(&task_name) {
142 #[cfg(feature = "with_tracing")]
143 warn!("Attempted to add task '{task_name}' but it already exists");
144 return;
145 }
146
147 let mut task_handle =
148 TaskHandle::new(task_dyn, self.max_restart_attempts, self.base_restart_delay);
149 Self::start_task(
150 task_name.clone(),
151 &mut task_handle,
152 self.internal_tx.clone(),
153 )
154 .await;
155 self.tasks.insert(task_name, task_handle);
156 }
157 SupervisorMessage::RestartTask(task_name) => {
158 #[cfg(feature = "with_tracing")]
159 info!("User requested restart for task: {task_name}");
160 self.restart_task(task_name).await;
161 }
162 SupervisorMessage::KillTask(task_name) => {
163 if let Some(task_handle) = self.tasks.get_mut(&task_name) {
164 if task_handle.status != TaskStatus::Dead {
165 task_handle.mark(TaskStatus::Dead);
166 task_handle.clean().await;
167 }
168 } else {
169 #[cfg(feature = "with_tracing")]
170 warn!("Attempted to kill non-existent task: {task_name}");
171 }
172 }
173 SupervisorMessage::GetTaskStatus(task_name, sender) => {
174 let status = self.tasks.get(&task_name).map(|handle| handle.status);
175
176 #[cfg(feature = "with_tracing")]
177 debug!("Status query for task '{task_name}': {status:?}");
178
179 let _ = sender.send(status);
180 }
181 SupervisorMessage::GetAllTaskStatuses(sender) => {
182 let statuses = self
183 .tasks
184 .iter()
185 .map(|(name, handle)| (name.clone(), handle.status))
186 .collect();
187 let _ = sender.send(statuses);
188 }
189 SupervisorMessage::Shutdown => {
190 #[cfg(feature = "with_tracing")]
191 info!("User requested supervisor shutdown");
192
193 for (_, task_handle) in self.tasks.iter_mut() {
194 if task_handle.status != TaskStatus::Dead
195 && task_handle.status != TaskStatus::Completed
196 {
197 task_handle.clean().await;
198 task_handle.mark(TaskStatus::Dead);
199 }
200 }
201 let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
202 }
203 }
204 }
205
206 async fn start_task(
208 task_name: String,
209 task_handle: &mut TaskHandle,
210 internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
211 ) {
212 task_handle.started_at = Some(Instant::now());
213 task_handle.mark(TaskStatus::Healthy);
214
215 let token = CancellationToken::new();
216 task_handle.cancellation_token = Some(token.clone());
217
218 let (completion_tx, mut completion_rx) = mpsc::channel::<TaskResult>(1);
219
220 let task_name_completion = task_name.clone();
222 let token_completion = token.clone();
223 let internal_tx_completion = internal_tx.clone();
224 let completion_listener_handle = tokio::spawn(async move {
225 tokio::select! {
226 _ = token_completion.cancelled() => { }
227 Some(outcome) = completion_rx.recv() => {
228 let completion_msg = SupervisedTaskMessage::Completed(task_name_completion.clone(), outcome);
229 let _ = internal_tx_completion.send(completion_msg);
230 token_completion.cancel();
231 }
232 }
233 });
234
235 let mut task_instance = task_handle.task.clone_box();
237 let token_main = token.clone();
238 let main_task_execution_handle = tokio::spawn(async move {
239 tokio::select! {
240 _ = token_main.cancelled() => { }
241 run_result = task_instance.run() => {
242 let _ = completion_tx.send(run_result).await;
243 }
244 }
245 });
246
247 task_handle.main_task_handle = Some(main_task_execution_handle);
248 task_handle.completion_task_handle = Some(completion_listener_handle);
249 }
250
251 async fn restart_task(&mut self, task_name: String) {
253 if let Some(task_handle) = self.tasks.get_mut(&task_name) {
254 task_handle.clean().await;
255 Self::start_task(task_name, task_handle, self.internal_tx.clone()).await;
256 }
257 }
258
259 async fn check_all_health(&mut self) {
260 let mut tasks_needing_restart: Vec<String> = Vec::new();
261 let now = Instant::now();
262
263 for (task_name, task_handle) in self.tasks.iter_mut() {
264 if task_handle.status == TaskStatus::Healthy {
265 if let Some(main_handle) = &task_handle.main_task_handle {
266 if main_handle.is_finished() {
267 #[cfg(feature = "with_tracing")]
268 warn!("Task '{task_name}' unexpectedly finished, marking as failed");
269
270 task_handle.mark(TaskStatus::Failed);
271 tasks_needing_restart.push(task_name.clone());
272 } else {
273 if let Some(healthy_since) = task_handle.healthy_since {
275 if (now.duration_since(healthy_since) > self.task_is_stable_after)
276 && task_handle.restart_attempts > 0
277 {
278 #[cfg(feature = "with_tracing")]
279 info!(
280 "Task '{task_name}' is now stable, resetting restart attempts",
281 );
282
283 task_handle.restart_attempts = 0;
284 }
285 } else {
286 task_handle.healthy_since = Some(now);
287 }
288 }
289 } else {
290 #[cfg(feature = "with_tracing")]
291 error!("Task '{task_name}' has no main handle, marking as failed");
292
293 task_handle.mark(TaskStatus::Failed);
294 tasks_needing_restart.push(task_name.clone());
295 }
296 }
297 }
298
299 for task_name in tasks_needing_restart {
300 let Some(task_handle) = self.tasks.get_mut(&task_name) else {
301 continue;
302 };
303
304 if task_handle.has_exceeded_max_retries() {
306 #[cfg(feature = "with_tracing")]
307 error!(
308 "Task '{task_name}' exceeded max restart attempts ({}), marking as dead",
309 self.max_restart_attempts
310 );
311
312 task_handle.mark(TaskStatus::Dead);
313 task_handle.clean().await;
314 continue;
315 }
316
317 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
319 let restart_delay = task_handle.restart_delay();
320
321 #[cfg(feature = "with_tracing")]
322 info!(
323 "Scheduling restart for task '{task_name}' in {restart_delay:?} (attempt {}/{})",
324 task_handle.restart_attempts, self.max_restart_attempts
325 );
326
327 let internal_tx_clone = self.internal_tx.clone();
328 tokio::spawn(async move {
329 tokio::time::sleep(restart_delay).await;
330 let _ = internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
331 });
332 }
333 }
334
335 async fn handle_task_completion(&mut self, task_name: String, outcome: TaskResult) {
336 let Some(task_handle) = self.tasks.get_mut(&task_name) else {
337 #[cfg(feature = "with_tracing")]
338 warn!("Received completion for non-existent task: {}", task_name);
339 return;
340 };
341
342 task_handle.clean().await;
343
344 match outcome {
345 Ok(()) => {
346 #[cfg(feature = "with_tracing")]
347 info!("Task '{task_name}' completed successfully");
348
349 task_handle.mark(TaskStatus::Completed);
350 }
351 #[allow(unused_variables)]
352 Err(ref e) => {
353 #[cfg(feature = "with_tracing")]
354 error!("Task '{task_name}' failed with error: {e:?}");
355
356 task_handle.mark(TaskStatus::Failed);
357 if task_handle.has_exceeded_max_retries() {
358 #[cfg(feature = "with_tracing")]
359 error!(
360 "Task '{task_name}' exceeded max restart attempts after failure, marking as dead",
361 );
362
363 task_handle.mark(TaskStatus::Dead);
364 return;
365 }
366
367 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
368 let restart_delay = task_handle.restart_delay();
369
370 #[cfg(feature = "with_tracing")]
371 info!(
372 "Scheduling restart for failed task '{task_name}' in {restart_delay:?} (attempt {}/{})",
373 task_handle.restart_attempts,
374 self.max_restart_attempts
375 );
376
377 let internal_tx_clone = self.internal_tx.clone();
378 tokio::spawn(async move {
379 tokio::time::sleep(restart_delay).await;
380 let _ =
381 internal_tx_clone.send(SupervisedTaskMessage::Restart(task_name.clone()));
382 });
383 }
384 }
385 }
386
387 async fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
388 if let Some(threshold) = self.max_dead_tasks_percentage_threshold {
389 if !self.tasks.is_empty() {
390 let dead_task_count = self
391 .tasks
392 .values()
393 .filter(|handle| handle.status == TaskStatus::Dead)
394 .count();
395
396 let total_task_count = self.tasks.len();
397 let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
398
399 if current_dead_percentage > threshold {
400 #[cfg(feature = "with_tracing")]
401 error!(
402 "Dead tasks threshold exceeded: {:.2}% > {:.2}% ({}/{} tasks dead)",
403 current_dead_percentage * 100.0,
404 threshold * 100.0,
405 dead_task_count,
406 total_task_count
407 );
408
409 #[allow(unused_variables)]
411 for (task_name, task_handle) in self.tasks.iter_mut() {
412 if task_handle.status != TaskStatus::Dead
413 && task_handle.status != TaskStatus::Completed
414 {
415 #[cfg(feature = "with_tracing")]
416 debug!("Killing task '{task_name}' due to threshold breach");
417
418 task_handle.clean().await;
419 task_handle.mark(TaskStatus::Dead);
420 }
421 }
422
423 return Err(SupervisorError::TooManyDeadTasks {
424 current_percentage: current_dead_percentage,
425 threshold,
426 });
427 }
428 }
429 };
430
431 Ok(())
432 }
433}