Skip to main content

task_supervisor/supervisor/
handle.rs

1use std::collections::HashMap;
2
3use thiserror::Error;
4use tokio::sync::{mpsc, oneshot, watch};
5
6use crate::{task::DynTask, SupervisedTask, TaskName, TaskStatus};
7
8use super::SupervisorError;
9
10#[derive(Debug, Error)]
11pub enum SupervisorHandleError {
12    #[error("Failed to send message to supervisor: channel closed")]
13    SendError,
14    #[error("Failed to receive response from supervisor: {0}")]
15    RecvError(#[from] tokio::sync::oneshot::error::RecvError),
16}
17
18pub(crate) enum SupervisorMessage {
19    /// Request to add a new running Task
20    AddTask(TaskName, DynTask),
21    /// Request to restart a Task
22    RestartTask(TaskName),
23    /// Request to kill a Task
24    KillTask(TaskName),
25    /// Request the status of a Task
26    GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
27    /// Request the status of all Tasks
28    GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
29    /// Request sent to shutdown all the running Tasks
30    Shutdown,
31}
32
33/// Handle used to interact with the `Supervisor`.
34///
35/// Cloning this handle is cheap. All clones share the same connection to
36/// the supervisor.
37#[derive(Clone)]
38pub struct SupervisorHandle {
39    pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
40    /// Receives the supervisor's final result.
41    result_rx: watch::Receiver<Option<Result<(), SupervisorError>>>,
42}
43
44impl Drop for SupervisorHandle {
45    /// Automatically shuts down the supervisor when the handle is dropped.
46    fn drop(&mut self) {
47        if self.is_channel_open() {
48            let _ = self.shutdown();
49        }
50    }
51}
52
53impl SupervisorHandle {
54    /// Creates a new `SupervisorHandle`.
55    ///
56    /// Spawns a background task that awaits the supervisor's `JoinHandle` and
57    /// broadcasts the result via a `watch` channel.
58    pub(crate) fn new(
59        join_handle: tokio::task::JoinHandle<Result<(), SupervisorError>>,
60        tx: mpsc::UnboundedSender<SupervisorMessage>,
61    ) -> Self {
62        let (result_tx, result_rx) = watch::channel(None);
63
64        tokio::spawn(async move {
65            let result = match join_handle.await {
66                Ok(supervisor_result) => supervisor_result,
67                // Supervisor task was cancelled/panicked — treat as Ok
68                Err(_join_error) => Ok(()),
69            };
70            let _ = result_tx.send(Some(result));
71        });
72
73        Self { tx, result_rx }
74    }
75
76    /// Waits for the supervisor to complete its execution.
77    ///
78    /// Multiple callers can `wait()` concurrently; all will receive the same result.
79    ///
80    /// # Returns
81    /// - `Ok(())` if the supervisor completed successfully.
82    /// - `Err(SupervisorError)` if the supervisor returned an error.
83    pub async fn wait(&self) -> Result<(), SupervisorError> {
84        let mut rx = self.result_rx.clone();
85        // Wait until the value is Some (i.e., the supervisor has finished).
86        loop {
87            if let Some(result) = rx.borrow_and_update().clone() {
88                return result;
89            }
90            // Value is still None — wait for a change.
91            if rx.changed().await.is_err() {
92                // Sender dropped without sending — supervisor is gone.
93                return Ok(());
94            }
95        }
96    }
97
98    /// Adds a new task to the supervisor.
99    ///
100    /// # Arguments
101    /// - `task_name`: The unique name of the task.
102    /// - `task`: The task to be added, which must implement `SupervisedTask`.
103    ///
104    /// # Returns
105    /// - `Ok(())` if the message was sent successfully.
106    /// - `Err(SendError)` if the supervisor is no longer running.
107    pub fn add_task<T: SupervisedTask + Clone>(
108        &self,
109        task_name: impl Into<String>,
110        task: T,
111    ) -> Result<(), SupervisorHandleError> {
112        self.tx
113            .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
114            .map_err(|_| SupervisorHandleError::SendError)
115    }
116
117    /// Requests the supervisor to restart a specific task.
118    ///
119    /// # Arguments
120    /// - `task_name`: The name of the task to restart.
121    ///
122    /// # Returns
123    /// - `Ok(())` if the message was sent successfully.
124    /// - `Err(SendError)` if the supervisor is no longer running.
125    pub fn restart(&self, task_name: impl Into<String>) -> Result<(), SupervisorHandleError> {
126        self.tx
127            .send(SupervisorMessage::RestartTask(task_name.into()))
128            .map_err(|_| SupervisorHandleError::SendError)
129    }
130
131    /// Requests the supervisor to kill a specific task.
132    ///
133    /// # Arguments
134    /// - `task_name`: The name of the task to kill.
135    ///
136    /// # Returns
137    /// - `Ok(())` if the message was sent successfully.
138    /// - `Err(SendError)` if the supervisor is no longer running.
139    pub fn kill_task(&self, task_name: impl Into<String>) -> Result<(), SupervisorHandleError> {
140        self.tx
141            .send(SupervisorMessage::KillTask(task_name.into()))
142            .map_err(|_| SupervisorHandleError::SendError)
143    }
144
145    /// Requests the supervisor to shut down all tasks and stop supervision.
146    ///
147    /// # Returns
148    /// - `Ok(())` if the message was sent successfully.
149    /// - `Err(SendError)` if the supervisor is no longer running.
150    pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
151        self.tx
152            .send(SupervisorMessage::Shutdown)
153            .map_err(|_| SupervisorHandleError::SendError)
154    }
155
156    /// Queries the status of a specific task asynchronously.
157    ///
158    /// # Arguments
159    /// - `task_name`: The name of the task to query.
160    ///
161    /// # Returns
162    /// - `Ok(Some(TaskStatus))` if the task exists and its status is returned.
163    /// - `Ok(None)` if the task does not exist.
164    /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
165    pub async fn get_task_status(
166        &self,
167        task_name: impl Into<String>,
168    ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
169        let (sender, receiver) = oneshot::channel();
170        self.tx
171            .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
172            .map_err(|_| SupervisorHandleError::SendError)?;
173        receiver.await.map_err(SupervisorHandleError::RecvError)
174    }
175
176    /// Queries the statuses of all tasks asynchronously.
177    ///
178    /// # Returns
179    /// - `Ok(HashMap<TaskName, TaskStatus>)` containing the statuses of all tasks.
180    /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
181    pub async fn get_all_task_statuses(
182        &self,
183    ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
184        let (sender, receiver) = oneshot::channel();
185        self.tx
186            .send(SupervisorMessage::GetAllTaskStatuses(sender))
187            .map_err(|_| SupervisorHandleError::SendError)?;
188        receiver.await.map_err(SupervisorHandleError::RecvError)
189    }
190
191    /// Checks if the supervisor channel is still open.
192    fn is_channel_open(&self) -> bool {
193        !self.tx.is_closed()
194    }
195}
196
197impl std::fmt::Debug for SupervisorHandle {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        f.debug_struct("SupervisorHandle")
200            .field("channel_open", &self.is_channel_open())
201            .finish()
202    }
203}