task_supervisor/supervisor/handle.rs
1use std::{collections::HashMap, sync::Arc};
2
3use thiserror::Error;
4use tokio::{
5 sync::{mpsc, oneshot, Mutex},
6 task::{JoinError, JoinHandle},
7};
8
9use crate::{
10 task::{CloneableSupervisedTask, DynTask},
11 TaskName, TaskStatus,
12};
13
14#[derive(Debug, Error)]
15pub enum SupervisorHandleError {
16 #[error("Failed to send message to supervisor: {0}")]
17 SendError(#[from] tokio::sync::mpsc::error::SendError<SupervisorMessage>),
18 #[error("Failed to receive response from supervisor: {0}")]
19 RecvError(#[from] tokio::sync::oneshot::error::RecvError),
20}
21
22pub enum SupervisorMessage {
23 /// Request to add a new running Task
24 AddTask(TaskName, DynTask),
25 /// Request to restart a Task
26 RestartTask(TaskName),
27 /// Request to kill a Task
28 KillTask(TaskName),
29 /// Request the status of a Task
30 GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
31 /// Request the status of all Task
32 GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
33 /// Request sent to shutdown all the running Tasks
34 Shutdown,
35}
36
37/// Handle used to interact with the `Supervisor`.
38#[derive(Debug, Clone)]
39pub struct SupervisorHandle {
40 pub(crate) join_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
41 pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
42}
43
44impl Drop for SupervisorHandle {
45 /// Automatically shuts down the supervisor when the handle is dropped.
46 fn drop(&mut self) {
47 if self.is_channel_open() {
48 let _ = self.shutdown();
49 }
50 }
51}
52
53impl SupervisorHandle {
54 /// Creates a new `SupervisorHandle`.
55 ///
56 /// This constructor is for internal use within the crate and initializes the handle
57 /// with the provided join handle and message sender.
58 ///
59 /// # Arguments
60 /// - `join_handle`: The `JoinHandle` representing the supervisor's task.
61 /// - `tx`: The unbounded sender for sending messages to the supervisor.
62 ///
63 /// # Returns
64 /// A new instance of `SupervisorHandle`.
65 pub(crate) fn new(
66 join_handle: JoinHandle<()>,
67 tx: mpsc::UnboundedSender<SupervisorMessage>,
68 ) -> Self {
69 Self {
70 join_handle: Arc::new(Mutex::new(Some(join_handle))),
71 tx,
72 }
73 }
74
75 /// Waits for the supervisor to complete its execution.
76 ///
77 /// This method consumes the handle and waits for the supervisor task to finish.
78 /// It should be called only once per handle, as it takes ownership of the join handle.
79 ///
80 /// # Returns
81 /// - `Ok(())` if the supervisor completed successfully.
82 /// - `Err(JoinError)` if the supervisor task panicked.
83 ///
84 /// # Panics
85 /// Panics if `wait()` has already been called on any clone of this handle.
86 pub async fn wait(self) -> Result<(), JoinError> {
87 let handle_opt = {
88 let mut guard = self.join_handle.lock().await;
89 guard.take()
90 };
91
92 match handle_opt {
93 Some(handle) => handle.await,
94 None => panic!("SupervisorHandle::wait() was already called on a clone of this handle"),
95 }
96 }
97
98 /// Adds a new task to the supervisor.
99 ///
100 /// This method sends a message to the supervisor to add a new task with the specified name.
101 ///
102 /// # Arguments
103 /// - `task_name`: The unique name of the task.
104 /// - `task`: The task to be added, which must implement `SupervisedTask`.
105 ///
106 /// # Returns
107 /// - `Ok(())` if the message was sent successfully.
108 /// - `Err(SendError)` if the supervisor is no longer running.
109 pub fn add_task<T: CloneableSupervisedTask + 'static>(
110 &self,
111 task_name: &str,
112 task: T,
113 ) -> Result<(), SupervisorHandleError> {
114 self.tx
115 .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
116 .map_err(SupervisorHandleError::SendError)
117 }
118
119 /// Requests the supervisor to restart a specific task.
120 ///
121 /// This method sends a message to the supervisor to restart the task with the given name.
122 ///
123 /// # Arguments
124 /// - `task_name`: The name of the task to restart.
125 ///
126 /// # Returns
127 /// - `Ok(())` if the message was sent successfully.
128 /// - `Err(SendError)` if the supervisor is no longer running.
129 pub fn restart(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
130 self.tx
131 .send(SupervisorMessage::RestartTask(task_name.into()))
132 .map_err(SupervisorHandleError::SendError)
133 }
134
135 /// Requests the supervisor to kill a specific task.
136 ///
137 /// This method sends a message to the supervisor to terminate the task with the given name.
138 ///
139 /// # Arguments
140 /// - `task_name`: The name of the task to kill.
141 ///
142 /// # Returns
143 /// - `Ok(())` if the message was sent successfully.
144 /// - `Err(SendError)` if the supervisor is no longer running.
145 pub fn kill_task(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
146 self.tx
147 .send(SupervisorMessage::KillTask(task_name.into()))
148 .map_err(SupervisorHandleError::SendError)
149 }
150
151 /// Requests the supervisor to shut down all tasks and stop supervision.
152 ///
153 /// This method sends a message to the supervisor to terminate all tasks and cease operation.
154 ///
155 /// # Returns
156 /// - `Ok(())` if the message was sent successfully.
157 /// - `Err(SendError)` if the supervisor is no longer running.
158 pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
159 self.tx
160 .send(SupervisorMessage::Shutdown)
161 .map_err(SupervisorHandleError::SendError)
162 }
163
164 /// Queries the status of a specific task asynchronously.
165 ///
166 /// This method sends a request to the supervisor to retrieve the status of the specified task
167 /// and awaits the response.
168 ///
169 /// # Arguments
170 /// - `task_name`: The name of the task to query.
171 ///
172 /// # Returns
173 /// - `Ok(Some(TaskStatus))` if the task exists and its status is returned.
174 /// - `Ok(None)` if the task does not exist.
175 /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
176 pub async fn get_task_status(
177 &self,
178 task_name: &str,
179 ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
180 let (sender, receiver) = oneshot::channel();
181 self.tx
182 .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
183 .map_err(SupervisorHandleError::SendError)?;
184 receiver.await.map_err(SupervisorHandleError::RecvError)
185 }
186
187 /// Queries the statuses of all tasks asynchronously.
188 ///
189 /// This method sends a request to the supervisor to retrieve the statuses of all tasks
190 /// and awaits the response.
191 ///
192 /// # Returns
193 /// - `Ok(HashMap<TaskName, TaskStatus>)` containing the statuses of all tasks.
194 /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
195 pub async fn get_all_task_statuses(
196 &self,
197 ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
198 let (sender, receiver) = oneshot::channel();
199 self.tx
200 .send(SupervisorMessage::GetAllTaskStatuses(sender))
201 .map_err(SupervisorHandleError::SendError)?;
202 receiver.await.map_err(SupervisorHandleError::RecvError)
203 }
204
205 /// Checks if the supervisor channel is still open.
206 fn is_channel_open(&self) -> bool {
207 !self.tx.is_closed()
208 }
209}