task_supervisor/supervisor/
mod.rs

1pub(crate) mod builder;
2pub(crate) mod handle;
3
4use std::{
5    collections::HashMap,
6    time::{Duration, Instant},
7};
8
9use handle::SupervisorMessage;
10use tokio::{sync::mpsc, time::interval_at};
11use tokio_util::sync::CancellationToken;
12
13use crate::{
14    supervisor::handle::SupervisorHandle,
15    task::{TaskHandle, TaskOutcome, TaskStatus},
16};
17
18/// A heartbeat sent from a task to indicate that it's alive.
19#[derive(Debug, Clone)]
20pub(crate) struct Heartbeat {
21    pub(crate) task_name: String,
22    pub(crate) timestamp: Instant,
23}
24
25impl Heartbeat {
26    /// Creates a new heartbeat for the given task name with the current timestamp.
27    pub fn new(task_name: &String) -> Self {
28        Self {
29            task_name: task_name.to_string(),
30            timestamp: Instant::now(),
31        }
32    }
33}
34
35/// Internal messages sent from tasks and by the `Supervisor` to manage task lifecycle.
36#[derive(Debug)]
37pub(crate) enum SupervisedTaskMessage {
38    /// Sent by tasks to indicate they are alive.
39    Heartbeat(Heartbeat),
40    /// Sent by the supervisor to trigger a task restart.
41    Restart(String),
42    /// Sent when a task completes, either successfully or with a failure.
43    Completed(String, TaskOutcome),
44    /// Sent when a shutdown signal is asked by the User. Close every tasks & the supervisor.
45    Shutdown,
46}
47
48/// Manages a set of tasks, ensuring they remain operational through heartbeats and restarts.
49///
50/// The `Supervisor` spawns each task with a heartbeat mechanism to monitor liveness.
51/// If a task stops sending heartbeats or fails, it is restarted with an exponential backoff.
52/// User commands such as adding, restarting, or killing tasks are supported via the `SupervisorHandle`.
53pub struct Supervisor {
54    // List of current tasks
55    tasks: HashMap<String, TaskHandle>,
56    // Durations for tasks lifecycle
57    timeout_threshold: Duration,
58    heartbeat_interval: Duration,
59    health_check_initial_delay: Duration,
60    health_check_interval: Duration,
61    base_restart_delay: Duration,
62    task_is_stable_after: Duration,
63    max_restart_attempts: u32,
64    // Channels between the User & the Supervisor
65    external_tx: mpsc::UnboundedSender<SupervisorMessage>,
66    external_rx: mpsc::UnboundedReceiver<SupervisorMessage>,
67    // Channels used in the Supervisor internally (between threads)
68    internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
69    internal_rx: mpsc::UnboundedReceiver<SupervisedTaskMessage>,
70}
71
72impl Supervisor {
73    /// Runs the supervisor, consuming it and returning a handle for external control.
74    ///
75    /// This method initiates all tasks and starts the supervision loop.
76    pub fn run(self) -> SupervisorHandle {
77        let user_tx = self.external_tx.clone();
78        let handle = tokio::spawn(async move {
79            self.run_and_supervise().await;
80        });
81        SupervisorHandle::new(handle, user_tx)
82    }
83
84    /// Starts and supervises all tasks, running the main supervision loop.
85    async fn run_and_supervise(mut self) {
86        self.start_all_tasks().await;
87        self.supervise_all_tasks().await;
88    }
89
90    /// Initiates all tasks managed by the supervisor.
91    async fn start_all_tasks(&mut self) {
92        for (task_name, task_handle) in self.tasks.iter_mut() {
93            Self::start_task(
94                task_name.to_string(),
95                task_handle,
96                self.internal_tx.clone(),
97                self.heartbeat_interval,
98            )
99            .await;
100        }
101    }
102
103    /// Supervises tasks by processing messages and performing periodic health checks.
104    async fn supervise_all_tasks(&mut self) {
105        let mut health_check_interval = interval_at(
106            tokio::time::Instant::now() + self.health_check_initial_delay,
107            self.health_check_interval,
108        );
109
110        loop {
111            tokio::select! {
112                Some(internal_msg) = self.internal_rx.recv() => {
113                    // Exit the supervising loop
114                    if matches!(internal_msg, SupervisedTaskMessage::Shutdown) {
115                        return;
116                    }
117                    self.handle_internal_message(internal_msg).await;
118                },
119                Some(user_msg) = self.external_rx.recv() => {
120                    self.handle_user_message(user_msg).await;
121                },
122                _ = health_check_interval.tick() => {
123                    self.check_all_health();
124                }
125            }
126        }
127    }
128
129    async fn handle_internal_message(&mut self, msg: SupervisedTaskMessage) {
130        match msg {
131            SupervisedTaskMessage::Heartbeat(heartbeat) => {
132                self.register_heartbeat(heartbeat);
133            }
134            SupervisedTaskMessage::Restart(task_name) => {
135                self.restart_task(task_name).await;
136            }
137            SupervisedTaskMessage::Completed(task_name, outcome) => {
138                self.handle_task_completion(task_name, outcome).await;
139            }
140            SupervisedTaskMessage::Shutdown => unreachable!(),
141        }
142    }
143
144    /// Processes user commands received via the `SupervisorHandle`.
145    async fn handle_user_message(&mut self, msg: SupervisorMessage) {
146        match msg {
147            SupervisorMessage::AddTask(task_name, task) => {
148                if self.tasks.contains_key(&task_name) {
149                    return;
150                }
151                let mut task_handle =
152                    TaskHandle::new(task, self.max_restart_attempts, self.base_restart_delay);
153                Self::start_task(
154                    task_name.clone(),
155                    &mut task_handle,
156                    self.internal_tx.clone(),
157                    self.heartbeat_interval,
158                )
159                .await;
160                self.tasks.insert(task_name, task_handle);
161            }
162            SupervisorMessage::RestartTask(task_name) => {
163                self.restart_task(task_name).await;
164            }
165            SupervisorMessage::KillTask(task_name) => {
166                let Some(task_handle) = self.tasks.get_mut(&task_name) else {
167                    return;
168                };
169                if task_handle.status == TaskStatus::Dead {
170                    return;
171                }
172                task_handle.mark(TaskStatus::Dead);
173                task_handle.clean().await;
174            }
175            SupervisorMessage::GetTaskStatus(task_name, sender) => {
176                let status = self.tasks.get(&task_name).map(|handle| handle.status);
177                let _ = sender.send(status);
178            }
179            SupervisorMessage::GetAllTaskStatuses(sender) => {
180                let statuses = self
181                    .tasks
182                    .iter()
183                    .map(|(name, handle)| (name.clone(), handle.status))
184                    .collect();
185                let _ = sender.send(statuses);
186            }
187            SupervisorMessage::Shutdown => {
188                for (_, task_handle) in self.tasks.iter_mut() {
189                    if task_handle.status != TaskStatus::Dead {
190                        task_handle.clean().await;
191                        task_handle.mark(TaskStatus::Dead);
192                    }
193                }
194                let _ = self.internal_tx.send(SupervisedTaskMessage::Shutdown);
195            }
196        }
197    }
198
199    /// Starts a task, setting up its execution, heartbeat, and completion handling.
200    async fn start_task(
201        task_name: String,
202        task_handle: &mut TaskHandle,
203        internal_tx: mpsc::UnboundedSender<SupervisedTaskMessage>,
204        heartbeat_interval: Duration,
205    ) {
206        let token = CancellationToken::new();
207
208        // Completion task
209        let (completion_tx, mut completion_rx) = mpsc::channel::<TaskOutcome>(1);
210        let tx_heartbeat = internal_tx.clone();
211        let task_name_clone = task_name.clone();
212        let token_completion = token.clone();
213        let completion_task = tokio::spawn(async move {
214            tokio::select! {
215                _ = token_completion.cancelled() => {
216                    // Task cancelled
217                }
218                Some(outcome) = completion_rx.recv() => {
219                    let completion_msg = SupervisedTaskMessage::Completed(task_name_clone, outcome);
220                    let _ = tx_heartbeat.send(completion_msg);
221                    token_completion.cancel(); // Cancel other tasks upon completion
222                }
223            }
224        });
225
226        // Heartbeat task
227        let task_name_heartbeat = task_name.clone();
228        let token_heartbeat = token.clone();
229        let heartbeat_task = tokio::spawn(async move {
230            let mut beat_interval = tokio::time::interval(heartbeat_interval);
231            loop {
232                tokio::select! {
233                    _ = beat_interval.tick() => {
234                        let beat = SupervisedTaskMessage::Heartbeat(Heartbeat::new(&task_name_heartbeat));
235                        if internal_tx.send(beat).is_err() {
236                            break;
237                        }
238                    }
239                    _ = token_heartbeat.cancelled() => {
240                        break; // Stop heartbeat on cancellation
241                    }
242                }
243            }
244        });
245
246        // Main task
247        let mut task = task_handle.task.clone_box();
248        let token_main = token.clone();
249        let ran_task = tokio::spawn(async move {
250            tokio::select! {
251                _ = token_main.cancelled() => {
252                    // Task cancelled
253                }
254                _ = async {
255                    match task.run().await {
256                        Ok(outcome) => {
257                            let _ = completion_tx.send(outcome).await;
258                        }
259                        Err(e) => {
260                            let _ = completion_tx.send(TaskOutcome::Failed(e.to_string())).await;
261                        }
262                    }
263                } => {}
264            }
265        });
266
267        // Mark the task as `Starting`
268        task_handle.mark(TaskStatus::Starting);
269        // Store the token and handles in TaskHandle
270        task_handle.cancellation_token = Some(token);
271        task_handle.handles = Some(vec![ran_task, heartbeat_task, completion_task]);
272    }
273
274    /// Updates a task's status based on received heartbeats.
275    fn register_heartbeat(&mut self, heartbeat: Heartbeat) {
276        let Some(task_handle) = self.tasks.get_mut(&heartbeat.task_name) else {
277            return;
278        };
279
280        if task_handle.status == TaskStatus::Dead {
281            return;
282        }
283
284        task_handle.ticked_at(heartbeat.timestamp);
285
286        match task_handle.status {
287            TaskStatus::Starting => {
288                task_handle.mark(TaskStatus::Healthy);
289                task_handle.healthy_since = Some(heartbeat.timestamp);
290            }
291            TaskStatus::Healthy => {
292                // Reset the `restart_attempts` if the task has been healthy & stable for some time
293                if let Some(healthy_since) = task_handle.healthy_since {
294                    if heartbeat.timestamp.duration_since(healthy_since) > self.task_is_stable_after
295                    {
296                        task_handle.restart_attempts = 0;
297                    }
298                } else {
299                    task_handle.healthy_since = Some(heartbeat.timestamp);
300                }
301            }
302            _ => {}
303        }
304    }
305
306    /// Restarts a task after cleaning up its previous execution.
307    async fn restart_task(&mut self, task_name: String) {
308        let Some(task_handle) = self.tasks.get_mut(&task_name) else {
309            return;
310        };
311        task_handle.clean().await;
312        task_handle.mark(TaskStatus::Created);
313        Self::start_task(
314            task_name,
315            task_handle,
316            self.internal_tx.clone(),
317            self.heartbeat_interval,
318        )
319        .await;
320    }
321
322    /// Checks task health and schedules restarts for crashed tasks.
323    fn check_all_health(&mut self) {
324        let crashed_tasks = self
325            .tasks
326            .iter()
327            .filter(|(_, handle)| handle.has_crashed(self.timeout_threshold))
328            .map(|(name, _)| name.clone())
329            .collect::<Vec<_>>();
330
331        for crashed_task in crashed_tasks {
332            let Some(task_handle) = self.tasks.get_mut(&crashed_task) else {
333                continue;
334            };
335            if task_handle.has_exceeded_max_retries() && task_handle.status != TaskStatus::Dead {
336                task_handle.mark(TaskStatus::Dead);
337                continue;
338            }
339            let restart_delay = task_handle.restart_delay();
340            task_handle.mark(TaskStatus::Failed);
341            task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
342            let internal_tx = self.internal_tx.clone();
343            tokio::spawn(async move {
344                tokio::time::sleep(restart_delay).await;
345                let _ = internal_tx.send(SupervisedTaskMessage::Restart(crashed_task));
346            });
347        }
348    }
349
350    /// Handles task completion outcomes, deciding whether to mark as completed or restart.
351    async fn handle_task_completion(&mut self, task_name: String, outcome: TaskOutcome) {
352        let Some(task_handle) = self.tasks.get_mut(&task_name) else {
353            return;
354        };
355
356        match outcome {
357            TaskOutcome::Completed => {
358                task_handle.mark(TaskStatus::Completed);
359            }
360            TaskOutcome::Failed(_) => {
361                task_handle.mark(TaskStatus::Failed);
362                if task_handle.has_exceeded_max_retries() {
363                    task_handle.mark(TaskStatus::Dead);
364                    return;
365                }
366                let restart_delay = task_handle.restart_delay();
367                task_handle.restart_attempts = task_handle.restart_attempts.saturating_add(1);
368                let internal_tx = self.internal_tx.clone();
369                tokio::spawn(async move {
370                    tokio::time::sleep(restart_delay).await;
371                    let _ = internal_tx.send(SupervisedTaskMessage::Restart(task_name));
372                });
373            }
374        }
375    }
376}