task_supervisor/supervisor/
mod.rs1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5 collections::{BinaryHeap, HashMap},
6 sync::Arc,
7 time::{Duration, Instant},
8};
9
10use tokio::{sync::mpsc, time::interval};
11use tokio_util::sync::CancellationToken;
12
13#[cfg(feature = "with_tracing")]
14use tracing::{debug, error, info, warn};
15
16use crate::{
17 supervisor::handle::{SupervisorHandle, SupervisorMessage},
18 task::{TaskHandle, TaskResult, TaskStatus},
19};
20
21#[derive(Clone, Debug, thiserror::Error)]
22pub enum SupervisorError {
23 #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
24 TooManyDeadTasks {
25 current_percentage: f64,
26 threshold: f64,
27 },
28}
29
30#[derive(Debug)]
32pub(crate) enum SupervisedTaskMessage {
33 Completed(Arc<str>, TaskResult),
35 Shutdown,
37}
38
39struct PendingRestart {
41 deadline: tokio::time::Instant,
42 task_name: Arc<str>,
43}
44
45impl PartialEq for PendingRestart {
46 fn eq(&self, other: &Self) -> bool {
47 self.deadline == other.deadline
48 }
49}
50
51impl Eq for PendingRestart {}
52
53impl PartialOrd for PendingRestart {
54 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
55 Some(self.cmp(other))
56 }
57}
58
59impl Ord for PendingRestart {
60 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
61 other.deadline.cmp(&self.deadline)
63 }
64}
65
66pub struct Supervisor {
72 pub(crate) tasks: HashMap<Arc<str>, TaskHandle>,
73 pub(crate) health_check_interval: Duration,
75 pub(crate) base_restart_delay: Duration,
76 pub(crate) task_is_stable_after: Duration,
77 pub(crate) max_restart_attempts: Option<u32>,
78 pub(crate) max_backoff_exponent: u32,
79 pub(crate) max_dead_tasks_percentage_threshold: Option<f64>,
80 pub(crate) external_tx: mpsc::UnboundedSender<SupervisorMessage>,
82 pub(crate) external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
83 pub(crate) internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
85 pub(crate) internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
86}
87
88impl Supervisor {
89 pub fn run(self) -> SupervisorHandle {
91 let user_tx = self.external_tx.clone();
92 let handle = tokio::spawn(async move { self.run_and_supervise().await });
93 SupervisorHandle::new(handle, user_tx)
94 }
95
96 async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
97 self.start_all_tasks();
98 self.supervise_all_tasks().await
99 }
100
101 fn start_all_tasks(&mut self) {
102 let task_names: Vec<Arc<str>> = self.tasks.keys().cloned().collect();
103 for task_name in task_names {
104 self.start_task(&task_name);
105 }
106 }
107
108 async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
110 let mut health_check_ticker = interval(self.health_check_interval);
111 let mut pending_restarts: BinaryHeap<PendingRestart> = BinaryHeap::new();
112
113 loop {
114 let next_restart = async {
116 match pending_restarts.peek() {
117 Some(pr) => tokio::time::sleep_until(pr.deadline).await,
118 None => std::future::pending().await,
119 }
120 };
121
122 tokio::select! {
123 biased;
124 Some(internal_msg) = self.internal_rx.recv() => {
125 match internal_msg {
126 SupervisedTaskMessage::Shutdown => {
127 #[cfg(feature = "with_tracing")]
128 info!("Supervisor received shutdown signal");
129 return Ok(());
130 }
131 SupervisedTaskMessage::Completed(task_name, outcome) => {
132 #[cfg(feature = "with_tracing")]
133 match &outcome {
134 Ok(()) => info!("Task '{}' completed successfully", task_name),
135 Err(e) => warn!("Task '{}' completed with error: {e}", task_name),
136 }
137 self.handle_task_completion(&task_name, outcome, &mut pending_restarts);
138 }
139 }
140 },
141 Some(user_msg) = self.external_rx.recv() => {
142 self.handle_user_message(user_msg, &mut pending_restarts);
143 },
144 _ = next_restart => {
145 if let Some(pr) = pending_restarts.pop() {
147 self.restart_task(&pr.task_name);
148 }
149 },
150 _ = health_check_ticker.tick() => {
151 #[cfg(feature = "with_tracing")]
152 debug!("Supervisor checking health of all tasks");
153 self.check_all_health(&mut pending_restarts);
154 self.check_dead_tasks_threshold()?;
155 }
156 }
157 }
158 }
159
160 fn handle_user_message(
162 &mut self,
163 msg: SupervisorMessage,
164 pending_restarts: &mut BinaryHeap<PendingRestart>,
165 ) {
166 match msg {
167 SupervisorMessage::AddTask(task_name, task_dyn) => {
168 let key: Arc<str> = Arc::from(task_name);
169
170 if self.tasks.contains_key(&key) {
172 #[cfg(feature = "with_tracing")]
173 warn!("Attempted to add task '{}' but it already exists", key);
174 return;
175 }
176
177 let mut task_handle = TaskHandle::new(task_dyn);
178 task_handle.max_restart_attempts = self.max_restart_attempts;
179 task_handle.base_restart_delay = self.base_restart_delay;
180 task_handle.max_backoff_exponent = self.max_backoff_exponent;
181
182 self.tasks.insert(Arc::clone(&key), task_handle);
183 self.start_task(&key);
184 }
185 SupervisorMessage::RestartTask(task_name) => {
186 let key: Arc<str> = Arc::from(task_name);
187 #[cfg(feature = "with_tracing")]
188 info!("User requested restart for task: {}", key);
189 self.restart_task(&key);
190 }
191 SupervisorMessage::KillTask(task_name) => {
192 let key: Arc<str> = Arc::from(task_name);
193 if let Some(task_handle) = self.tasks.get_mut(&key) {
194 if task_handle.status != TaskStatus::Dead {
195 task_handle.mark(TaskStatus::Dead);
196 task_handle.clean();
197 }
198 } else {
199 #[cfg(feature = "with_tracing")]
200 warn!("Attempted to kill non-existent task: {}", key);
201 }
202 }
203 SupervisorMessage::GetTaskStatus(task_name, sender) => {
204 let key: Arc<str> = Arc::from(task_name);
205 let status = self.tasks.get(&key).map(|handle| handle.status);
206
207 #[cfg(feature = "with_tracing")]
208 debug!("Status query for task '{}': {:?}", key, status);
209
210 let _ = sender.send(status);
211 }
212 SupervisorMessage::GetAllTaskStatuses(sender) => {
213 let statuses = self
214 .tasks
215 .iter()
216 .map(|(name, handle)| (String::from(name.as_ref()), handle.status))
217 .collect();
218 let _ = sender.send(statuses);
219 }
220 SupervisorMessage::Shutdown => {
221 #[cfg(feature = "with_tracing")]
222 info!("User requested supervisor shutdown");
223
224 for (_, task_handle) in self.tasks.iter_mut() {
225 if task_handle.status != TaskStatus::Dead
226 && task_handle.status != TaskStatus::Completed
227 {
228 task_handle.clean();
229 task_handle.mark(TaskStatus::Dead);
230 }
231 }
232 pending_restarts.clear();
233 let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
234 }
235 }
236 }
237
238 fn start_task(&mut self, task_name: &Arc<str>) {
240 let Some(task_handle) = self.tasks.get_mut(task_name) else {
241 return;
242 };
243
244 task_handle.mark(TaskStatus::Healthy);
245
246 let token = CancellationToken::new();
247 task_handle.cancellation_token = Some(token.clone());
248
249 let mut task_instance = task_handle.task.clone_box();
250 let internal_tx = self.internal_tx.clone();
251 let name = Arc::clone(task_name);
252
253 let join_handle = tokio::spawn(async move {
254 tokio::select! {
255 _ = token.cancelled() => { }
256 result = task_instance.run_boxed() => {
257 let _ = internal_tx.send(SupervisedTaskMessage::Completed(name, result));
258 }
259 }
260 });
261
262 task_handle.join_handle = Some(join_handle);
263 }
264
265 fn restart_task(&mut self, task_name: &Arc<str>) {
267 if let Some(task_handle) = self.tasks.get_mut(task_name) {
268 task_handle.clean();
269 }
270 self.start_task(task_name);
271 }
272
273 fn check_all_health(&mut self, pending_restarts: &mut BinaryHeap<PendingRestart>) {
274 let now = Instant::now();
275
276 let mut failed_names: Vec<Arc<str>> = Vec::new();
280
281 for (task_name, task_handle) in self.tasks.iter_mut() {
282 if task_handle.status != TaskStatus::Healthy {
283 continue;
284 }
285
286 if let Some(handle) = &task_handle.join_handle {
287 if handle.is_finished() {
288 #[cfg(feature = "with_tracing")]
289 warn!(
290 "Task '{}' unexpectedly finished, marking as failed",
291 task_name
292 );
293
294 task_handle.mark(TaskStatus::Failed);
295 failed_names.push(Arc::clone(task_name));
296 } else {
297 if let Some(healthy_since) = task_handle.healthy_since {
300 if now.duration_since(healthy_since) > self.task_is_stable_after
301 && task_handle.restart_attempts > 0
302 {
303 #[cfg(feature = "with_tracing")]
304 info!(
305 "Task '{}' is now stable, resetting restart attempts",
306 task_name
307 );
308 task_handle.restart_attempts = 0;
309 }
310 } else {
311 task_handle.healthy_since = Some(now);
312 }
313 }
314 } else {
315 #[cfg(feature = "with_tracing")]
316 error!("Task '{}' has no join handle, marking as failed", task_name);
317
318 task_handle.mark(TaskStatus::Failed);
319 failed_names.push(Arc::clone(task_name));
320 }
321 }
322
323 for task_name in failed_names {
324 self.schedule_restart_or_kill(&task_name, pending_restarts);
325 }
326 }
327
328 fn handle_task_completion(
329 &mut self,
330 task_name: &Arc<str>,
331 outcome: TaskResult,
332 pending_restarts: &mut BinaryHeap<PendingRestart>,
333 ) {
334 let Some(task_handle) = self.tasks.get_mut(task_name) else {
335 #[cfg(feature = "with_tracing")]
336 warn!("Received completion for non-existent task: {}", task_name);
337 return;
338 };
339
340 task_handle.clean();
341
342 match outcome {
343 Ok(()) => {
344 #[cfg(feature = "with_tracing")]
345 info!("Task '{}' completed successfully", task_name);
346
347 task_handle.mark(TaskStatus::Completed);
348 }
349 #[allow(unused_variables)]
350 Err(ref e) => {
351 #[cfg(feature = "with_tracing")]
352 error!("Task '{}' failed with error: {:?}", task_name, e);
353
354 task_handle.mark(TaskStatus::Failed);
355 self.schedule_restart_or_kill(task_name, pending_restarts);
356 }
357 }
358 }
359
360 fn schedule_restart_or_kill(
363 &mut self,
364 task_name: &Arc<str>,
365 pending_restarts: &mut BinaryHeap<PendingRestart>,
366 ) {
367 let Some(task_handle) = self.tasks.get_mut(task_name) else {
368 return;
369 };
370
371 if task_handle.has_exceeded_max_retries() {
372 #[cfg(feature = "with_tracing")]
373 error!(
374 "Task '{}' exceeded max restart attempts ({:?}), marking as dead",
375 task_name,
376 task_handle
377 .max_restart_attempts
378 .expect("is provided if has exceeded")
379 );
380
381 task_handle.mark(TaskStatus::Dead);
382 task_handle.clean();
383 return;
384 }
385
386 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
387 let restart_delay = task_handle.restart_delay();
388
389 #[cfg(feature = "with_tracing")]
390 info!(
391 "Scheduling restart for task '{}' in {:?} (attempt {}/{})",
392 task_name,
393 restart_delay,
394 task_handle.restart_attempts,
395 task_handle
396 .max_restart_attempts
397 .map(|t| t.to_string())
398 .unwrap_or_else(|| "\u{221e}".to_string())
399 );
400
401 pending_restarts.push(PendingRestart {
402 deadline: tokio::time::Instant::now() + restart_delay,
403 task_name: Arc::clone(task_name),
404 });
405 }
406
407 fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
408 let Some(threshold) = self.max_dead_tasks_percentage_threshold else {
409 return Ok(());
410 };
411
412 let total_task_count = self.tasks.len();
413 if total_task_count == 0 {
414 return Ok(());
415 }
416
417 let dead_task_count = self
419 .tasks
420 .values()
421 .filter(|handle| handle.status == TaskStatus::Dead)
422 .count();
423
424 let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
425
426 if current_dead_percentage <= threshold {
427 return Ok(());
428 }
429
430 #[cfg(feature = "with_tracing")]
431 error!(
432 "Dead tasks threshold exceeded: {:.2}% > {:.2}% ({}/{} tasks dead)",
433 current_dead_percentage * 100.0,
434 threshold * 100.0,
435 dead_task_count,
436 total_task_count
437 );
438
439 #[allow(unused_variables)]
441 for (task_name, task_handle) in self.tasks.iter_mut() {
442 if task_handle.status != TaskStatus::Dead && task_handle.status != TaskStatus::Completed
443 {
444 #[cfg(feature = "with_tracing")]
445 debug!("Killing task '{}' due to threshold breach", task_name);
446
447 task_handle.clean();
448 task_handle.mark(TaskStatus::Dead);
449 }
450 }
451
452 Err(SupervisorError::TooManyDeadTasks {
453 current_percentage: current_dead_percentage,
454 threshold,
455 })
456 }
457}