task_supervisor/supervisor/
handle.rs1use std::collections::HashMap;
2
3use thiserror::Error;
4use tokio::sync::{mpsc, oneshot, watch};
5
6use crate::{task::DynTask, SupervisedTask, TaskName, TaskStatus};
7
8use super::SupervisorError;
9
10#[derive(Debug, Error)]
11pub enum SupervisorHandleError {
12 #[error("Failed to send message to supervisor: channel closed")]
13 SendError,
14 #[error("Failed to receive response from supervisor: {0}")]
15 RecvError(#[from] tokio::sync::oneshot::error::RecvError),
16}
17
18pub(crate) enum SupervisorMessage {
19 AddTask(TaskName, DynTask),
20 RestartTask(TaskName),
21 KillTask(TaskName),
22 GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
23 GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
24 Shutdown,
25}
26
27#[derive(Clone)]
29pub struct SupervisorHandle {
30 pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
31 result_rx: watch::Receiver<Option<Result<(), SupervisorError>>>,
32}
33
34impl Drop for SupervisorHandle {
35 fn drop(&mut self) {
36 if self.is_channel_open() {
37 let _ = self.shutdown();
38 }
39 }
40}
41
42impl SupervisorHandle {
43 pub(crate) fn new(
44 join_handle: tokio::task::JoinHandle<Result<(), SupervisorError>>,
45 tx: mpsc::UnboundedSender<SupervisorMessage>,
46 ) -> Self {
47 let (result_tx, result_rx) = watch::channel(None);
48
49 tokio::spawn(async move {
50 let result = match join_handle.await {
51 Ok(supervisor_result) => supervisor_result,
52 Err(_join_error) => Ok(()),
53 };
54 let _ = result_tx.send(Some(result));
55 });
56
57 Self { tx, result_rx }
58 }
59
60 pub async fn wait(&self) -> Result<(), SupervisorError> {
62 let mut rx = self.result_rx.clone();
63 loop {
64 if let Some(result) = rx.borrow_and_update().clone() {
65 return result;
66 }
67 if rx.changed().await.is_err() {
68 return Ok(());
69 }
70 }
71 }
72
73 pub fn add_task<T: SupervisedTask + Clone>(
74 &self,
75 task_name: &str,
76 task: T,
77 ) -> Result<(), SupervisorHandleError> {
78 self.tx
79 .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
80 .map_err(|_| SupervisorHandleError::SendError)
81 }
82
83 pub fn restart(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
84 self.tx
85 .send(SupervisorMessage::RestartTask(task_name.into()))
86 .map_err(|_| SupervisorHandleError::SendError)
87 }
88
89 pub fn kill_task(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
90 self.tx
91 .send(SupervisorMessage::KillTask(task_name.into()))
92 .map_err(|_| SupervisorHandleError::SendError)
93 }
94
95 pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
96 self.tx
97 .send(SupervisorMessage::Shutdown)
98 .map_err(|_| SupervisorHandleError::SendError)
99 }
100
101 pub async fn get_task_status(
102 &self,
103 task_name: &str,
104 ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
105 let (sender, receiver) = oneshot::channel();
106 self.tx
107 .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
108 .map_err(|_| SupervisorHandleError::SendError)?;
109 receiver.await.map_err(SupervisorHandleError::RecvError)
110 }
111
112 pub async fn get_all_task_statuses(
113 &self,
114 ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
115 let (sender, receiver) = oneshot::channel();
116 self.tx
117 .send(SupervisorMessage::GetAllTaskStatuses(sender))
118 .map_err(|_| SupervisorHandleError::SendError)?;
119 receiver.await.map_err(SupervisorHandleError::RecvError)
120 }
121
122 fn is_channel_open(&self) -> bool {
123 !self.tx.is_closed()
124 }
125}
126
127impl std::fmt::Debug for SupervisorHandle {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 f.debug_struct("SupervisorHandle")
130 .field("channel_open", &self.is_channel_open())
131 .finish()
132 }
133}