Skip to main content

teloxide_ng/utils/
shutdown_token.rs

1use std::{
2    fmt,
3    future::Future,
4    sync::{
5        Arc,
6        atomic::{AtomicU8, Ordering},
7    },
8};
9
10use tokio::sync::Notify;
11
12/// A token which used to shutdown [`Dispatcher`].
13///
14/// [`Dispatcher`]: crate::dispatching::Dispatcher
15#[derive(Clone)]
16pub struct ShutdownToken {
17    // FIXME: use a single arc
18    dispatcher_state: Arc<DispatcherState>,
19    shutdown_notify_back: Arc<Notify>,
20}
21
22/// This error is returned from [`ShutdownToken::shutdown`] when trying to
23/// shutdown an idle [`Dispatcher`].
24///
25/// [`Dispatcher`]: crate::dispatching::Dispatcher
26#[derive(Debug)]
27pub struct IdleShutdownError;
28
29impl ShutdownToken {
30    /// Tries to shutdown dispatching.
31    ///
32    /// Returns an error if the dispatcher is idle at the moment.
33    ///
34    /// If you don't need to wait for shutdown, the returned future can be
35    /// ignored.
36    pub fn shutdown(&self) -> Result<impl Future<Output = ()> + '_, IdleShutdownError> {
37        match shutdown_inner(&self.dispatcher_state) {
38            Ok(()) | Err(Ok(AlreadyShuttingDown)) => Ok(async move {
39                log::info!("Trying to shutdown the dispatcher...");
40                self.shutdown_notify_back.notified().await
41            }),
42            Err(Err(err)) => Err(err),
43        }
44    }
45
46    pub(crate) fn new() -> Self {
47        Self {
48            dispatcher_state: Arc::new(DispatcherState {
49                inner: AtomicU8::new(ShutdownState::Idle as _),
50                notify: <_>::default(),
51            }),
52            shutdown_notify_back: <_>::default(),
53        }
54    }
55
56    pub(crate) async fn wait_for_changes(&self) {
57        self.dispatcher_state.notify.notified().await;
58    }
59
60    pub(crate) fn start_dispatching(&self) {
61        if let Err(actual) =
62            self.dispatcher_state.compare_exchange(ShutdownState::Idle, ShutdownState::Running)
63        {
64            panic!(
65                "Dispatching is already running: expected `{:?}` state, found `{:?}`",
66                ShutdownState::Idle,
67                actual
68            );
69        }
70    }
71
72    pub(crate) fn is_shutting_down(&self) -> bool {
73        matches!(self.dispatcher_state.load(), ShutdownState::ShuttingDown)
74    }
75
76    pub(crate) fn done(&self) {
77        if self.is_shutting_down() {
78            // Stopped because of a `shutdown` call.
79
80            // Notify `shutdown`s that we finished
81            self.shutdown_notify_back.notify_waiters();
82            log::info!("Dispatching has been shut down.");
83        } else {
84            log::info!("Dispatching has been stopped (listener returned `None`).");
85        }
86
87        self.dispatcher_state.store(ShutdownState::Idle);
88    }
89}
90
91impl fmt::Display for IdleShutdownError {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        write!(f, "Dispatcher was idle and as such couldn't be shut down")
94    }
95}
96
97impl std::error::Error for IdleShutdownError {}
98
99struct DispatcherState {
100    inner: AtomicU8,
101    notify: Notify,
102}
103
104impl DispatcherState {
105    // Ordering::Relaxed: only one atomic variable, nothing to synchronize.
106    fn load(&self) -> ShutdownState {
107        ShutdownState::from_u8(self.inner.load(Ordering::Relaxed))
108    }
109
110    fn store(&self, new: ShutdownState) {
111        self.inner.store(new as _, Ordering::Relaxed);
112        self.notify.notify_waiters();
113    }
114
115    fn compare_exchange(
116        &self,
117        current: ShutdownState,
118        new: ShutdownState,
119    ) -> Result<ShutdownState, ShutdownState> {
120        self.inner
121            .compare_exchange(current as _, new as _, Ordering::Relaxed, Ordering::Relaxed)
122            .map(ShutdownState::from_u8)
123            .map_err(ShutdownState::from_u8)
124            // FIXME: `Result::inspect` when :(
125            .inspect(|_| self.notify.notify_waiters())
126    }
127}
128
129#[repr(u8)]
130#[derive(Debug)]
131enum ShutdownState {
132    Running,
133    ShuttingDown,
134    Idle,
135}
136
137impl ShutdownState {
138    fn from_u8(n: u8) -> Self {
139        const RUNNING: u8 = ShutdownState::Running as u8;
140        const SHUTTING_DOWN: u8 = ShutdownState::ShuttingDown as u8;
141        const IDLE: u8 = ShutdownState::Idle as u8;
142
143        match n {
144            RUNNING => ShutdownState::Running,
145            SHUTTING_DOWN => ShutdownState::ShuttingDown,
146            IDLE => ShutdownState::Idle,
147            _ => unreachable!(),
148        }
149    }
150}
151
152struct AlreadyShuttingDown;
153
154fn shutdown_inner(
155    state: &DispatcherState,
156) -> Result<(), Result<AlreadyShuttingDown, IdleShutdownError>> {
157    use ShutdownState::*;
158
159    let res = state.compare_exchange(Running, ShuttingDown);
160
161    match res {
162        Ok(_) => Ok(()),
163        Err(ShuttingDown) => Err(Ok(AlreadyShuttingDown)),
164        Err(Idle) => Err(Err(IdleShutdownError)),
165        Err(Running) => unreachable!(),
166    }
167}