Skip to main content

task_supervisor/supervisor/
handle.rs

1use std::collections::HashMap;
2
3use thiserror::Error;
4use tokio::sync::{mpsc, oneshot, watch};
5
6use crate::{task::DynTask, SupervisedTask, TaskName, TaskStatus};
7
8use super::SupervisorError;
9
10#[derive(Debug, Error)]
11pub enum SupervisorHandleError {
12    #[error("Failed to send message to supervisor: channel closed")]
13    SendError,
14    #[error("Failed to receive response from supervisor: {0}")]
15    RecvError(#[from] tokio::sync::oneshot::error::RecvError),
16}
17
18pub(crate) enum SupervisorMessage {
19    /// Request to add a new running Task
20    AddTask(TaskName, DynTask),
21    /// Request to restart a Task
22    RestartTask(TaskName),
23    /// Request to kill a Task
24    KillTask(TaskName),
25    /// Request the status of a Task
26    GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
27    /// Request the status of all Tasks
28    GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
29    /// Request sent to shutdown all the running Tasks
30    Shutdown,
31}
32
33/// Handle used to interact with the `Supervisor`.
34///
35/// Cloning this handle is cheap. All clones share the same connection to
36/// the supervisor. Multiple clones can call `wait()` concurrently and all
37/// will receive the result.
38#[derive(Clone)]
39pub struct SupervisorHandle {
40    pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
41    /// Receives the supervisor's final result. Using a `watch` channel
42    /// allows multiple concurrent `wait()` callers to all see the outcome,
43    /// unlike `JoinSet` which would consume the result on the first `join_next()`.
44    result_rx: watch::Receiver<Option<Result<(), SupervisorError>>>,
45}
46
47impl Drop for SupervisorHandle {
48    /// Automatically shuts down the supervisor when the handle is dropped.
49    fn drop(&mut self) {
50        if self.is_channel_open() {
51            let _ = self.shutdown();
52        }
53    }
54}
55
56impl SupervisorHandle {
57    /// Creates a new `SupervisorHandle`.
58    ///
59    /// Spawns a background task that awaits the supervisor's `JoinHandle` and
60    /// broadcasts the result via a `watch` channel.
61    pub(crate) fn new(
62        join_handle: tokio::task::JoinHandle<Result<(), SupervisorError>>,
63        tx: mpsc::UnboundedSender<SupervisorMessage>,
64    ) -> Self {
65        let (result_tx, result_rx) = watch::channel(None);
66
67        tokio::spawn(async move {
68            let result = match join_handle.await {
69                Ok(supervisor_result) => supervisor_result,
70                // Supervisor task was cancelled/panicked — treat as Ok
71                Err(_join_error) => Ok(()),
72            };
73            let _ = result_tx.send(Some(result));
74        });
75
76        Self { tx, result_rx }
77    }
78
79    /// Waits for the supervisor to complete its execution.
80    ///
81    /// Multiple callers can `wait()` concurrently; all will receive the same result.
82    ///
83    /// # Returns
84    /// - `Ok(())` if the supervisor completed successfully.
85    /// - `Err(SupervisorError)` if the supervisor returned an error.
86    pub async fn wait(&self) -> Result<(), SupervisorError> {
87        let mut rx = self.result_rx.clone();
88        // Wait until the value is Some (i.e., the supervisor has finished).
89        loop {
90            if let Some(result) = rx.borrow_and_update().clone() {
91                return result;
92            }
93            // Value is still None — wait for a change.
94            if rx.changed().await.is_err() {
95                // Sender dropped without sending — supervisor is gone.
96                return Ok(());
97            }
98        }
99    }
100
101    /// Adds a new task to the supervisor.
102    ///
103    /// # Arguments
104    /// - `task_name`: The unique name of the task.
105    /// - `task`: The task to be added, which must implement `SupervisedTask`.
106    ///
107    /// # Returns
108    /// - `Ok(())` if the message was sent successfully.
109    /// - `Err(SendError)` if the supervisor is no longer running.
110    pub fn add_task<T: SupervisedTask + Clone>(
111        &self,
112        task_name: &str,
113        task: T,
114    ) -> Result<(), SupervisorHandleError> {
115        self.tx
116            .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
117            .map_err(|_| SupervisorHandleError::SendError)
118    }
119
120    /// Requests the supervisor to restart a specific task.
121    ///
122    /// # Arguments
123    /// - `task_name`: The name of the task to restart.
124    ///
125    /// # Returns
126    /// - `Ok(())` if the message was sent successfully.
127    /// - `Err(SendError)` if the supervisor is no longer running.
128    pub fn restart(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
129        self.tx
130            .send(SupervisorMessage::RestartTask(task_name.into()))
131            .map_err(|_| SupervisorHandleError::SendError)
132    }
133
134    /// Requests the supervisor to kill a specific task.
135    ///
136    /// # Arguments
137    /// - `task_name`: The name of the task to kill.
138    ///
139    /// # Returns
140    /// - `Ok(())` if the message was sent successfully.
141    /// - `Err(SendError)` if the supervisor is no longer running.
142    pub fn kill_task(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
143        self.tx
144            .send(SupervisorMessage::KillTask(task_name.into()))
145            .map_err(|_| SupervisorHandleError::SendError)
146    }
147
148    /// Requests the supervisor to shut down all tasks and stop supervision.
149    ///
150    /// # Returns
151    /// - `Ok(())` if the message was sent successfully.
152    /// - `Err(SendError)` if the supervisor is no longer running.
153    pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
154        self.tx
155            .send(SupervisorMessage::Shutdown)
156            .map_err(|_| SupervisorHandleError::SendError)
157    }
158
159    /// Queries the status of a specific task asynchronously.
160    ///
161    /// # Arguments
162    /// - `task_name`: The name of the task to query.
163    ///
164    /// # Returns
165    /// - `Ok(Some(TaskStatus))` if the task exists and its status is returned.
166    /// - `Ok(None)` if the task does not exist.
167    /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
168    pub async fn get_task_status(
169        &self,
170        task_name: &str,
171    ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
172        let (sender, receiver) = oneshot::channel();
173        self.tx
174            .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
175            .map_err(|_| SupervisorHandleError::SendError)?;
176        receiver.await.map_err(SupervisorHandleError::RecvError)
177    }
178
179    /// Queries the statuses of all tasks asynchronously.
180    ///
181    /// # Returns
182    /// - `Ok(HashMap<TaskName, TaskStatus>)` containing the statuses of all tasks.
183    /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
184    pub async fn get_all_task_statuses(
185        &self,
186    ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
187        let (sender, receiver) = oneshot::channel();
188        self.tx
189            .send(SupervisorMessage::GetAllTaskStatuses(sender))
190            .map_err(|_| SupervisorHandleError::SendError)?;
191        receiver.await.map_err(SupervisorHandleError::RecvError)
192    }
193
194    /// Checks if the supervisor channel is still open.
195    fn is_channel_open(&self) -> bool {
196        !self.tx.is_closed()
197    }
198}
199
200impl std::fmt::Debug for SupervisorHandle {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        f.debug_struct("SupervisorHandle")
203            .field("channel_open", &self.is_channel_open())
204            .finish()
205    }
206}