task_supervisor/supervisor/handle.rs
1use std::{collections::HashMap, sync::Arc};
2
3use thiserror::Error;
4use tokio::{
5 sync::{mpsc, oneshot, Mutex},
6 task::{JoinError, JoinHandle, JoinSet},
7};
8
9use crate::{
10 task::{CloneableSupervisedTask, DynTask},
11 TaskName, TaskStatus,
12};
13
14use super::SupervisorError;
15
16#[derive(Debug, Error)]
17pub enum SupervisorHandleError {
18 #[error("Failed to send message to supervisor: {0}")]
19 SendError(#[from] tokio::sync::mpsc::error::SendError<SupervisorMessage>),
20 #[error("Failed to receive response from supervisor: {0}")]
21 RecvError(#[from] tokio::sync::oneshot::error::RecvError),
22}
23
24pub enum SupervisorMessage {
25 /// Request to add a new running Task
26 AddTask(TaskName, DynTask),
27 /// Request to restart a Task
28 RestartTask(TaskName),
29 /// Request to kill a Task
30 KillTask(TaskName),
31 /// Request the status of a Task
32 GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
33 /// Request the status of all Task
34 GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
35 /// Request sent to shutdown all the running Tasks
36 Shutdown,
37}
38
39type SupervisorHandleResult = Result<Result<(), SupervisorError>, JoinError>;
40
41/// Handle used to interact with the `Supervisor`.
42#[derive(Debug, Clone)]
43pub struct SupervisorHandle {
44 pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
45 join_set: Arc<Mutex<JoinSet<SupervisorHandleResult>>>,
46}
47
48impl Drop for SupervisorHandle {
49 /// Automatically shuts down the supervisor when the handle is dropped.
50 fn drop(&mut self) {
51 if self.is_channel_open() {
52 let _ = self.shutdown();
53 }
54 }
55}
56
57impl SupervisorHandle {
58 /// Creates a new `SupervisorHandle`.
59 ///
60 /// This constructor is for internal use within the crate and initializes the handle
61 /// with the provided join handle and message sender.
62 ///
63 /// # Arguments
64 /// - `join_handle`: The `JoinHandle` representing the supervisor's task.
65 /// - `tx`: The unbounded sender for sending messages to the supervisor.
66 ///
67 /// # Returns
68 /// A new instance of `SupervisorHandle`.
69 pub(crate) fn new(
70 join_handle: JoinHandle<Result<(), SupervisorError>>,
71 tx: mpsc::UnboundedSender<SupervisorMessage>,
72 ) -> Self {
73 let mut join_set = JoinSet::new();
74 join_set.spawn(join_handle);
75
76 Self {
77 tx,
78 join_set: Arc::new(Mutex::new(join_set)),
79 }
80 }
81
82 /// Waits for the supervisor to complete its execution.
83 ///
84 /// # Returns
85 /// - `Ok(())` if the supervisor completed successfully.
86 /// - `Err(SupervisorError)` if the supervisor returned an error.
87 pub async fn wait(&self) -> Result<(), SupervisorError> {
88 let mut join_set = self.join_set.lock().await;
89
90 if join_set.is_empty() {
91 return Ok(());
92 }
93
94 if let Some(result) = join_set.join_next().await {
95 match result {
96 Ok(Ok(supervisor_result)) => supervisor_result,
97 // TODO: Handle the join_error & propagate it?
98 Err(_join_error) => Ok(()),
99 _ => Ok(()),
100 }
101 } else {
102 // JoinSet is empty, task already completed
103 Ok(())
104 }
105 }
106
107 /// Adds a new task to the supervisor.
108 ///
109 /// This method sends a message to the supervisor to add a new task with the specified name.
110 ///
111 /// # Arguments
112 /// - `task_name`: The unique name of the task.
113 /// - `task`: The task to be added, which must implement `SupervisedTask`.
114 ///
115 /// # Returns
116 /// - `Ok(())` if the message was sent successfully.
117 /// - `Err(SendError)` if the supervisor is no longer running.
118 pub fn add_task<T: CloneableSupervisedTask + 'static>(
119 &self,
120 task_name: &str,
121 task: T,
122 ) -> Result<(), SupervisorHandleError> {
123 self.tx
124 .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
125 .map_err(SupervisorHandleError::SendError)
126 }
127
128 /// Requests the supervisor to restart a specific task.
129 ///
130 /// This method sends a message to the supervisor to restart the task with the given name.
131 ///
132 /// # Arguments
133 /// - `task_name`: The name of the task to restart.
134 ///
135 /// # Returns
136 /// - `Ok(())` if the message was sent successfully.
137 /// - `Err(SendError)` if the supervisor is no longer running.
138 pub fn restart(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
139 self.tx
140 .send(SupervisorMessage::RestartTask(task_name.into()))
141 .map_err(SupervisorHandleError::SendError)
142 }
143
144 /// Requests the supervisor to kill a specific task.
145 ///
146 /// This method sends a message to the supervisor to terminate the task with the given name.
147 ///
148 /// # Arguments
149 /// - `task_name`: The name of the task to kill.
150 ///
151 /// # Returns
152 /// - `Ok(())` if the message was sent successfully.
153 /// - `Err(SendError)` if the supervisor is no longer running.
154 pub fn kill_task(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
155 self.tx
156 .send(SupervisorMessage::KillTask(task_name.into()))
157 .map_err(SupervisorHandleError::SendError)
158 }
159
160 /// Requests the supervisor to shut down all tasks and stop supervision.
161 ///
162 /// This method sends a message to the supervisor to terminate all tasks and cease operation.
163 ///
164 /// # Returns
165 /// - `Ok(())` if the message was sent successfully.
166 /// - `Err(SendError)` if the supervisor is no longer running.
167 pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
168 self.tx
169 .send(SupervisorMessage::Shutdown)
170 .map_err(SupervisorHandleError::SendError)
171 }
172
173 /// Queries the status of a specific task asynchronously.
174 ///
175 /// This method sends a request to the supervisor to retrieve the status of the specified task
176 /// and awaits the response.
177 ///
178 /// # Arguments
179 /// - `task_name`: The name of the task to query.
180 ///
181 /// # Returns
182 /// - `Ok(Some(TaskStatus))` if the task exists and its status is returned.
183 /// - `Ok(None)` if the task does not exist.
184 /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
185 pub async fn get_task_status(
186 &self,
187 task_name: &str,
188 ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
189 let (sender, receiver) = oneshot::channel();
190 self.tx
191 .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
192 .map_err(SupervisorHandleError::SendError)?;
193 receiver.await.map_err(SupervisorHandleError::RecvError)
194 }
195
196 /// Queries the statuses of all tasks asynchronously.
197 ///
198 /// This method sends a request to the supervisor to retrieve the statuses of all tasks
199 /// and awaits the response.
200 ///
201 /// # Returns
202 /// - `Ok(HashMap<TaskName, TaskStatus>)` containing the statuses of all tasks.
203 /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
204 pub async fn get_all_task_statuses(
205 &self,
206 ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
207 let (sender, receiver) = oneshot::channel();
208 self.tx
209 .send(SupervisorMessage::GetAllTaskStatuses(sender))
210 .map_err(SupervisorHandleError::SendError)?;
211 receiver.await.map_err(SupervisorHandleError::RecvError)
212 }
213
214 /// Checks if the supervisor channel is still open.
215 fn is_channel_open(&self) -> bool {
216 !self.tx.is_closed()
217 }
218}