tokio_shield/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3#![warn(clippy::dbg_macro, clippy::use_debug, clippy::todo)]
4#![warn(missing_docs, missing_debug_implementations)]
5
6use std::future::Future;
7
8use futures::{future::Map, FutureExt};
9use tokio::task::{JoinError, JoinHandle};
10
11/// Adds methods to futures to prevent them from being aborted.
12pub trait Shield
13where
14    Self: Future + Send + 'static,
15    Self::Output: Send,
16{
17    /// The [`Future`] returned from [`shield()`](Self::shield).
18    type ShieldFuture: Future<Output = Self::Output>;
19    /// The [`Future`] returned from [`try_shield()`](Self::try_shield).
20    type TryShieldFuture: Future<Output = Result<Self::Output, Self::TryShieldError>>;
21    /// The error returned from [`try_shield()`](Self::try_shield).
22    type TryShieldError;
23
24    /// Prevent this future from being aborted by wrapping it in a task.
25    ///
26    /// `future.shield().await` is equivalent to
27    /// `future.try_shield().await.unwrap()`.
28    ///
29    /// # Panics
30    /// This function panics if awaiting the spawned task fails.
31    fn shield(self) -> Self::ShieldFuture;
32
33    /// Prevent this future from being aborted by wrapping it in a task.
34    ///
35    /// Since the task is created using [`tokio::spawn()`], execution of this
36    /// future starts immediately.
37    fn try_shield(self) -> Self::TryShieldFuture;
38}
39
40impl<T> Shield for T
41where
42    T: Future + Send + 'static,
43    T::Output: Send,
44{
45    type ShieldFuture = Map<JoinHandle<T::Output>, fn(Result<T::Output, JoinError>) -> T::Output>;
46    type TryShieldFuture = JoinHandle<T::Output>;
47    type TryShieldError = JoinError;
48
49    #[inline]
50    fn shield(self) -> Self::ShieldFuture {
51        self.try_shield().map(Result::unwrap)
52    }
53
54    #[inline]
55    fn try_shield(self) -> Self::TryShieldFuture {
56        tokio::spawn(self)
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use std::{sync::Arc, time::Duration};
63
64    use tokio::{sync::Mutex, time::sleep};
65
66    use super::*;
67
68    #[tokio::test]
69    async fn returns_value() {
70        let result = async { 42 }.shield().await;
71        assert_eq!(result, 42);
72    }
73
74    #[tokio::test]
75    async fn survives_cancel() {
76        let x = Arc::new(Mutex::new(false));
77        let y = Arc::clone(&x);
78        let task = tokio::spawn(
79            async move {
80                sleep(Duration::from_millis(100)).await;
81                *y.lock().await = true;
82            }
83            .shield(),
84        );
85        sleep(Duration::from_millis(50)).await;
86        task.abort();
87        sleep(Duration::from_millis(100)).await;
88        assert!(*x.lock().await);
89    }
90
91    #[tokio::test]
92    async fn inner_panic() {
93        async {
94            panic!();
95        }
96        .try_shield()
97        .await
98        .unwrap_err();
99    }
100}