1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3#![warn(clippy::dbg_macro, clippy::use_debug, clippy::todo)]
4#![warn(missing_docs, missing_debug_implementations)]
5
6use std::future::Future;
7
8use futures::{future::Map, FutureExt};
9use tokio::task::{JoinError, JoinHandle};
10
11pub trait Shield
13where
14 Self: Future + Send + 'static,
15 Self::Output: Send,
16{
17 type ShieldFuture: Future<Output = Self::Output>;
19 type TryShieldFuture: Future<Output = Result<Self::Output, Self::TryShieldError>>;
21 type TryShieldError;
23
24 fn shield(self) -> Self::ShieldFuture;
32
33 fn try_shield(self) -> Self::TryShieldFuture;
38}
39
40impl<T> Shield for T
41where
42 T: Future + Send + 'static,
43 T::Output: Send,
44{
45 type ShieldFuture = Map<JoinHandle<T::Output>, fn(Result<T::Output, JoinError>) -> T::Output>;
46 type TryShieldFuture = JoinHandle<T::Output>;
47 type TryShieldError = JoinError;
48
49 #[inline]
50 fn shield(self) -> Self::ShieldFuture {
51 self.try_shield().map(Result::unwrap)
52 }
53
54 #[inline]
55 fn try_shield(self) -> Self::TryShieldFuture {
56 tokio::spawn(self)
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use std::{sync::Arc, time::Duration};
63
64 use tokio::{sync::Mutex, time::sleep};
65
66 use super::*;
67
68 #[tokio::test]
69 async fn returns_value() {
70 let result = async { 42 }.shield().await;
71 assert_eq!(result, 42);
72 }
73
74 #[tokio::test]
75 async fn survives_cancel() {
76 let x = Arc::new(Mutex::new(false));
77 let y = Arc::clone(&x);
78 let task = tokio::spawn(
79 async move {
80 sleep(Duration::from_millis(100)).await;
81 *y.lock().await = true;
82 }
83 .shield(),
84 );
85 sleep(Duration::from_millis(50)).await;
86 task.abort();
87 sleep(Duration::from_millis(100)).await;
88 assert!(*x.lock().await);
89 }
90
91 #[tokio::test]
92 async fn inner_panic() {
93 async {
94 panic!();
95 }
96 .try_shield()
97 .await
98 .unwrap_err();
99 }
100}