task_supervisor/supervisor/handle.rs
1use std::collections::HashMap;
2
3use thiserror::Error;
4use tokio::sync::{mpsc, oneshot, watch};
5
6use crate::{task::DynTask, SupervisedTask, TaskName, TaskStatus};
7
8use super::SupervisorError;
9
10#[derive(Debug, Error)]
11pub enum SupervisorHandleError {
12 #[error("Failed to send message to supervisor: channel closed")]
13 SendError,
14 #[error("Failed to receive response from supervisor: {0}")]
15 RecvError(#[from] tokio::sync::oneshot::error::RecvError),
16}
17
18pub(crate) enum SupervisorMessage {
19 /// Request to add a new running Task
20 AddTask(TaskName, DynTask),
21 /// Request to restart a Task
22 RestartTask(TaskName),
23 /// Request to kill a Task
24 KillTask(TaskName),
25 /// Request the status of a Task
26 GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
27 /// Request the status of all Tasks
28 GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
29 /// Request sent to shutdown all the running Tasks
30 Shutdown,
31}
32
33/// Handle used to interact with the `Supervisor`.
34///
35/// Cloning this handle is cheap. All clones share the same connection to
36/// the supervisor. Multiple clones can call `wait()` concurrently and all
37/// will receive the result.
38#[derive(Clone)]
39pub struct SupervisorHandle {
40 pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
41 /// Receives the supervisor's final result. Using a `watch` channel
42 /// allows multiple concurrent `wait()` callers to all see the outcome,
43 /// unlike `JoinSet` which would consume the result on the first `join_next()`.
44 result_rx: watch::Receiver<Option<Result<(), SupervisorError>>>,
45}
46
47impl Drop for SupervisorHandle {
48 /// Automatically shuts down the supervisor when the handle is dropped.
49 fn drop(&mut self) {
50 if self.is_channel_open() {
51 let _ = self.shutdown();
52 }
53 }
54}
55
56impl SupervisorHandle {
57 /// Creates a new `SupervisorHandle`.
58 ///
59 /// Spawns a background task that awaits the supervisor's `JoinHandle` and
60 /// broadcasts the result via a `watch` channel.
61 pub(crate) fn new(
62 join_handle: tokio::task::JoinHandle<Result<(), SupervisorError>>,
63 tx: mpsc::UnboundedSender<SupervisorMessage>,
64 ) -> Self {
65 let (result_tx, result_rx) = watch::channel(None);
66
67 tokio::spawn(async move {
68 let result = match join_handle.await {
69 Ok(supervisor_result) => supervisor_result,
70 // Supervisor task was cancelled/panicked — treat as Ok
71 Err(_join_error) => Ok(()),
72 };
73 let _ = result_tx.send(Some(result));
74 });
75
76 Self { tx, result_rx }
77 }
78
79 /// Waits for the supervisor to complete its execution.
80 ///
81 /// Multiple callers can `wait()` concurrently; all will receive the same result.
82 ///
83 /// # Returns
84 /// - `Ok(())` if the supervisor completed successfully.
85 /// - `Err(SupervisorError)` if the supervisor returned an error.
86 pub async fn wait(&self) -> Result<(), SupervisorError> {
87 let mut rx = self.result_rx.clone();
88 // Wait until the value is Some (i.e., the supervisor has finished).
89 loop {
90 if let Some(result) = rx.borrow_and_update().clone() {
91 return result;
92 }
93 // Value is still None — wait for a change.
94 if rx.changed().await.is_err() {
95 // Sender dropped without sending — supervisor is gone.
96 return Ok(());
97 }
98 }
99 }
100
101 /// Adds a new task to the supervisor.
102 ///
103 /// # Arguments
104 /// - `task_name`: The unique name of the task.
105 /// - `task`: The task to be added, which must implement `SupervisedTask`.
106 ///
107 /// # Returns
108 /// - `Ok(())` if the message was sent successfully.
109 /// - `Err(SendError)` if the supervisor is no longer running.
110 pub fn add_task<T: SupervisedTask + Clone>(
111 &self,
112 task_name: &str,
113 task: T,
114 ) -> Result<(), SupervisorHandleError> {
115 self.tx
116 .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
117 .map_err(|_| SupervisorHandleError::SendError)
118 }
119
120 /// Requests the supervisor to restart a specific task.
121 ///
122 /// # Arguments
123 /// - `task_name`: The name of the task to restart.
124 ///
125 /// # Returns
126 /// - `Ok(())` if the message was sent successfully.
127 /// - `Err(SendError)` if the supervisor is no longer running.
128 pub fn restart(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
129 self.tx
130 .send(SupervisorMessage::RestartTask(task_name.into()))
131 .map_err(|_| SupervisorHandleError::SendError)
132 }
133
134 /// Requests the supervisor to kill a specific task.
135 ///
136 /// # Arguments
137 /// - `task_name`: The name of the task to kill.
138 ///
139 /// # Returns
140 /// - `Ok(())` if the message was sent successfully.
141 /// - `Err(SendError)` if the supervisor is no longer running.
142 pub fn kill_task(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
143 self.tx
144 .send(SupervisorMessage::KillTask(task_name.into()))
145 .map_err(|_| SupervisorHandleError::SendError)
146 }
147
148 /// Requests the supervisor to shut down all tasks and stop supervision.
149 ///
150 /// # Returns
151 /// - `Ok(())` if the message was sent successfully.
152 /// - `Err(SendError)` if the supervisor is no longer running.
153 pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
154 self.tx
155 .send(SupervisorMessage::Shutdown)
156 .map_err(|_| SupervisorHandleError::SendError)
157 }
158
159 /// Queries the status of a specific task asynchronously.
160 ///
161 /// # Arguments
162 /// - `task_name`: The name of the task to query.
163 ///
164 /// # Returns
165 /// - `Ok(Some(TaskStatus))` if the task exists and its status is returned.
166 /// - `Ok(None)` if the task does not exist.
167 /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
168 pub async fn get_task_status(
169 &self,
170 task_name: &str,
171 ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
172 let (sender, receiver) = oneshot::channel();
173 self.tx
174 .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
175 .map_err(|_| SupervisorHandleError::SendError)?;
176 receiver.await.map_err(SupervisorHandleError::RecvError)
177 }
178
179 /// Queries the statuses of all tasks asynchronously.
180 ///
181 /// # Returns
182 /// - `Ok(HashMap<TaskName, TaskStatus>)` containing the statuses of all tasks.
183 /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
184 pub async fn get_all_task_statuses(
185 &self,
186 ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
187 let (sender, receiver) = oneshot::channel();
188 self.tx
189 .send(SupervisorMessage::GetAllTaskStatuses(sender))
190 .map_err(|_| SupervisorHandleError::SendError)?;
191 receiver.await.map_err(SupervisorHandleError::RecvError)
192 }
193
194 /// Checks if the supervisor channel is still open.
195 fn is_channel_open(&self) -> bool {
196 !self.tx.is_closed()
197 }
198}
199
200impl std::fmt::Debug for SupervisorHandle {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("SupervisorHandle")
203 .field("channel_open", &self.is_channel_open())
204 .finish()
205 }
206}