task_supervisor/supervisor/
handle.rs

1use std::{collections::HashMap, sync::Arc};
2
3use thiserror::Error;
4use tokio::{
5    sync::{mpsc, oneshot, Mutex},
6    task::{JoinError, JoinHandle, JoinSet},
7};
8
9use crate::{
10    task::{CloneableSupervisedTask, DynTask},
11    TaskName, TaskStatus,
12};
13
14use super::SupervisorError;
15
16#[derive(Debug, Error)]
17pub enum SupervisorHandleError {
18    #[error("Failed to send message to supervisor: {0}")]
19    SendError(#[from] tokio::sync::mpsc::error::SendError<SupervisorMessage>),
20    #[error("Failed to receive response from supervisor: {0}")]
21    RecvError(#[from] tokio::sync::oneshot::error::RecvError),
22}
23
24pub enum SupervisorMessage {
25    /// Request to add a new running Task
26    AddTask(TaskName, DynTask),
27    /// Request to restart a Task
28    RestartTask(TaskName),
29    /// Request to kill a Task
30    KillTask(TaskName),
31    /// Request the status of a  Task
32    GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
33    /// Request the status of all Task
34    GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
35    /// Request sent to shutdown all the running Tasks
36    Shutdown,
37}
38
39type SupervisorHandleResult = Result<Result<(), SupervisorError>, JoinError>;
40
41/// Handle used to interact with the `Supervisor`.
42#[derive(Debug, Clone)]
43pub struct SupervisorHandle {
44    pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
45    join_set: Arc<Mutex<JoinSet<SupervisorHandleResult>>>,
46}
47
48impl Drop for SupervisorHandle {
49    /// Automatically shuts down the supervisor when the handle is dropped.
50    fn drop(&mut self) {
51        if self.is_channel_open() {
52            let _ = self.shutdown();
53        }
54    }
55}
56
57impl SupervisorHandle {
58    /// Creates a new `SupervisorHandle`.
59    ///
60    /// This constructor is for internal use within the crate and initializes the handle
61    /// with the provided join handle and message sender.
62    ///
63    /// # Arguments
64    /// - `join_handle`: The `JoinHandle` representing the supervisor's task.
65    /// - `tx`: The unbounded sender for sending messages to the supervisor.
66    ///
67    /// # Returns
68    /// A new instance of `SupervisorHandle`.
69    pub(crate) fn new(
70        join_handle: JoinHandle<Result<(), SupervisorError>>,
71        tx: mpsc::UnboundedSender<SupervisorMessage>,
72    ) -> Self {
73        let mut join_set = JoinSet::new();
74        join_set.spawn(join_handle);
75
76        Self {
77            tx,
78            join_set: Arc::new(Mutex::new(join_set)),
79        }
80    }
81
82    /// Waits for the supervisor to complete its execution.
83    ///
84    /// # Returns
85    /// - `Ok(())` if the supervisor completed successfully.
86    /// - `Err(SupervisorError)` if the supervisor returned an error.
87    pub async fn wait(&self) -> Result<(), SupervisorError> {
88        let mut join_set = self.join_set.lock().await;
89
90        if join_set.is_empty() {
91            return Ok(());
92        }
93
94        if let Some(result) = join_set.join_next().await {
95            match result {
96                Ok(Ok(supervisor_result)) => supervisor_result,
97                // TODO: Handle the join_error & propagate it?
98                Err(_join_error) => Ok(()),
99                _ => Ok(()),
100            }
101        } else {
102            // JoinSet is empty, task already completed
103            Ok(())
104        }
105    }
106
107    /// Adds a new task to the supervisor.
108    ///
109    /// This method sends a message to the supervisor to add a new task with the specified name.
110    ///
111    /// # Arguments
112    /// - `task_name`: The unique name of the task.
113    /// - `task`: The task to be added, which must implement `SupervisedTask`.
114    ///
115    /// # Returns
116    /// - `Ok(())` if the message was sent successfully.
117    /// - `Err(SendError)` if the supervisor is no longer running.
118    pub fn add_task<T: CloneableSupervisedTask + 'static>(
119        &self,
120        task_name: &str,
121        task: T,
122    ) -> Result<(), SupervisorHandleError> {
123        self.tx
124            .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
125            .map_err(SupervisorHandleError::SendError)
126    }
127
128    /// Requests the supervisor to restart a specific task.
129    ///
130    /// This method sends a message to the supervisor to restart the task with the given name.
131    ///
132    /// # Arguments
133    /// - `task_name`: The name of the task to restart.
134    ///
135    /// # Returns
136    /// - `Ok(())` if the message was sent successfully.
137    /// - `Err(SendError)` if the supervisor is no longer running.
138    pub fn restart(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
139        self.tx
140            .send(SupervisorMessage::RestartTask(task_name.into()))
141            .map_err(SupervisorHandleError::SendError)
142    }
143
144    /// Requests the supervisor to kill a specific task.
145    ///
146    /// This method sends a message to the supervisor to terminate the task with the given name.
147    ///
148    /// # Arguments
149    /// - `task_name`: The name of the task to kill.
150    ///
151    /// # Returns
152    /// - `Ok(())` if the message was sent successfully.
153    /// - `Err(SendError)` if the supervisor is no longer running.
154    pub fn kill_task(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
155        self.tx
156            .send(SupervisorMessage::KillTask(task_name.into()))
157            .map_err(SupervisorHandleError::SendError)
158    }
159
160    /// Requests the supervisor to shut down all tasks and stop supervision.
161    ///
162    /// This method sends a message to the supervisor to terminate all tasks and cease operation.
163    ///
164    /// # Returns
165    /// - `Ok(())` if the message was sent successfully.
166    /// - `Err(SendError)` if the supervisor is no longer running.
167    pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
168        self.tx
169            .send(SupervisorMessage::Shutdown)
170            .map_err(SupervisorHandleError::SendError)
171    }
172
173    /// Queries the status of a specific task asynchronously.
174    ///
175    /// This method sends a request to the supervisor to retrieve the status of the specified task
176    /// and awaits the response.
177    ///
178    /// # Arguments
179    /// - `task_name`: The name of the task to query.
180    ///
181    /// # Returns
182    /// - `Ok(Some(TaskStatus))` if the task exists and its status is returned.
183    /// - `Ok(None)` if the task does not exist.
184    /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
185    pub async fn get_task_status(
186        &self,
187        task_name: &str,
188    ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
189        let (sender, receiver) = oneshot::channel();
190        self.tx
191            .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
192            .map_err(SupervisorHandleError::SendError)?;
193        receiver.await.map_err(SupervisorHandleError::RecvError)
194    }
195
196    /// Queries the statuses of all tasks asynchronously.
197    ///
198    /// This method sends a request to the supervisor to retrieve the statuses of all tasks
199    /// and awaits the response.
200    ///
201    /// # Returns
202    /// - `Ok(HashMap<TaskName, TaskStatus>)` containing the statuses of all tasks.
203    /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
204    pub async fn get_all_task_statuses(
205        &self,
206    ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
207        let (sender, receiver) = oneshot::channel();
208        self.tx
209            .send(SupervisorMessage::GetAllTaskStatuses(sender))
210            .map_err(SupervisorHandleError::SendError)?;
211        receiver.await.map_err(SupervisorHandleError::RecvError)
212    }
213
214    /// Checks if the supervisor channel is still open.
215    fn is_channel_open(&self) -> bool {
216        !self.tx.is_closed()
217    }
218}