task_supervisor/supervisor/
mod.rs1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5 collections::{BinaryHeap, HashMap},
6 sync::Arc,
7 time::{Duration, Instant},
8};
9
10use tokio::{sync::mpsc, time::interval};
11use tokio_util::sync::CancellationToken;
12
13#[cfg(feature = "tracing")]
14use tracing::{debug, error, info, warn};
15
16use crate::{
17 supervisor::handle::{SupervisorHandle, SupervisorMessage},
18 task::{TaskHandle, TaskResult, TaskStatus},
19};
20
21#[derive(Clone, Debug, thiserror::Error)]
22pub enum SupervisorError {
23 #[error("Too many tasks are dead (threshold exceeded: {current_percentage:.2}% > {threshold:.2}%), supervisor shutting down.")]
24 TooManyDeadTasks {
25 current_percentage: f64,
26 threshold: f64,
27 },
28}
29
30#[derive(Debug)]
31pub(crate) enum SupervisedTaskMessage {
32 Completed(Arc<str>, TaskResult),
33 Shutdown,
34}
35
36struct PendingRestart {
38 deadline: tokio::time::Instant,
39 task_name: Arc<str>,
40}
41
42impl PartialEq for PendingRestart {
43 fn eq(&self, other: &Self) -> bool {
44 self.deadline == other.deadline
45 }
46}
47
48impl Eq for PendingRestart {}
49
50impl PartialOrd for PendingRestart {
51 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
52 Some(self.cmp(other))
53 }
54}
55
56impl Ord for PendingRestart {
57 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
58 other.deadline.cmp(&self.deadline)
60 }
61}
62
63pub struct Supervisor {
64 pub(crate) tasks: HashMap<Arc<str>, TaskHandle>,
65 pub(crate) health_check_interval: Duration,
66 pub(crate) base_restart_delay: Duration,
67 pub(crate) task_is_stable_after: Duration,
68 pub(crate) max_restart_attempts: Option<u32>,
69 pub(crate) max_backoff_exponent: u32,
70 pub(crate) max_dead_tasks_percentage_threshold: Option<f64>,
71 pub(crate) external_tx: mpsc::UnboundedSender<SupervisorMessage>,
72 pub(crate) external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
73 pub(crate) internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
74 pub(crate) internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
75}
76
77impl Supervisor {
78 pub fn run(self) -> SupervisorHandle {
79 let user_tx = self.external_tx.clone();
80 let handle = tokio::spawn(async move { self.run_and_supervise().await });
81 SupervisorHandle::new(handle, user_tx)
82 }
83
84 async fn run_and_supervise(mut self) -> Result<(), SupervisorError> {
85 self.start_all_tasks();
86 self.supervise_all_tasks().await
87 }
88
89 fn start_all_tasks(&mut self) {
90 let task_names: Vec<Arc<str>> = self.tasks.keys().cloned().collect();
91 for task_name in task_names {
92 self.start_task(&task_name);
93 }
94 }
95
96 async fn supervise_all_tasks(&mut self) -> Result<(), SupervisorError> {
97 let mut health_check_ticker = interval(self.health_check_interval);
98 let mut pending_restarts: BinaryHeap<PendingRestart> = BinaryHeap::new();
99
100 loop {
101 let next_restart = async {
102 match pending_restarts.peek() {
103 Some(pr) => tokio::time::sleep_until(pr.deadline).await,
104 None => std::future::pending().await,
105 }
106 };
107
108 tokio::select! {
109 biased;
110 Some(internal_msg) = self.internal_rx.recv() => {
111 match internal_msg {
112 SupervisedTaskMessage::Shutdown => {
113 #[cfg(feature = "tracing")]
114 info!("Supervisor received shutdown signal");
115 return Ok(());
116 }
117 SupervisedTaskMessage::Completed(task_name, outcome) => {
118 #[cfg(feature = "tracing")]
119 match &outcome {
120 Ok(()) => info!("Task '{}' completed successfully", task_name),
121 Err(e) => warn!("Task '{}' completed with error: {e}", task_name),
122 }
123 self.handle_task_completion(&task_name, outcome, &mut pending_restarts);
124 }
125 }
126 },
127 Some(user_msg) = self.external_rx.recv() => {
128 self.handle_user_message(user_msg, &mut pending_restarts);
129 },
130 _ = next_restart => {
131 if let Some(pr) = pending_restarts.pop() {
132 self.restart_task(&pr.task_name);
133 }
134 },
135 _ = health_check_ticker.tick() => {
136 #[cfg(feature = "tracing")]
137 debug!("Health check tick");
138 self.check_all_health(&mut pending_restarts);
139 self.check_dead_tasks_threshold()?;
140 }
141 }
142 }
143 }
144
145 fn handle_user_message(
146 &mut self,
147 msg: SupervisorMessage,
148 pending_restarts: &mut BinaryHeap<PendingRestart>,
149 ) {
150 match msg {
151 SupervisorMessage::AddTask(task_name, task_dyn) => {
152 let key: Arc<str> = Arc::from(task_name);
153
154 if self.tasks.contains_key(&key) {
155 #[cfg(feature = "tracing")]
156 warn!("Task '{}' already exists, ignoring add", key);
157 return;
158 }
159
160 let mut task_handle = TaskHandle::new(task_dyn);
161 task_handle.max_restart_attempts = self.max_restart_attempts;
162 task_handle.base_restart_delay = self.base_restart_delay;
163 task_handle.max_backoff_exponent = self.max_backoff_exponent;
164
165 self.tasks.insert(Arc::clone(&key), task_handle);
166 self.start_task(&key);
167 }
168 SupervisorMessage::RestartTask(task_name) => {
169 let key: Arc<str> = Arc::from(task_name);
170 #[cfg(feature = "tracing")]
171 info!("User requested restart for task: {}", key);
172 self.restart_task(&key);
173 }
174 SupervisorMessage::KillTask(task_name) => {
175 let key: Arc<str> = Arc::from(task_name);
176 if let Some(task_handle) = self.tasks.get_mut(&key) {
177 if task_handle.status != TaskStatus::Dead {
178 task_handle.mark(TaskStatus::Dead);
179 task_handle.clean();
180 }
181 } else {
182 #[cfg(feature = "tracing")]
183 warn!("Attempted to kill non-existent task: {}", key);
184 }
185 }
186 SupervisorMessage::GetTaskStatus(task_name, sender) => {
187 let key: Arc<str> = Arc::from(task_name);
188 let status = self.tasks.get(&key).map(|handle| handle.status);
189 let _ = sender.send(status);
190 }
191 SupervisorMessage::GetAllTaskStatuses(sender) => {
192 let statuses = self
193 .tasks
194 .iter()
195 .map(|(name, handle)| (String::from(name.as_ref()), handle.status))
196 .collect();
197 let _ = sender.send(statuses);
198 }
199 SupervisorMessage::Shutdown => {
200 #[cfg(feature = "tracing")]
201 info!("User requested supervisor shutdown");
202
203 for (_, task_handle) in self.tasks.iter_mut() {
204 if task_handle.status != TaskStatus::Dead
205 && task_handle.status != TaskStatus::Completed
206 {
207 task_handle.clean();
208 task_handle.mark(TaskStatus::Dead);
209 }
210 }
211 pending_restarts.clear();
212 let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
213 }
214 }
215 }
216
217 fn start_task(&mut self, task_name: &Arc<str>) {
218 let Some(task_handle) = self.tasks.get_mut(task_name) else {
219 return;
220 };
221
222 task_handle.mark(TaskStatus::Healthy);
223
224 let token = CancellationToken::new();
225 task_handle.cancellation_token = Some(token.clone());
226
227 let mut task_instance = task_handle.task.clone_box();
229 let internal_tx = self.internal_tx.clone();
230 let name = Arc::clone(task_name);
231
232 let join_handle = tokio::spawn(async move {
233 tokio::select! {
234 _ = token.cancelled() => { }
235 result = task_instance.run_boxed() => {
236 let _ = internal_tx.send(SupervisedTaskMessage::Completed(name, result));
237 }
238 }
239 });
240
241 task_handle.join_handle = Some(join_handle);
242 }
243
244 fn restart_task(&mut self, task_name: &Arc<str>) {
245 if let Some(task_handle) = self.tasks.get_mut(task_name) {
246 task_handle.clean();
247 }
248 self.start_task(task_name);
249 }
250
251 fn check_all_health(&mut self, pending_restarts: &mut BinaryHeap<PendingRestart>) {
252 let now = Instant::now();
253 let mut failed_names: Vec<Arc<str>> = Vec::new();
254
255 for (task_name, task_handle) in self.tasks.iter_mut() {
256 if task_handle.status != TaskStatus::Healthy {
257 continue;
258 }
259
260 if let Some(handle) = &task_handle.join_handle {
261 if handle.is_finished() {
262 #[cfg(feature = "tracing")]
263 warn!(
264 "Task '{}' unexpectedly finished, marking as failed",
265 task_name
266 );
267
268 task_handle.mark(TaskStatus::Failed);
269 failed_names.push(Arc::clone(task_name));
270 } else {
271 if let Some(healthy_since) = task_handle.healthy_since {
273 if now.duration_since(healthy_since) > self.task_is_stable_after
274 && task_handle.restart_attempts > 0
275 {
276 #[cfg(feature = "tracing")]
277 info!(
278 "Task '{}' is now stable, resetting restart attempts",
279 task_name
280 );
281 task_handle.restart_attempts = 0;
282 }
283 } else {
284 task_handle.healthy_since = Some(now);
285 }
286 }
287 } else {
288 #[cfg(feature = "tracing")]
289 error!("Task '{}' has no join handle, marking as failed", task_name);
290
291 task_handle.mark(TaskStatus::Failed);
292 failed_names.push(Arc::clone(task_name));
293 }
294 }
295
296 for task_name in failed_names {
297 self.schedule_restart_or_kill(&task_name, pending_restarts);
298 }
299 }
300
301 fn handle_task_completion(
302 &mut self,
303 task_name: &Arc<str>,
304 outcome: TaskResult,
305 pending_restarts: &mut BinaryHeap<PendingRestart>,
306 ) {
307 let Some(task_handle) = self.tasks.get_mut(task_name) else {
308 #[cfg(feature = "tracing")]
309 warn!("Completion for unknown task: {}", task_name);
310 return;
311 };
312
313 task_handle.clean();
314
315 match outcome {
316 Ok(()) => {
317 #[cfg(feature = "tracing")]
318 info!("Task '{}' completed successfully", task_name);
319 task_handle.mark(TaskStatus::Completed);
320 }
321 #[allow(unused_variables)]
322 Err(ref e) => {
323 #[cfg(feature = "tracing")]
324 error!("Task '{}' failed: {:?}", task_name, e);
325
326 task_handle.mark(TaskStatus::Failed);
327 self.schedule_restart_or_kill(task_name, pending_restarts);
328 }
329 }
330 }
331
332 fn schedule_restart_or_kill(
334 &mut self,
335 task_name: &Arc<str>,
336 pending_restarts: &mut BinaryHeap<PendingRestart>,
337 ) {
338 let Some(task_handle) = self.tasks.get_mut(task_name) else {
339 return;
340 };
341
342 if task_handle.has_exceeded_max_retries() {
343 #[cfg(feature = "tracing")]
344 error!(
345 "Task '{}' exceeded max restart attempts ({:?}), marking as dead",
346 task_name,
347 task_handle
348 .max_restart_attempts
349 .expect("is provided if has exceeded")
350 );
351
352 task_handle.mark(TaskStatus::Dead);
353 task_handle.clean();
354 return;
355 }
356
357 task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
358 let restart_delay = task_handle.restart_delay();
359
360 #[cfg(feature = "tracing")]
361 info!(
362 "Scheduling restart for task '{}' in {:?} (attempt {}/{})",
363 task_name,
364 restart_delay,
365 task_handle.restart_attempts,
366 task_handle
367 .max_restart_attempts
368 .map(|t| t.to_string())
369 .unwrap_or_else(|| "\u{221e}".to_string())
370 );
371
372 pending_restarts.push(PendingRestart {
373 deadline: tokio::time::Instant::now() + restart_delay,
374 task_name: Arc::clone(task_name),
375 });
376 }
377
378 fn check_dead_tasks_threshold(&mut self) -> Result<(), SupervisorError> {
379 let Some(threshold) = self.max_dead_tasks_percentage_threshold else {
380 return Ok(());
381 };
382
383 let total_task_count = self.tasks.len();
384 if total_task_count == 0 {
385 return Ok(());
386 }
387
388 let dead_task_count = self
389 .tasks
390 .values()
391 .filter(|handle| handle.status == TaskStatus::Dead)
392 .count();
393
394 let current_dead_percentage = dead_task_count as f64 / total_task_count as f64;
395
396 if current_dead_percentage <= threshold {
397 return Ok(());
398 }
399
400 #[cfg(feature = "tracing")]
401 error!(
402 "Dead tasks threshold exceeded: {:.2}% > {:.2}% ({}/{} tasks dead)",
403 current_dead_percentage * 100.0,
404 threshold * 100.0,
405 dead_task_count,
406 total_task_count
407 );
408
409 #[allow(unused_variables)]
410 for (task_name, task_handle) in self.tasks.iter_mut() {
411 if task_handle.status != TaskStatus::Dead && task_handle.status != TaskStatus::Completed
412 {
413 #[cfg(feature = "tracing")]
414 debug!("Killing task '{}' due to threshold breach", task_name);
415
416 task_handle.clean();
417 task_handle.mark(TaskStatus::Dead);
418 }
419 }
420
421 Err(SupervisorError::TooManyDeadTasks {
422 current_percentage: current_dead_percentage,
423 threshold,
424 })
425 }
426}