scuffle_context/
lib.rs

1//! A crate designed to provide the ability to cancel futures using a context
2//! go-like approach, allowing for graceful shutdowns and cancellations.
3//!
4//! ## Why do we need this?
5//!
6//! Its often useful to wait for all the futures to shutdown or to cancel them
7//! when we no longer care about the results. This crate provides an interface
8//! to cancel all futures associated with a context or wait for them to finish
9//! before shutting down. Allowing for graceful shutdowns and cancellations.
10//!
11//! ## Usage
12//!
13//! Here is an example of how to use the `Context` to cancel a spawned task.
14//!
15//! ```rust
16//! # use scuffle_context::{Context, ContextFutExt};
17//! # tokio_test::block_on(async {
18//! let (ctx, handler) = Context::new();
19//!
20//! tokio::spawn(async {
21//!     // Do some work
22//!     tokio::time::sleep(std::time::Duration::from_secs(10)).await;
23//! }.with_context(ctx));
24//!
25//! // Will stop the spawned task and cancel all associated futures.
26//! handler.cancel();
27//! # });
28//! ```
29//!
30//! ## License
31//!
32//! This project is licensed under the [MIT](./LICENSE.MIT) or
33//! [Apache-2.0](./LICENSE.Apache-2.0) license. You can choose between one of
34//! them if you use this work.
35//!
36//! `SPDX-License-Identifier: MIT OR Apache-2.0`
37#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
38
39use std::sync::Arc;
40use std::sync::atomic::{AtomicBool, AtomicUsize};
41
42use tokio_util::sync::CancellationToken;
43
44/// For extending types.
45mod ext;
46
47pub use ext::*;
48
49/// Create by calling [`ContextTrackerInner::child`].
50#[derive(Debug)]
51struct ContextTracker(Arc<ContextTrackerInner>);
52
53impl Drop for ContextTracker {
54    fn drop(&mut self) {
55        let prev_active_count = self.0.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
56        // If this was the last active `ContextTracker` and the context has been
57        // stopped, then notify the waiters
58        if prev_active_count == 1 && self.0.stopped.load(std::sync::atomic::Ordering::Relaxed) {
59            self.0.notify.notify_waiters();
60        }
61    }
62}
63
64#[derive(Debug)]
65struct ContextTrackerInner {
66    stopped: AtomicBool,
67    /// This count keeps track of the number of `ContextTrackers` that exist for
68    /// this `ContextTrackerInner`.
69    active_count: AtomicUsize,
70    notify: tokio::sync::Notify,
71}
72
73impl ContextTrackerInner {
74    fn new() -> Arc<Self> {
75        Arc::new(Self {
76            stopped: AtomicBool::new(false),
77            active_count: AtomicUsize::new(0),
78            notify: tokio::sync::Notify::new(),
79        })
80    }
81
82    /// Create a new `ContextTracker` from an `Arc<ContextTrackerInner>`.
83    fn child(self: &Arc<Self>) -> ContextTracker {
84        self.active_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
85        ContextTracker(Arc::clone(self))
86    }
87
88    /// Mark this `ContextTrackerInner` as stopped.
89    fn stop(&self) {
90        self.stopped.store(true, std::sync::atomic::Ordering::Relaxed);
91    }
92
93    /// Wait for this `ContextTrackerInner` to be stopped and all associated
94    /// `ContextTracker`s to be dropped.
95    async fn wait(&self) {
96        let notify = self.notify.notified();
97
98        // If there are no active children, then the notify will never be called
99        if self.active_count.load(std::sync::atomic::Ordering::Relaxed) == 0 {
100            return;
101        }
102
103        notify.await;
104    }
105}
106
107/// A context for cancelling futures and waiting for shutdown.
108///
109/// A context can be created from a handler by calling [`Handler::context`] or
110/// from another context by calling [`Context::new_child`] so to have a
111/// hierarchy of contexts.
112///
113/// Contexts can then be attached to futures or streams in order to
114/// automatically cancel them when the context is done, when invoking
115/// [`Handler::cancel`].
116/// The [`Handler::shutdown`] method will block until all contexts have been
117/// dropped allowing for a graceful shutdown.
118#[derive(Debug)]
119pub struct Context {
120    token: CancellationToken,
121    tracker: ContextTracker,
122}
123
124impl Clone for Context {
125    fn clone(&self) -> Self {
126        Self {
127            token: self.token.clone(),
128            tracker: self.tracker.0.child(),
129        }
130    }
131}
132
133impl Context {
134    #[must_use]
135    /// Create a new context using the global handler.
136    /// Returns a child context and child handler of the global handler.
137    pub fn new() -> (Self, Handler) {
138        Handler::global().new_child()
139    }
140
141    #[must_use]
142    /// Create a new child context from this context.
143    /// Returns a new child context and child handler of this context.
144    ///
145    /// # Example
146    ///
147    /// ```rust
148    /// use scuffle_context::Context;
149    ///
150    /// let (parent, parent_handler) = Context::new();
151    /// let (child, child_handler) = parent.new_child();
152    /// ```
153    pub fn new_child(&self) -> (Self, Handler) {
154        let token = self.token.child_token();
155        let tracker = ContextTrackerInner::new();
156
157        (
158            Self {
159                tracker: tracker.child(),
160                token: token.clone(),
161            },
162            Handler {
163                token: Arc::new(TokenDropGuard(token)),
164                tracker,
165            },
166        )
167    }
168
169    #[must_use]
170    /// Returns the global context
171    pub fn global() -> Self {
172        Handler::global().context()
173    }
174
175    /// Wait for the context to be done (the handler to be shutdown).
176    pub async fn done(&self) {
177        self.token.cancelled().await;
178    }
179
180    /// The same as [`Context::done`] but takes ownership of the context.
181    pub async fn into_done(self) {
182        self.done().await;
183    }
184
185    /// Returns true if the context is done.
186    #[must_use]
187    pub fn is_done(&self) -> bool {
188        self.token.is_cancelled()
189    }
190}
191
192/// A wrapper type around [`CancellationToken`] that will cancel the token as
193/// soon as it is dropped.
194#[derive(Debug)]
195struct TokenDropGuard(CancellationToken);
196
197impl TokenDropGuard {
198    #[must_use]
199    fn child(&self) -> CancellationToken {
200        self.0.child_token()
201    }
202
203    fn cancel(&self) {
204        self.0.cancel();
205    }
206}
207
208impl Drop for TokenDropGuard {
209    fn drop(&mut self) {
210        self.cancel();
211    }
212}
213
214#[derive(Debug, Clone)]
215pub struct Handler {
216    token: Arc<TokenDropGuard>,
217    tracker: Arc<ContextTrackerInner>,
218}
219
220impl Default for Handler {
221    fn default() -> Self {
222        Self::new()
223    }
224}
225
226impl Handler {
227    #[must_use]
228    /// Create a new handler.
229    pub fn new() -> Handler {
230        let token = CancellationToken::new();
231        let tracker = ContextTrackerInner::new();
232
233        Handler {
234            token: Arc::new(TokenDropGuard(token)),
235            tracker,
236        }
237    }
238
239    #[must_use]
240    /// Returns the global handler.
241    pub fn global() -> &'static Self {
242        static GLOBAL: std::sync::OnceLock<Handler> = std::sync::OnceLock::new();
243
244        GLOBAL.get_or_init(Handler::new)
245    }
246
247    /// Shutdown the handler and wait for all contexts to be done.
248    pub async fn shutdown(&self) {
249        self.cancel();
250        self.done().await;
251    }
252
253    /// Waits for the handler to be done (waiting for all contexts to be done).
254    pub async fn done(&self) {
255        self.token.0.cancelled().await;
256        self.wait().await;
257    }
258
259    /// Waits for the handler to be done (waiting for all contexts to be done).
260    /// Returns once all contexts are done, even if the handler is not done and
261    /// contexts can be created after this call.
262    pub async fn wait(&self) {
263        self.tracker.wait().await;
264    }
265
266    #[must_use]
267    /// Create a new context from this handler.
268    pub fn context(&self) -> Context {
269        Context {
270            token: self.token.child(),
271            tracker: self.tracker.child(),
272        }
273    }
274
275    #[must_use]
276    /// Create a new child context from this handler
277    pub fn new_child(&self) -> (Context, Handler) {
278        self.context().new_child()
279    }
280
281    /// Cancel the handler.
282    pub fn cancel(&self) {
283        self.tracker.stop();
284        self.token.cancel();
285    }
286
287    /// Returns true if the handler is done.
288    pub fn is_done(&self) -> bool {
289        self.token.0.is_cancelled()
290    }
291}
292
293#[cfg_attr(all(coverage_nightly, test), coverage(off))]
294#[cfg(test)]
295mod tests {
296    use scuffle_future_ext::FutureExt;
297
298    use crate::{Context, Handler};
299
300    #[tokio::test]
301    async fn new() {
302        let (ctx, handler) = Context::new();
303        assert!(!handler.is_done());
304        assert!(!ctx.is_done());
305
306        let handler = Handler::default();
307        assert!(!handler.is_done());
308    }
309
310    #[tokio::test]
311    async fn cancel() {
312        let (ctx, handler) = Context::new();
313        let (child_ctx, child_handler) = ctx.new_child();
314        let child_ctx2 = ctx.clone();
315
316        assert!(!handler.is_done());
317        assert!(!ctx.is_done());
318        assert!(!child_handler.is_done());
319        assert!(!child_ctx.is_done());
320        assert!(!child_ctx2.is_done());
321
322        handler.cancel();
323
324        assert!(handler.is_done());
325        assert!(ctx.is_done());
326        assert!(child_handler.is_done());
327        assert!(child_ctx.is_done());
328        assert!(child_ctx2.is_done());
329    }
330
331    #[tokio::test]
332    async fn cancel_child() {
333        let (ctx, handler) = Context::new();
334        let (child_ctx, child_handler) = ctx.new_child();
335
336        assert!(!handler.is_done());
337        assert!(!ctx.is_done());
338        assert!(!child_handler.is_done());
339        assert!(!child_ctx.is_done());
340
341        child_handler.cancel();
342
343        assert!(!handler.is_done());
344        assert!(!ctx.is_done());
345        assert!(child_handler.is_done());
346        assert!(child_ctx.is_done());
347    }
348
349    #[tokio::test]
350    async fn shutdown() {
351        let (ctx, handler) = Context::new();
352
353        assert!(!handler.is_done());
354        assert!(!ctx.is_done());
355
356        // This is expected to timeout
357        assert!(
358            handler
359                .shutdown()
360                .with_timeout(std::time::Duration::from_millis(200))
361                .await
362                .is_err()
363        );
364        assert!(handler.is_done());
365        assert!(ctx.is_done());
366        assert!(
367            ctx.into_done()
368                .with_timeout(std::time::Duration::from_millis(200))
369                .await
370                .is_ok()
371        );
372
373        assert!(
374            handler
375                .shutdown()
376                .with_timeout(std::time::Duration::from_millis(200))
377                .await
378                .is_ok()
379        );
380        assert!(
381            handler
382                .wait()
383                .with_timeout(std::time::Duration::from_millis(200))
384                .await
385                .is_ok()
386        );
387        assert!(
388            handler
389                .done()
390                .with_timeout(std::time::Duration::from_millis(200))
391                .await
392                .is_ok()
393        );
394        assert!(handler.is_done());
395    }
396
397    #[tokio::test]
398    async fn global_handler() {
399        let handler = Handler::global();
400
401        assert!(!handler.is_done());
402
403        handler.cancel();
404
405        assert!(handler.is_done());
406        assert!(Handler::global().is_done());
407        assert!(Context::global().is_done());
408
409        let (child_ctx, child_handler) = Handler::global().new_child();
410        assert!(child_handler.is_done());
411        assert!(child_ctx.is_done());
412    }
413}