tokio_graceful_shutdown/subsystem/subsystem_handle.rs
1use std::{
2 mem::ManuallyDrop,
3 sync::{Arc, Mutex, atomic::Ordering},
4};
5
6use atomic::Atomic;
7use tokio::sync::{mpsc, oneshot};
8use tokio_util::sync::CancellationToken;
9
10use crate::{
11 AsyncSubsysFn, BoxedError, ErrTypeTraits, ErrorAction, NestedSubsystem, SubsystemBuilder,
12 errors::{SubsystemError, handle_dropped_error},
13 runner::{AliveGuard, SubsystemRunner},
14 utils::{JoinerToken, remote_drop_collection::RemotelyDroppableItems},
15};
16
17use super::{ErrorActions, error_collector::ErrorCollector};
18
19struct Inner<ErrType: ErrTypeTraits> {
20 name: Arc<str>,
21 cancellation_token: CancellationToken,
22 toplevel_cancellation_token: CancellationToken,
23 joiner_token: JoinerToken<ErrType>,
24 children: RemotelyDroppableItems<SubsystemRunner>,
25}
26
27/// The handle given to each subsystem through which the subsystem can interact with this crate.
28pub struct SubsystemHandle<ErrType: ErrTypeTraits = BoxedError> {
29 inner: ManuallyDrop<Inner<ErrType>>,
30 // When dropped, redirect Self into this channel.
31 // Required to pass the handle back out of `tokio::spawn`.
32 drop_redirect: Option<oneshot::Sender<WeakSubsystemHandle<ErrType>>>,
33}
34
35pub(crate) struct WeakSubsystemHandle<ErrType: ErrTypeTraits> {
36 pub(crate) joiner_token: JoinerToken<ErrType>,
37 // Children are stored here to keep them alive
38 _children: RemotelyDroppableItems<SubsystemRunner>,
39}
40
41impl<ErrType: ErrTypeTraits> SubsystemHandle<ErrType> {
42 /// Start a nested subsystem.
43 ///
44 /// Once called, the subsystem will be started immediately, similar to [`tokio::spawn`].
45 ///
46 /// # Arguments
47 ///
48 /// * `builder` - The [`SubsystemBuilder`] that contains all the information
49 /// about the subsystem that should be spawned.
50 ///
51 /// # Returns
52 ///
53 /// A [`NestedSubsystem`] that can be used to control or join the subsystem.
54 ///
55 /// # Examples
56 ///
57 /// ```
58 /// use miette::Result;
59 /// use tokio_graceful_shutdown::{SubsystemBuilder, SubsystemHandle};
60 ///
61 /// async fn nested_subsystem(subsys: &mut SubsystemHandle) -> Result<()> {
62 /// subsys.on_shutdown_requested().await;
63 /// Ok(())
64 /// }
65 ///
66 /// async fn my_subsystem(subsys: &mut SubsystemHandle) -> Result<()> {
67 /// // start a nested subsystem
68 /// subsys.start(SubsystemBuilder::new("Nested", nested_subsystem));
69 ///
70 /// subsys.on_shutdown_requested().await;
71 /// Ok(())
72 /// }
73 /// ```
74 #[track_caller]
75 pub fn start<Err, Subsys>(&self, builder: SubsystemBuilder<Subsys>) -> NestedSubsystem<ErrType>
76 where
77 Subsys: 'static + for<'a> AsyncSubsysFn<&'a mut SubsystemHandle<ErrType>, Result<(), Err>>,
78 Err: Into<ErrType>,
79 {
80 self.start_with_abs_name(
81 if self.inner.name.as_ref() == "/" {
82 Arc::from(format!("/{}", builder.name))
83 } else {
84 Arc::from(format!("{}/{}", self.inner.name, builder.name))
85 },
86 builder.subsystem,
87 ErrorActions {
88 on_failure: Atomic::new(builder.failure_action),
89 on_panic: Atomic::new(builder.panic_action),
90 },
91 builder.detached,
92 )
93 }
94
95 #[track_caller]
96 pub(crate) fn start_with_abs_name<Err, Subsys>(
97 &self,
98 name: Arc<str>,
99 subsystem: Subsys,
100 error_actions: ErrorActions,
101 detached: bool,
102 ) -> NestedSubsystem<ErrType>
103 where
104 Subsys: 'static + for<'a> AsyncSubsysFn<&'a mut SubsystemHandle<ErrType>, Result<(), Err>>,
105 Err: Into<ErrType>,
106 {
107 let alive_guard = AliveGuard::new();
108
109 let (error_sender, errors) = mpsc::unbounded_channel();
110
111 let cancellation_token = if detached {
112 CancellationToken::new()
113 } else {
114 self.inner.cancellation_token.child_token()
115 };
116
117 let error_actions = Arc::new(error_actions);
118
119 let (joiner_token, joiner_token_ref) = self.inner.joiner_token.child_token({
120 let cancellation_token = cancellation_token.clone();
121 let error_actions = Arc::clone(&error_actions);
122 move |e| {
123 let error_action = match &e {
124 SubsystemError::Failed(_, _) => {
125 error_actions.on_failure.load(Ordering::Relaxed)
126 }
127 SubsystemError::Panicked(_) => error_actions.on_panic.load(Ordering::Relaxed),
128 };
129
130 match error_action {
131 ErrorAction::Forward => Some(e),
132 ErrorAction::CatchAndLocalShutdown => {
133 handle_dropped_error(error_sender.send(e));
134 cancellation_token.cancel();
135 None
136 }
137 }
138 }
139 });
140
141 let child_handle = SubsystemHandle {
142 inner: ManuallyDrop::new(Inner {
143 name: Arc::clone(&name),
144 cancellation_token: cancellation_token.clone(),
145 toplevel_cancellation_token: self.inner.toplevel_cancellation_token.clone(),
146 joiner_token,
147 children: RemotelyDroppableItems::new(),
148 }),
149 drop_redirect: None,
150 };
151
152 let runner = SubsystemRunner::new(name, subsystem, child_handle, alive_guard.clone());
153 let abort_handle = runner.abort_handle();
154
155 // Shenanigans to juggle child ownership
156 //
157 // RACE CONDITION SAFETY:
158 // If the subsystem ends before `on_finished` was able to be called, nothing bad happens.
159 // alive_guard will keep the guard alive and the callback will only be called inside of
160 // the guard's drop() implementation.
161 let child_dropper = self.inner.children.insert(runner);
162 alive_guard.on_finished(|| {
163 drop(child_dropper);
164 });
165
166 NestedSubsystem {
167 joiner: joiner_token_ref,
168 cancellation_token,
169 errors: Mutex::new(ErrorCollector::new(errors)),
170 error_actions,
171 abort_handle,
172 }
173 }
174
175 /// Waits until all the children of this subsystem are finished.
176 pub async fn wait_for_children(&self) {
177 self.inner.joiner_token.join_children().await
178 }
179
180 // For internal use only - should never be used by users.
181 // Required as a short-lived second reference inside of `runner`.
182 pub(crate) fn delayed_clone(&mut self) -> oneshot::Receiver<WeakSubsystemHandle<ErrType>> {
183 let (sender, receiver) = oneshot::channel();
184
185 let previous = self.drop_redirect.replace(sender);
186 assert!(previous.is_none());
187
188 receiver
189 }
190
191 /// Wait for the shutdown mode to be triggered.
192 ///
193 /// Once the shutdown mode is entered, all existing calls to this
194 /// method will be released and future calls to this method will
195 /// return immediately.
196 ///
197 /// This is the primary method of subsystems to react to
198 /// the shutdown requests. Most often, it will be used in [`tokio::select`]
199 /// statements to cancel other code as soon as the shutdown is requested.
200 ///
201 /// # Examples
202 ///
203 /// ```
204 /// use miette::Result;
205 /// use tokio::time::{sleep, Duration};
206 /// use tokio_graceful_shutdown::SubsystemHandle;
207 ///
208 /// async fn countdown() {
209 /// for i in (1..10).rev() {
210 /// tracing::info!("Countdown: {}", i);
211 /// sleep(Duration::from_millis(1000)).await;
212 /// }
213 /// }
214 ///
215 /// async fn countdown_subsystem(subsys: SubsystemHandle) -> Result<()> {
216 /// tracing::info!("Starting countdown ...");
217 ///
218 /// // This cancels the countdown as soon as shutdown
219 /// // mode was entered
220 /// tokio::select! {
221 /// _ = subsys.on_shutdown_requested() => {
222 /// tracing::info!("Countdown cancelled.");
223 /// },
224 /// _ = countdown() => {
225 /// tracing::info!("Countdown finished.");
226 /// }
227 /// };
228 ///
229 /// Ok(())
230 /// }
231 /// ```
232 pub async fn on_shutdown_requested(&self) {
233 self.inner.cancellation_token.cancelled().await
234 }
235
236 /// Returns whether a shutdown should be performed now.
237 ///
238 /// This method is provided for subsystems that need to query the shutdown
239 /// request state repeatedly.
240 ///
241 /// This can be useful in scenarios where a subsystem depends on the graceful
242 /// shutdown of its nested coroutines before it can run final cleanup steps itself.
243 ///
244 /// # Examples
245 ///
246 /// ```
247 /// use miette::Result;
248 /// use tokio::time::{sleep, Duration};
249 /// use tokio_graceful_shutdown::SubsystemHandle;
250 ///
251 /// async fn uncancellable_action(subsys: &SubsystemHandle) {
252 /// tokio::select! {
253 /// // Execute an action. A dummy `sleep` in this case.
254 /// _ = sleep(Duration::from_millis(1000)) => {
255 /// tracing::info!("Action finished.");
256 /// }
257 /// // Perform a shutdown if requested
258 /// _ = subsys.on_shutdown_requested() => {
259 /// tracing::info!("Action aborted.");
260 /// },
261 /// }
262 /// }
263 ///
264 /// async fn my_subsystem(subsys: SubsystemHandle) -> Result<()> {
265 /// tracing::info!("Starting subsystem ...");
266 ///
267 /// // We cannot do a `tokio::select` with `on_shutdown_requested`
268 /// // here, because a shutdown would cancel the action without giving
269 /// // it the chance to react first.
270 /// while !subsys.is_shutdown_requested() {
271 /// uncancellable_action(&subsys).await;
272 /// }
273 ///
274 /// tracing::info!("Subsystem stopped.");
275 ///
276 /// Ok(())
277 /// }
278 /// ```
279 pub fn is_shutdown_requested(&self) -> bool {
280 self.inner.cancellation_token.is_cancelled()
281 }
282
283 /// Triggers a shutdown of the entire subsystem tree.
284 ///
285 /// # Examples
286 ///
287 /// ```
288 /// use miette::Result;
289 /// use tokio::time::{sleep, Duration};
290 /// use tokio_graceful_shutdown::SubsystemHandle;
291 ///
292 /// async fn stop_subsystem(subsys: SubsystemHandle) -> Result<()> {
293 /// // This subsystem wait for one second and then stops the program.
294 /// sleep(Duration::from_millis(1000)).await;
295 ///
296 /// // Shut down the entire subsystem tree
297 /// subsys.request_shutdown();
298 ///
299 /// Ok(())
300 /// }
301 /// ```
302 pub fn request_shutdown(&self) {
303 self.inner.toplevel_cancellation_token.cancel();
304 }
305
306 /// Triggers a shutdown of the current subsystem and all
307 /// of its children.
308 pub fn request_local_shutdown(&self) {
309 self.inner.cancellation_token.cancel();
310 }
311
312 pub(crate) fn get_cancellation_token(&self) -> &CancellationToken {
313 &self.inner.cancellation_token
314 }
315
316 /// Creates a cancellation token that will get triggered once the
317 /// subsystem shuts down.
318 ///
319 /// This is intended for more lightweight situations where
320 /// creating full-blown subsystems would be too much overhead,
321 /// like spawning connection handlers of a webserver.
322 ///
323 /// For more information, see the [hyper example](https://github.com/Finomnis/tokio-graceful-shutdown/blob/main/examples/hyper.rs).
324 pub fn create_cancellation_token(&self) -> CancellationToken {
325 self.inner.cancellation_token.child_token()
326 }
327
328 /// Get the name associated with this subsystem.
329 ///
330 /// Note that the names of nested subsystems are built unix-path alike,
331 /// starting and delimited by slashes (e.g. `/a/b/c`).
332 ///
333 /// See [`SubsystemBuilder::new()`] how to set this name.
334 pub fn name(&self) -> &str {
335 &self.inner.name
336 }
337}
338
339impl<ErrType: ErrTypeTraits> Drop for SubsystemHandle<ErrType> {
340 fn drop(&mut self) {
341 // SAFETY: This is how ManuallyDrop is meant to be used.
342 // `self.inner` won't ever be used again because `self` will be gone after this
343 // function is finished.
344 // This takes the `self.inner` object and makes it droppable again.
345 //
346 // This workaround is required to take ownership for the `self.drop_redirect` channel.
347 let inner = unsafe { ManuallyDrop::take(&mut self.inner) };
348
349 if let Some(redirect) = self.drop_redirect.take() {
350 let redirected_self = WeakSubsystemHandle {
351 joiner_token: inner.joiner_token,
352 _children: inner.children,
353 };
354
355 // ignore error; an error would indicate that there is no receiver.
356 // in that case, do nothing.
357 let _ = redirect.send(redirected_self);
358 }
359 }
360}
361
362pub(crate) fn root_handle<ErrType: ErrTypeTraits>(
363 cancellation_token: CancellationToken,
364 on_error: impl Fn(SubsystemError<ErrType>) + Sync + Send + 'static,
365) -> SubsystemHandle<ErrType> {
366 SubsystemHandle {
367 inner: ManuallyDrop::new(Inner {
368 name: Arc::from(""),
369 cancellation_token: cancellation_token.clone(),
370 toplevel_cancellation_token: cancellation_token.clone(),
371 joiner_token: JoinerToken::new(move |e| {
372 on_error(e);
373 cancellation_token.cancel();
374 None
375 })
376 .0,
377 children: RemotelyDroppableItems::new(),
378 }),
379 drop_redirect: None,
380 }
381}
382
383#[cfg(test)]
384mod tests;