Skip to main content

task_supervisor/supervisor/
handle.rs

1use std::collections::HashMap;
2
3use thiserror::Error;
4use tokio::sync::{mpsc, oneshot, watch};
5
6use crate::{task::DynTask, SupervisedTask, TaskName, TaskStatus};
7
8use super::SupervisorError;
9
10#[derive(Debug, Error)]
11pub enum SupervisorHandleError {
12    #[error("Failed to send message to supervisor: channel closed")]
13    SendError,
14    #[error("Failed to receive response from supervisor: {0}")]
15    RecvError(#[from] tokio::sync::oneshot::error::RecvError),
16}
17
18pub(crate) enum SupervisorMessage {
19    AddTask(TaskName, DynTask),
20    RestartTask(TaskName),
21    KillTask(TaskName),
22    GetTaskStatus(TaskName, oneshot::Sender<Option<TaskStatus>>),
23    GetAllTaskStatuses(oneshot::Sender<HashMap<TaskName, TaskStatus>>),
24    Shutdown,
25}
26
27/// Handle to interact with a running `Supervisor`. Cheap to clone.
28#[derive(Clone)]
29pub struct SupervisorHandle {
30    pub(crate) tx: mpsc::UnboundedSender<SupervisorMessage>,
31    result_rx: watch::Receiver<Option<Result<(), SupervisorError>>>,
32}
33
34impl Drop for SupervisorHandle {
35    fn drop(&mut self) {
36        if self.is_channel_open() {
37            let _ = self.shutdown();
38        }
39    }
40}
41
42impl SupervisorHandle {
43    pub(crate) fn new(
44        join_handle: tokio::task::JoinHandle<Result<(), SupervisorError>>,
45        tx: mpsc::UnboundedSender<SupervisorMessage>,
46    ) -> Self {
47        let (result_tx, result_rx) = watch::channel(None);
48
49        tokio::spawn(async move {
50            let result = match join_handle.await {
51                Ok(supervisor_result) => supervisor_result,
52                Err(_join_error) => Ok(()),
53            };
54            let _ = result_tx.send(Some(result));
55        });
56
57        Self { tx, result_rx }
58    }
59
60    /// Waits for the supervisor to finish. Safe to call from multiple clones concurrently.
61    pub async fn wait(&self) -> Result<(), SupervisorError> {
62        let mut rx = self.result_rx.clone();
63        loop {
64            if let Some(result) = rx.borrow_and_update().clone() {
65                return result;
66            }
67            if rx.changed().await.is_err() {
68                return Ok(());
69            }
70        }
71    }
72
73    pub fn add_task<T: SupervisedTask + Clone>(
74        &self,
75        task_name: &str,
76        task: T,
77    ) -> Result<(), SupervisorHandleError> {
78        self.tx
79            .send(SupervisorMessage::AddTask(task_name.into(), Box::new(task)))
80            .map_err(|_| SupervisorHandleError::SendError)
81    }
82
83    pub fn restart(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
84        self.tx
85            .send(SupervisorMessage::RestartTask(task_name.into()))
86            .map_err(|_| SupervisorHandleError::SendError)
87    }
88
89    pub fn kill_task(&self, task_name: &str) -> Result<(), SupervisorHandleError> {
90        self.tx
91            .send(SupervisorMessage::KillTask(task_name.into()))
92            .map_err(|_| SupervisorHandleError::SendError)
93    }
94
95    pub fn shutdown(&self) -> Result<(), SupervisorHandleError> {
96        self.tx
97            .send(SupervisorMessage::Shutdown)
98            .map_err(|_| SupervisorHandleError::SendError)
99    }
100
101    pub async fn get_task_status(
102        &self,
103        task_name: &str,
104    ) -> Result<Option<TaskStatus>, SupervisorHandleError> {
105        let (sender, receiver) = oneshot::channel();
106        self.tx
107            .send(SupervisorMessage::GetTaskStatus(task_name.into(), sender))
108            .map_err(|_| SupervisorHandleError::SendError)?;
109        receiver.await.map_err(SupervisorHandleError::RecvError)
110    }
111
112    pub async fn get_all_task_statuses(
113        &self,
114    ) -> Result<HashMap<String, TaskStatus>, SupervisorHandleError> {
115        let (sender, receiver) = oneshot::channel();
116        self.tx
117            .send(SupervisorMessage::GetAllTaskStatuses(sender))
118            .map_err(|_| SupervisorHandleError::SendError)?;
119        receiver.await.map_err(SupervisorHandleError::RecvError)
120    }
121
122    fn is_channel_open(&self) -> bool {
123        !self.tx.is_closed()
124    }
125}
126
127impl std::fmt::Debug for SupervisorHandle {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        f.debug_struct("SupervisorHandle")
130            .field("channel_open", &self.is_channel_open())
131            .finish()
132    }
133}