task_supervisor/supervisor/
handle.rs

1use std::{collections::HashMap, sync::Arc};
2
3use thiserror::Error;
4use tokio::{
5    sync::{mpsc, oneshot, Mutex},
6    task::{JoinError, JoinHandle},
7};
8
9use crate::{
10    task::{CloneableSupervisedTask, DynTask},
11    TaskName, TaskStatus,
12};
13
14#[derive(Debug, Error)]
15pub enum SupervisorHandleError {
16    #[error("Failed to send message to supervisor: {0}")]
17    SendError(#[from] tokio::sync::mpsc::error::SendError<SupervisorMessage>),
18    #[error("Failed to receive response from supervisor: {0}")]
19    RecvError(#[from] tokio::sync::oneshot::error::RecvError),
20}
21
22pub enum SupervisorMessage {
23    /// Request to add a new running Task
24    AddTask(TaskName, DynTask),
25    /// Request to restart a Task
26    RestartTask(TaskName),
27    /// Request to kill a Task
28    KillTask(TaskName),
29    /// Request the status of a  Task
30    GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
31    /// Request the status of all Task
32    GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
33    /// Request sent to shutdown all the running Tasks
34    Shutdown,
35}
36
37/// Handle used to interact with the `Supervisor`.
38#[derive(Debug, Clone)]
39pub struct SupervisorHandle {
40    pub(crate) join_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
41    pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
42}
43
44impl Drop for SupervisorHandle {
45    /// Automatically shuts down the supervisor when the handle is dropped.
46    fn drop(&mut self) {
47        if self.is_channel_open() {
48            let _ = self.shutdown();
49        }
50    }
51}
52
53impl SupervisorHandle {
54    /// Creates a new `SupervisorHandle`.
55    ///
56    /// This constructor is for internal use within the crate and initializes the handle
57    /// with the provided join handle and message sender.
58    ///
59    /// # Arguments
60    /// - `join_handle`: The `JoinHandle` representing the supervisor's task.
61    /// - `tx`: The unbounded sender for sending messages to the supervisor.
62    ///
63    /// # Returns
64    /// A new instance of `SupervisorHandle`.
65    pub(crate) fn new(
66        join_handle: JoinHandle<()>,
67        tx: mpsc::UnboundedSender<SupervisorMessage>,
68    ) -> Self {
69        Self {
70            join_handle: Arc::new(Mutex::new(Some(join_handle))),
71            tx,
72        }
73    }
74
75    /// Waits for the supervisor to complete its execution.
76    ///
77    /// This method consumes the handle and waits for the supervisor task to finish.
78    /// It should be called only once per handle, as it takes ownership of the join handle.
79    ///
80    /// # Returns
81    /// - `Ok(())` if the supervisor completed successfully.
82    /// - `Err(JoinError)` if the supervisor task panicked.
83    ///
84    /// # Panics
85    /// Panics if `wait()` has already been called on any clone of this handle.
86    pub async fn wait(self) -> Result<(), JoinError> {
87        let handle_opt = {
88            let mut guard = self.join_handle.lock().await;
89            guard.take()
90        };
91
92        match handle_opt {
93            Some(handle) => handle.await,
94            None => panic!("SupervisorHandle::wait() was already called on a clone of this handle"),
95        }
96    }
97
98    /// Adds a new task to the supervisor.
99    ///
100    /// This method sends a message to the supervisor to add a new task with the specified name.
101    ///
102    /// # Arguments
103    /// - `task_name`: The unique name of the task.
104    /// - `task`: The task to be added, which must implement `SupervisedTask`.
105    ///
106    /// # Returns
107    /// - `Ok(())` if the message was sent successfully.
108    /// - `Err(SendError)` if the supervisor is no longer running.
109    pub fn add_task<T: CloneableSupervisedTask + 'static>(
110        &self,
111        task_name: &str,
112        task: T,
113    ) -> Result<(), SupervisorHandleError> {
114        self.tx
115            .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
116            .map_err(SupervisorHandleError::SendError)
117    }
118
119    /// Requests the supervisor to restart a specific task.
120    ///
121    /// This method sends a message to the supervisor to restart the task with the given name.
122    ///
123    /// # Arguments
124    /// - `task_name`: The name of the task to restart.
125    ///
126    /// # Returns
127    /// - `Ok(())` if the message was sent successfully.
128    /// - `Err(SendError)` if the supervisor is no longer running.
129    pub fn restart(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
130        self.tx
131            .send(SupervisorMessage::RestartTask(task_name.into()))
132            .map_err(SupervisorHandleError::SendError)
133    }
134
135    /// Requests the supervisor to kill a specific task.
136    ///
137    /// This method sends a message to the supervisor to terminate the task with the given name.
138    ///
139    /// # Arguments
140    /// - `task_name`: The name of the task to kill.
141    ///
142    /// # Returns
143    /// - `Ok(())` if the message was sent successfully.
144    /// - `Err(SendError)` if the supervisor is no longer running.
145    pub fn kill_task(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
146        self.tx
147            .send(SupervisorMessage::KillTask(task_name.into()))
148            .map_err(SupervisorHandleError::SendError)
149    }
150
151    /// Requests the supervisor to shut down all tasks and stop supervision.
152    ///
153    /// This method sends a message to the supervisor to terminate all tasks and cease operation.
154    ///
155    /// # Returns
156    /// - `Ok(())` if the message was sent successfully.
157    /// - `Err(SendError)` if the supervisor is no longer running.
158    pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
159        self.tx
160            .send(SupervisorMessage::Shutdown)
161            .map_err(SupervisorHandleError::SendError)
162    }
163
164    /// Queries the status of a specific task asynchronously.
165    ///
166    /// This method sends a request to the supervisor to retrieve the status of the specified task
167    /// and awaits the response.
168    ///
169    /// # Arguments
170    /// - `task_name`: The name of the task to query.
171    ///
172    /// # Returns
173    /// - `Ok(Some(TaskStatus))` if the task exists and its status is returned.
174    /// - `Ok(None)` if the task does not exist.
175    /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
176    pub async fn get_task_status(
177        &self,
178        task_name: &str,
179    ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
180        let (sender, receiver) = oneshot::channel();
181        self.tx
182            .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
183            .map_err(SupervisorHandleError::SendError)?;
184        receiver.await.map_err(SupervisorHandleError::RecvError)
185    }
186
187    /// Queries the statuses of all tasks asynchronously.
188    ///
189    /// This method sends a request to the supervisor to retrieve the statuses of all tasks
190    /// and awaits the response.
191    ///
192    /// # Returns
193    /// - `Ok(HashMap<TaskName, TaskStatus>)` containing the statuses of all tasks.
194    /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
195    pub async fn get_all_task_statuses(
196        &self,
197    ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
198        let (sender, receiver) = oneshot::channel();
199        self.tx
200            .send(SupervisorMessage::GetAllTaskStatuses(sender))
201            .map_err(SupervisorHandleError::SendError)?;
202        receiver.await.map_err(SupervisorHandleError::RecvError)
203    }
204
205    /// Checks if the supervisor channel is still open.
206    fn is_channel_open(&self) -> bool {
207        !self.tx.is_closed()
208    }
209}