tokio_graceful_shutdown/subsystem/
subsystem_handle.rs

1use std::{
2    mem::ManuallyDrop,
3    sync::{Arc, Mutex, atomic::Ordering},
4};
5
6use atomic::Atomic;
7use tokio::sync::{mpsc, oneshot};
8use tokio_util::sync::CancellationToken;
9
10use crate::{
11    AsyncSubsysFn, BoxedError, ErrTypeTraits, ErrorAction, NestedSubsystem, SubsystemBuilder,
12    errors::{SubsystemError, handle_dropped_error},
13    runner::{AliveGuard, SubsystemRunner},
14    utils::{JoinerToken, remote_drop_collection::RemotelyDroppableItems},
15};
16
17use super::{ErrorActions, error_collector::ErrorCollector};
18
19struct Inner<ErrType: ErrTypeTraits> {
20    name: Arc<str>,
21    cancellation_token: CancellationToken,
22    toplevel_cancellation_token: CancellationToken,
23    joiner_token: JoinerToken<ErrType>,
24    children: RemotelyDroppableItems<SubsystemRunner>,
25}
26
27/// The handle given to each subsystem through which the subsystem can interact with this crate.
28pub struct SubsystemHandle<ErrType: ErrTypeTraits = BoxedError> {
29    inner: ManuallyDrop<Inner<ErrType>>,
30    // When dropped, redirect Self into this channel.
31    // Required to pass the handle back out of `tokio::spawn`.
32    drop_redirect: Option<oneshot::Sender<WeakSubsystemHandle<ErrType>>>,
33}
34
35pub(crate) struct WeakSubsystemHandle<ErrType: ErrTypeTraits> {
36    pub(crate) joiner_token: JoinerToken<ErrType>,
37    // Children are stored here to keep them alive
38    _children: RemotelyDroppableItems<SubsystemRunner>,
39}
40
41impl<ErrType: ErrTypeTraits> SubsystemHandle<ErrType> {
42    /// Start a nested subsystem.
43    ///
44    /// Once called, the subsystem will be started immediately, similar to [`tokio::spawn`].
45    ///
46    /// # Arguments
47    ///
48    /// * `builder` - The [`SubsystemBuilder`] that contains all the information
49    ///   about the subsystem that should be spawned.
50    ///
51    /// # Returns
52    ///
53    /// A [`NestedSubsystem`] that can be used to control or join the subsystem.
54    ///
55    /// # Examples
56    ///
57    /// ```
58    /// use miette::Result;
59    /// use tokio_graceful_shutdown::{SubsystemBuilder, SubsystemHandle};
60    ///
61    /// async fn nested_subsystem(subsys: &mut SubsystemHandle) -> Result<()> {
62    ///     subsys.on_shutdown_requested().await;
63    ///     Ok(())
64    /// }
65    ///
66    /// async fn my_subsystem(subsys: &mut SubsystemHandle) -> Result<()> {
67    ///     // start a nested subsystem
68    ///     subsys.start(SubsystemBuilder::new("Nested", nested_subsystem));
69    ///
70    ///     subsys.on_shutdown_requested().await;
71    ///     Ok(())
72    /// }
73    /// ```
74    #[track_caller]
75    pub fn start<Err, Subsys>(&self, builder: SubsystemBuilder<Subsys>) -> NestedSubsystem<ErrType>
76    where
77        Subsys: 'static + for<'a> AsyncSubsysFn<&'a mut SubsystemHandle<ErrType>, Result<(), Err>>,
78        Err: Into<ErrType>,
79    {
80        self.start_with_abs_name(
81            if self.inner.name.as_ref() == "/" {
82                Arc::from(format!("/{}", builder.name))
83            } else {
84                Arc::from(format!("{}/{}", self.inner.name, builder.name))
85            },
86            builder.subsystem,
87            ErrorActions {
88                on_failure: Atomic::new(builder.failure_action),
89                on_panic: Atomic::new(builder.panic_action),
90            },
91            builder.detached,
92        )
93    }
94
95    #[track_caller]
96    pub(crate) fn start_with_abs_name<Err, Subsys>(
97        &self,
98        name: Arc<str>,
99        subsystem: Subsys,
100        error_actions: ErrorActions,
101        detached: bool,
102    ) -> NestedSubsystem<ErrType>
103    where
104        Subsys: 'static + for<'a> AsyncSubsysFn<&'a mut SubsystemHandle<ErrType>, Result<(), Err>>,
105        Err: Into<ErrType>,
106    {
107        let alive_guard = AliveGuard::new();
108
109        let (error_sender, errors) = mpsc::unbounded_channel();
110
111        let cancellation_token = if detached {
112            CancellationToken::new()
113        } else {
114            self.inner.cancellation_token.child_token()
115        };
116
117        let error_actions = Arc::new(error_actions);
118
119        let (joiner_token, joiner_token_ref) = self.inner.joiner_token.child_token({
120            let cancellation_token = cancellation_token.clone();
121            let error_actions = Arc::clone(&error_actions);
122            move |e| {
123                let error_action = match &e {
124                    SubsystemError::Failed(_, _) => {
125                        error_actions.on_failure.load(Ordering::Relaxed)
126                    }
127                    SubsystemError::Panicked(_) => error_actions.on_panic.load(Ordering::Relaxed),
128                };
129
130                match error_action {
131                    ErrorAction::Forward => Some(e),
132                    ErrorAction::CatchAndLocalShutdown => {
133                        handle_dropped_error(error_sender.send(e));
134                        cancellation_token.cancel();
135                        None
136                    }
137                }
138            }
139        });
140
141        let child_handle = SubsystemHandle {
142            inner: ManuallyDrop::new(Inner {
143                name: Arc::clone(&name),
144                cancellation_token: cancellation_token.clone(),
145                toplevel_cancellation_token: self.inner.toplevel_cancellation_token.clone(),
146                joiner_token,
147                children: RemotelyDroppableItems::new(),
148            }),
149            drop_redirect: None,
150        };
151
152        let runner = SubsystemRunner::new(name, subsystem, child_handle, alive_guard.clone());
153        let abort_handle = runner.abort_handle();
154
155        // Shenanigans to juggle child ownership
156        //
157        // RACE CONDITION SAFETY:
158        // If the subsystem ends before `on_finished` was able to be called, nothing bad happens.
159        // alive_guard will keep the guard alive and the callback will only be called inside of
160        // the guard's drop() implementation.
161        let child_dropper = self.inner.children.insert(runner);
162        alive_guard.on_finished(|| {
163            drop(child_dropper);
164        });
165
166        NestedSubsystem {
167            joiner: joiner_token_ref,
168            cancellation_token,
169            errors: Mutex::new(ErrorCollector::new(errors)),
170            error_actions,
171            abort_handle,
172        }
173    }
174
175    /// Waits until all the children of this subsystem are finished.
176    pub async fn wait_for_children(&self) {
177        self.inner.joiner_token.join_children().await
178    }
179
180    // For internal use only - should never be used by users.
181    // Required as a short-lived second reference inside of `runner`.
182    pub(crate) fn delayed_clone(&mut self) -> oneshot::Receiver<WeakSubsystemHandle<ErrType>> {
183        let (sender, receiver) = oneshot::channel();
184
185        let previous = self.drop_redirect.replace(sender);
186        assert!(previous.is_none());
187
188        receiver
189    }
190
191    /// Wait for the shutdown mode to be triggered.
192    ///
193    /// Once the shutdown mode is entered, all existing calls to this
194    /// method will be released and future calls to this method will
195    /// return immediately.
196    ///
197    /// This is the primary method of subsystems to react to
198    /// the shutdown requests. Most often, it will be used in [`tokio::select`]
199    /// statements to cancel other code as soon as the shutdown is requested.
200    ///
201    /// # Examples
202    ///
203    /// ```
204    /// use miette::Result;
205    /// use tokio::time::{sleep, Duration};
206    /// use tokio_graceful_shutdown::SubsystemHandle;
207    ///
208    /// async fn countdown() {
209    ///     for i in (1..10).rev() {
210    ///         tracing::info!("Countdown: {}", i);
211    ///         sleep(Duration::from_millis(1000)).await;
212    ///     }
213    /// }
214    ///
215    /// async fn countdown_subsystem(subsys: SubsystemHandle) -> Result<()> {
216    ///     tracing::info!("Starting countdown ...");
217    ///
218    ///     // This cancels the countdown as soon as shutdown
219    ///     // mode was entered
220    ///     tokio::select! {
221    ///         _ = subsys.on_shutdown_requested() => {
222    ///             tracing::info!("Countdown cancelled.");
223    ///         },
224    ///         _ = countdown() => {
225    ///             tracing::info!("Countdown finished.");
226    ///         }
227    ///     };
228    ///
229    ///     Ok(())
230    /// }
231    /// ```
232    pub async fn on_shutdown_requested(&self) {
233        self.inner.cancellation_token.cancelled().await
234    }
235
236    /// Returns whether a shutdown should be performed now.
237    ///
238    /// This method is provided for subsystems that need to query the shutdown
239    /// request state repeatedly.
240    ///
241    /// This can be useful in scenarios where a subsystem depends on the graceful
242    /// shutdown of its nested coroutines before it can run final cleanup steps itself.
243    ///
244    /// # Examples
245    ///
246    /// ```
247    /// use miette::Result;
248    /// use tokio::time::{sleep, Duration};
249    /// use tokio_graceful_shutdown::SubsystemHandle;
250    ///
251    /// async fn uncancellable_action(subsys: &SubsystemHandle) {
252    ///     tokio::select! {
253    ///         // Execute an action. A dummy `sleep` in this case.
254    ///         _ = sleep(Duration::from_millis(1000)) => {
255    ///             tracing::info!("Action finished.");
256    ///         }
257    ///         // Perform a shutdown if requested
258    ///         _ = subsys.on_shutdown_requested() => {
259    ///             tracing::info!("Action aborted.");
260    ///         },
261    ///     }
262    /// }
263    ///
264    /// async fn my_subsystem(subsys: SubsystemHandle) -> Result<()> {
265    ///     tracing::info!("Starting subsystem ...");
266    ///
267    ///     // We cannot do a `tokio::select` with `on_shutdown_requested`
268    ///     // here, because a shutdown would cancel the action without giving
269    ///     // it the chance to react first.
270    ///     while !subsys.is_shutdown_requested() {
271    ///         uncancellable_action(&subsys).await;
272    ///     }
273    ///
274    ///     tracing::info!("Subsystem stopped.");
275    ///
276    ///     Ok(())
277    /// }
278    /// ```
279    pub fn is_shutdown_requested(&self) -> bool {
280        self.inner.cancellation_token.is_cancelled()
281    }
282
283    /// Triggers a shutdown of the entire subsystem tree.
284    ///
285    /// # Examples
286    ///
287    /// ```
288    /// use miette::Result;
289    /// use tokio::time::{sleep, Duration};
290    /// use tokio_graceful_shutdown::SubsystemHandle;
291    ///
292    /// async fn stop_subsystem(subsys: SubsystemHandle) -> Result<()> {
293    ///     // This subsystem wait for one second and then stops the program.
294    ///     sleep(Duration::from_millis(1000)).await;
295    ///
296    ///     // Shut down the entire subsystem tree
297    ///     subsys.request_shutdown();
298    ///
299    ///     Ok(())
300    /// }
301    /// ```
302    pub fn request_shutdown(&self) {
303        self.inner.toplevel_cancellation_token.cancel();
304    }
305
306    /// Triggers a shutdown of the current subsystem and all
307    /// of its children.
308    pub fn request_local_shutdown(&self) {
309        self.inner.cancellation_token.cancel();
310    }
311
312    pub(crate) fn get_cancellation_token(&self) -> &CancellationToken {
313        &self.inner.cancellation_token
314    }
315
316    /// Creates a cancellation token that will get triggered once the
317    /// subsystem shuts down.
318    ///
319    /// This is intended for more lightweight situations where
320    /// creating full-blown subsystems would be too much overhead,
321    /// like spawning connection handlers of a webserver.
322    ///
323    /// For more information, see the [hyper example](https://github.com/Finomnis/tokio-graceful-shutdown/blob/main/examples/hyper.rs).
324    pub fn create_cancellation_token(&self) -> CancellationToken {
325        self.inner.cancellation_token.child_token()
326    }
327
328    /// Get the name associated with this subsystem.
329    ///
330    /// Note that the names of nested subsystems are built unix-path alike,
331    /// starting and delimited by slashes (e.g. `/a/b/c`).
332    ///
333    /// See [`SubsystemBuilder::new()`] how to set this name.
334    pub fn name(&self) -> &str {
335        &self.inner.name
336    }
337}
338
339impl<ErrType: ErrTypeTraits> Drop for SubsystemHandle<ErrType> {
340    fn drop(&mut self) {
341        // SAFETY: This is how ManuallyDrop is meant to be used.
342        // `self.inner` won't ever be used again because `self` will be gone after this
343        // function is finished.
344        // This takes the `self.inner` object and makes it droppable again.
345        //
346        // This workaround is required to take ownership for the `self.drop_redirect` channel.
347        let inner = unsafe { ManuallyDrop::take(&mut self.inner) };
348
349        if let Some(redirect) = self.drop_redirect.take() {
350            let redirected_self = WeakSubsystemHandle {
351                joiner_token: inner.joiner_token,
352                _children: inner.children,
353            };
354
355            // ignore error; an error would indicate that there is no receiver.
356            // in that case, do nothing.
357            let _ = redirect.send(redirected_self);
358        }
359    }
360}
361
362pub(crate) fn root_handle<ErrType: ErrTypeTraits>(
363    cancellation_token: CancellationToken,
364    on_error: impl Fn(SubsystemError<ErrType>) + Sync + Send + 'static,
365) -> SubsystemHandle<ErrType> {
366    SubsystemHandle {
367        inner: ManuallyDrop::new(Inner {
368            name: Arc::from(""),
369            cancellation_token: cancellation_token.clone(),
370            toplevel_cancellation_token: cancellation_token.clone(),
371            joiner_token: JoinerToken::new(move |e| {
372                on_error(e);
373                cancellation_token.cancel();
374                None
375            })
376            .0,
377            children: RemotelyDroppableItems::new(),
378        }),
379        drop_redirect: None,
380    }
381}
382
383#[cfg(test)]
384mod tests;