Skip to main content

scope_spawn/
scope.rs

1//!
2//! Implementation of Scope
3//!
4
5use tokio::task::JoinHandle;
6use tokio_util::future::FutureExt;
7use tokio_util::sync::CancellationToken;
8use tokio_util::task::TaskTracker;
9
10/// A Scope for spawning Tokio Tasks.
11///
12/// When the Scope is dropped, any tasks which are still executing are cancelled.
13/// If you want to cancel all tasks sooner, you can use [Scope::cancel()].
14#[derive(Clone, Debug, Default)]
15pub struct Scope {
16    token: CancellationToken,
17    tracker: TaskTracker,
18}
19
20impl Scope {
21    /// Create a new scoped task spawner
22    ///
23    /// Tasks may be cancelled manually, using `cancel()`, or automatically
24    /// when the Scope is dropped.
25    pub fn new() -> Self {
26        Self {
27            token: CancellationToken::new(),
28            tracker: TaskTracker::new(),
29        }
30    }
31
32    /// Cancel all the spawned, and still executing, tasks.
33    pub fn cancel(&self) {
34        self.token.cancel();
35    }
36
37    /// Spawn a task within the scope.
38    ///
39    /// This function is the primary way to introduce concurrency within a `Scope`.
40    /// The spawned task will be cancelled automatically when the `Scope` is dropped
41    /// or when `Scope::cancel()` is called.
42    ///
43    /// # When to use `spawn`
44    ///
45    /// Use `spawn` when you need to react to the outcome of the spawned task.
46    /// It returns a `JoinHandle<Option<Output>>`, allowing you to `.await` the
47    /// result. The `Option` will be:
48    ///
49    /// - `Some(value)` if the future completes successfully.
50    /// - `None` if the future is cancelled before completion.
51    ///
52    /// If you only need to run a side-effect on completion or cancellation (like
53    /// decrementing a counter), consider using [Scope::spawn_with_hooks] for a more
54    /// direct API.
55    ///
56    /// ```
57    /// use scope_spawn::scope::Scope;
58    /// use std::time::Duration;
59    /// use tokio::time::sleep;
60    ///
61    /// #[tokio::main]
62    /// async fn main() {
63    ///     let scope = Scope::new();
64    ///
65    ///     let handle = scope.spawn(async {
66    ///         // Simulate some work
67    ///         sleep(Duration::from_millis(10)).await;
68    ///         "Hello from a spawned task!"
69    ///     });
70    ///
71    ///     let result = handle.await.unwrap();
72    ///     assert!(result.is_some());
73    ///     assert_eq!(result.unwrap(), "Hello from a spawned task!");
74    ///
75    ///     // scope is dropped here, and any remaining tasks are cancelled.
76    /// }
77    /// ```
78    pub fn spawn<F, R>(&self, future: F) -> JoinHandle<Option<F::Output>>
79    where
80        F: Future<Output = R> + Send + 'static,
81        R: Send + 'static,
82    {
83        let token = self.token.clone();
84        self.tracker
85            .spawn(async move { future.with_cancellation_token_owned(token).await })
86    }
87
88    /// Spawn a "fire-and-forget" task with completion and cancellation hooks.
89    ///
90    /// This function is useful when you need to execute a side-effect based on the
91    /// task's outcome, but do not need to handle its return value directly.
92    ///
93    /// - The `on_completion` closure runs if the task finishes successfully.
94    /// - The `on_cancellation` closure runs if the task is cancelled.
95    ///
96    /// # When to use `spawn_with_hooks`
97    ///
98    /// This method is ideal for managing resources, such as semaphores or counters,
99    /// that are tied to the lifecycle of a task. For example, you might decrement
100    /// an "in-flight requests" counter in both hooks to ensure it's always accurate,
101    /// regardless of how the task terminates.
102    ///
103    /// If you need to `await` the task's result, use [Scope::spawn] instead.
104    ///
105    /// ```
106    /// use scope_spawn::scope::Scope;
107    /// use std::sync::Arc;
108    /// use std::sync::atomic::AtomicUsize;
109    /// use std::sync::atomic::Ordering;
110    ///
111    /// #[tokio::main]
112    /// async fn main() {
113    ///     let scope = Scope::new();
114    ///     let completed_count = Arc::new(AtomicUsize::new(0));
115    ///
116    ///     let count_clone = completed_count.clone();
117    ///     scope.spawn_with_hooks(
118    ///         async { /* ... */ },
119    ///         move || { count_clone.fetch_add(1, Ordering::SeqCst); },
120    ///         || { /* handle cancellation */ }
121    ///     );
122    ///
123    ///     // Give the task time to complete.
124    ///     tokio::time::sleep(std::time::Duration::from_millis(10)).await;
125    ///
126    ///     assert_eq!(completed_count.load(Ordering::SeqCst), 1);
127    ///     // scope is dropped here, and spawned tasks are cancelled.
128    /// }
129    /// ```
130    pub fn spawn_with_hooks<F, C, D, R>(&self, future: F, on_completion: C, on_cancellation: D)
131    where
132        F: Future<Output = R> + Send + 'static,
133        C: FnOnce() + Send + 'static,
134        D: FnOnce() + Send + 'static,
135        R: Send + 'static,
136    {
137        let token = self.token.clone();
138        self.tracker.spawn(async move {
139            match future.with_cancellation_token_owned(token).await {
140                Some(r) => {
141                    on_completion();
142                    Some(r)
143                }
144                None => {
145                    on_cancellation();
146                    None
147                }
148            }
149        });
150    }
151}
152
153impl Drop for Scope {
154    fn drop(&mut self) {
155        self.cancel();
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use std::future::pending;
162    use std::sync::Arc;
163    use std::sync::atomic::AtomicBool;
164    use std::sync::atomic::AtomicUsize;
165    use std::sync::atomic::Ordering;
166    use std::time::Duration;
167
168    use axum::Router;
169    use axum::body::Body;
170    use axum::http::Request;
171    use axum::http::StatusCode;
172    use axum::routing::get;
173    use tokio::sync::oneshot;
174    use tokio::time;
175    use tokio_util::sync::CancellationToken;
176    use tower::Service;
177
178    use super::*;
179
180    #[tokio::test]
181    async fn it_correctly_processes_spawned_panic() {
182        let scope = Scope::new();
183        let (tx, rx) = oneshot::channel::<()>();
184
185        let jh = scope.spawn(async move {
186            // This task will complete when cancelled.
187            // The oneshot sender will be dropped.
188            let _tx = tx;
189            panic!("to panic is human");
190        });
191
192        // Putting in a sleep here guarantees we get panic even if we try to drop because panic
193        // happens before the drop
194        time::sleep(Duration::from_millis(50)).await;
195
196        drop(scope);
197
198        let result = jh.await; // We expect an error because we panicked before we dropped
199
200        assert!(result.is_err());
201        // The receiver will get an error when the sender is dropped.
202        assert!(rx.await.is_err());
203    }
204
205    #[tokio::test]
206    async fn it_works() {
207        let scope = Scope::new();
208        let (tx, rx) = oneshot::channel::<()>();
209
210        scope.spawn(async move {
211            // This task will complete when cancelled.
212            // The oneshot sender will be dropped.
213            let _tx = tx;
214            pending::<()>().await;
215        });
216
217        drop(scope);
218
219        // The receiver will get an error when the sender is dropped.
220        assert!(rx.await.is_err());
221    }
222
223    #[tokio::test]
224    async fn test_scope_cancellation_with_tokio() {
225        let scope = Scope::new();
226        let started = Arc::new(AtomicBool::new(false));
227        let cancelled = Arc::new(AtomicBool::new(false));
228
229        let s_clone = started.clone();
230        let c_clone = cancelled.clone();
231
232        scope.spawn_with_hooks(
233            async move {
234                s_clone.store(true, Ordering::SeqCst);
235                // This task will complete when cancelled.
236                pending::<()>().await;
237            },
238            || (),
239            move || {
240                c_clone.store(true, Ordering::SeqCst);
241            },
242        );
243
244        // Give the spawned task a moment to start
245        time::sleep(Duration::from_millis(50)).await;
246        assert!(started.load(Ordering::SeqCst));
247
248        drop(scope);
249
250        // Give the cancellation a moment to propagate
251        time::sleep(Duration::from_millis(50)).await;
252        assert!(cancelled.load(Ordering::SeqCst));
253    }
254
255    #[tokio::test]
256    async fn test_scope_with_tower() {
257        let token = CancellationToken::new();
258
259        let handler = {
260            let token = token.clone();
261            move || async move {
262                let scope = Scope::new();
263                scope.spawn(async move {
264                    pending::<()>().await;
265                });
266                token.cancelled().await;
267            }
268        };
269
270        let mut service = Router::new().route("/", get(handler)).into_service();
271
272        let request = Request::new(Body::empty());
273
274        let response_future = service.call(request);
275
276        let response_handle = tokio::spawn(response_future);
277
278        // Give the handler time to start
279        time::sleep(Duration::from_millis(50)).await;
280
281        token.cancel();
282
283        // The handler should complete, dropping the scope, cancelling the task.
284        let response = response_handle.await.unwrap().unwrap();
285        assert_eq!(response.status(), StatusCode::OK);
286    }
287
288    #[tokio::test]
289    async fn test_spawn_with_hooks_completion() {
290        let scope = Scope::new();
291        let completed = Arc::new(AtomicBool::new(false));
292        let cancelled = Arc::new(AtomicBool::new(false));
293
294        let comp_clone = completed.clone();
295        let canc_clone = cancelled.clone();
296
297        scope.spawn_with_hooks(
298            async move {
299                // Task completes normally
300            },
301            move || {
302                comp_clone.store(true, Ordering::SeqCst);
303            },
304            move || {
305                canc_clone.store(true, Ordering::SeqCst);
306            },
307        );
308
309        // Give the spawned task a moment to complete
310        time::sleep(Duration::from_millis(50)).await;
311        assert!(completed.load(Ordering::SeqCst));
312        assert!(!cancelled.load(Ordering::SeqCst));
313    }
314
315    #[tokio::test]
316    async fn test_multiple_tasks() {
317        let scope = Scope::new();
318        let count = Arc::new(AtomicUsize::new(0));
319
320        for _ in 0..10 {
321            let c = count.clone();
322            scope.spawn_with_hooks(
323                async move {
324                    pending::<()>().await;
325                },
326                || (),
327                move || {
328                    c.fetch_add(1, Ordering::SeqCst);
329                },
330            );
331        }
332
333        // Give tasks time to start (though they just pend)
334        time::sleep(Duration::from_millis(50)).await;
335        assert_eq!(count.load(Ordering::SeqCst), 0);
336
337        drop(scope);
338
339        // Give cancellation time to propagate
340        time::sleep(Duration::from_millis(50)).await;
341        assert_eq!(count.load(Ordering::SeqCst), 10);
342    }
343
344    #[tokio::test]
345    async fn test_scope_manual_cancel() {
346        let scope = Scope::new();
347        let cancelled = Arc::new(AtomicBool::new(false));
348        let c_clone = cancelled.clone();
349
350        scope.spawn_with_hooks(
351            async move {
352                pending::<()>().await;
353            },
354            || (),
355            move || {
356                c_clone.store(true, Ordering::SeqCst);
357            },
358        );
359
360        scope.cancel();
361
362        // Give cancellation time to propagate
363        time::sleep(Duration::from_millis(50)).await;
364        assert!(cancelled.load(Ordering::SeqCst));
365    }
366
367    #[tokio::test]
368    async fn test_scope_clone_shares_lifecycle() {
369        let scope = Scope::new();
370        let scope_clone = scope.clone();
371        let cancelled = Arc::new(AtomicBool::new(false));
372        let c_clone = cancelled.clone();
373
374        scope_clone.spawn_with_hooks(
375            async move {
376                pending::<()>().await;
377            },
378            || (),
379            move || {
380                c_clone.store(true, Ordering::SeqCst);
381            },
382        );
383
384        // Dropping the original scope should cancel tasks in the clone
385        drop(scope);
386
387        // Give cancellation time to propagate
388        time::sleep(Duration::from_millis(50)).await;
389        assert!(cancelled.load(Ordering::SeqCst));
390    }
391
392    #[tokio::test]
393    async fn test_spawn_after_cancel() {
394        let scope = Scope::new();
395        scope.cancel();
396
397        let cancelled = Arc::new(AtomicBool::new(false));
398        let c_clone = cancelled.clone();
399
400        scope.spawn_with_hooks(
401            async move {
402                pending::<()>().await;
403            },
404            || (),
405            move || {
406                c_clone.store(true, Ordering::SeqCst);
407            },
408        );
409
410        // It should be cancelled immediately
411        time::sleep(Duration::from_millis(50)).await;
412        assert!(cancelled.load(Ordering::SeqCst));
413    }
414}