task_supervisor/supervisor/handle.rs
1use std::collections::HashMap;
2
3use thiserror::Error;
4use tokio::sync::{mpsc, oneshot, watch};
5
6use crate::{task::DynTask, SupervisedTask, TaskName, TaskStatus};
7
8use super::SupervisorError;
9
10#[derive(Debug, Error)]
11pub enum SupervisorHandleError {
12 #[error("Failed to send message to supervisor: channel closed")]
13 SendError,
14 #[error("Failed to receive response from supervisor: {0}")]
15 RecvError(#[from] tokio::sync::oneshot::error::RecvError),
16}
17
18pub(crate) enum SupervisorMessage {
19 /// Request to add a new running Task
20 AddTask(TaskName, DynTask),
21 /// Request to restart a Task
22 RestartTask(TaskName),
23 /// Request to kill a Task
24 KillTask(TaskName),
25 /// Request the status of a Task
26 GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
27 /// Request the status of all Tasks
28 GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
29 /// Request sent to shutdown all the running Tasks
30 Shutdown,
31}
32
33/// Handle used to interact with the `Supervisor`.
34///
35/// Cloning this handle is cheap. All clones share the same connection to
36/// the supervisor.
37#[derive(Clone)]
38pub struct SupervisorHandle {
39 pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
40 /// Receives the supervisor's final result.
41 result_rx: watch::Receiver<Option<Result<(), SupervisorError>>>,
42}
43
44impl Drop for SupervisorHandle {
45 /// Automatically shuts down the supervisor when the handle is dropped.
46 fn drop(&mut self) {
47 if self.is_channel_open() {
48 let _ = self.shutdown();
49 }
50 }
51}
52
53impl SupervisorHandle {
54 /// Creates a new `SupervisorHandle`.
55 ///
56 /// Spawns a background task that awaits the supervisor's `JoinHandle` and
57 /// broadcasts the result via a `watch` channel.
58 pub(crate) fn new(
59 join_handle: tokio::task::JoinHandle<Result<(), SupervisorError>>,
60 tx: mpsc::UnboundedSender<SupervisorMessage>,
61 ) -> Self {
62 let (result_tx, result_rx) = watch::channel(None);
63
64 tokio::spawn(async move {
65 let result = match join_handle.await {
66 Ok(supervisor_result) => supervisor_result,
67 // Supervisor task was cancelled/panicked — treat as Ok
68 Err(_join_error) => Ok(()),
69 };
70 let _ = result_tx.send(Some(result));
71 });
72
73 Self { tx, result_rx }
74 }
75
76 /// Waits for the supervisor to complete its execution.
77 ///
78 /// Multiple callers can `wait()` concurrently; all will receive the same result.
79 ///
80 /// # Returns
81 /// - `Ok(())` if the supervisor completed successfully.
82 /// - `Err(SupervisorError)` if the supervisor returned an error.
83 pub async fn wait(&self) -> Result<(), SupervisorError> {
84 let mut rx = self.result_rx.clone();
85 // Wait until the value is Some (i.e., the supervisor has finished).
86 loop {
87 if let Some(result) = rx.borrow_and_update().clone() {
88 return result;
89 }
90 // Value is still None — wait for a change.
91 if rx.changed().await.is_err() {
92 // Sender dropped without sending — supervisor is gone.
93 return Ok(());
94 }
95 }
96 }
97
98 /// Adds a new task to the supervisor.
99 ///
100 /// # Arguments
101 /// - `task_name`: The unique name of the task.
102 /// - `task`: The task to be added, which must implement `SupervisedTask`.
103 ///
104 /// # Returns
105 /// - `Ok(())` if the message was sent successfully.
106 /// - `Err(SendError)` if the supervisor is no longer running.
107 pub fn add_task<T: SupervisedTask + Clone>(
108 &self,
109 task_name: impl Into<String>,
110 task: T,
111 ) -> Result<(), SupervisorHandleError> {
112 self.tx
113 .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
114 .map_err(|_| SupervisorHandleError::SendError)
115 }
116
117 /// Requests the supervisor to restart a specific task.
118 ///
119 /// # Arguments
120 /// - `task_name`: The name of the task to restart.
121 ///
122 /// # Returns
123 /// - `Ok(())` if the message was sent successfully.
124 /// - `Err(SendError)` if the supervisor is no longer running.
125 pub fn restart(&self, task_name: impl Into<String>) -> Result<(), SupervisorHandleError> {
126 self.tx
127 .send(SupervisorMessage::RestartTask(task_name.into()))
128 .map_err(|_| SupervisorHandleError::SendError)
129 }
130
131 /// Requests the supervisor to kill a specific task.
132 ///
133 /// # Arguments
134 /// - `task_name`: The name of the task to kill.
135 ///
136 /// # Returns
137 /// - `Ok(())` if the message was sent successfully.
138 /// - `Err(SendError)` if the supervisor is no longer running.
139 pub fn kill_task(&self, task_name: impl Into<String>) -> Result<(), SupervisorHandleError> {
140 self.tx
141 .send(SupervisorMessage::KillTask(task_name.into()))
142 .map_err(|_| SupervisorHandleError::SendError)
143 }
144
145 /// Requests the supervisor to shut down all tasks and stop supervision.
146 ///
147 /// # Returns
148 /// - `Ok(())` if the message was sent successfully.
149 /// - `Err(SendError)` if the supervisor is no longer running.
150 pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
151 self.tx
152 .send(SupervisorMessage::Shutdown)
153 .map_err(|_| SupervisorHandleError::SendError)
154 }
155
156 /// Queries the status of a specific task asynchronously.
157 ///
158 /// # Arguments
159 /// - `task_name`: The name of the task to query.
160 ///
161 /// # Returns
162 /// - `Ok(Some(TaskStatus))` if the task exists and its status is returned.
163 /// - `Ok(None)` if the task does not exist.
164 /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
165 pub async fn get_task_status(
166 &self,
167 task_name: impl Into<String>,
168 ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
169 let (sender, receiver) = oneshot::channel();
170 self.tx
171 .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
172 .map_err(|_| SupervisorHandleError::SendError)?;
173 receiver.await.map_err(SupervisorHandleError::RecvError)
174 }
175
176 /// Queries the statuses of all tasks asynchronously.
177 ///
178 /// # Returns
179 /// - `Ok(HashMap<TaskName, TaskStatus>)` containing the statuses of all tasks.
180 /// - `Err(RecvError)` if communication with the supervisor fails (e.g., it has shut down).
181 pub async fn get_all_task_statuses(
182 &self,
183 ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
184 let (sender, receiver) = oneshot::channel();
185 self.tx
186 .send(SupervisorMessage::GetAllTaskStatuses(sender))
187 .map_err(|_| SupervisorHandleError::SendError)?;
188 receiver.await.map_err(SupervisorHandleError::RecvError)
189 }
190
191 /// Checks if the supervisor channel is still open.
192 fn is_channel_open(&self) -> bool {
193 !self.tx.is_closed()
194 }
195}
196
197impl std::fmt::Debug for SupervisorHandle {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct("SupervisorHandle")
200 .field("channel_open", &self.is_channel_open())
201 .finish()
202 }
203}