tokio_interruptible_future/
lib.rs

1/// Easily interrupt async code in given check points. It's useful to interrupt threads/fibers.
2/// TODO: Documentation comments.
3
4use std::{fmt, future::Future};
5use async_channel::Receiver;
6
7#[derive(Debug, PartialEq, Eq)]
8pub struct InterruptError { }
9
10impl InterruptError {
11    pub fn new() -> Self {
12        Self { }
13    }
14}
15
16impl std::error::Error for InterruptError { }
17
18impl fmt::Display for InterruptError {
19    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20        write!(f, "Async fiber interrupted.")
21    }
22}
23
24/// You usually use `interruptible` or `interruptible_sendable` instead.
25pub async fn interruptible_straight<T, E: From<InterruptError>>(
26    rx: Receiver<()>,
27    f: impl Future<Output=Result<T, E>>
28) -> Result<T, E>
29{
30    tokio::select!{
31        r = f => r,
32        _ = async { // shorten lock lifetime
33            let _ = rx.recv().await;
34        } => Err(InterruptError::new().into()),
35    }
36}
37
38pub async fn interruptible<T, E: From<InterruptError>>(
39    rx: Receiver<()>,
40    f: impl Future<Output=Result<T, E>> + Unpin
41) -> Result<T, E>
42{
43    interruptible_straight(rx, f).await
44}
45
46pub async fn interruptible_sendable<T, E: From<InterruptError>>(
47    rx: Receiver<()>,
48    f: impl Future<Output=Result<T, E>> + Send + Unpin
49) -> Result<T, E>
50{
51    interruptible_straight(rx, f).await
52}
53
54/// TODO: More tests.
55#[cfg(test)]
56mod tests {
57    use std::future::Future;
58    use async_channel::bounded;
59    use futures::executor::block_on;
60
61    use crate::{InterruptError, interruptible, interruptible_sendable};
62
63    #[derive(Debug, PartialEq, Eq)]
64    struct AnotherError { }
65    impl AnotherError {
66        pub fn new() -> Self {
67            return Self { }
68        }
69    }
70    #[derive(Debug, PartialEq, Eq)]
71    enum MyError {
72        Interrupted(InterruptError),
73        Another(AnotherError)
74    }
75    impl From<InterruptError> for MyError {
76        fn from(value: InterruptError) -> Self {
77            Self::Interrupted(value)
78        }
79    }
80    impl From<AnotherError> for MyError {
81        fn from(value: AnotherError) -> Self {
82            Self::Another(value)
83        }
84    }
85    struct Test {
86    }
87    impl Test {
88        pub fn new() -> Self {
89            Self {
90            }
91        }
92        pub async fn g(self) -> Result<u8, MyError> {
93            let (_tx, rx) = bounded(1);
94
95            interruptible(rx, Box::pin(async move {
96                Ok(123)
97            })).await
98        }
99        pub async fn h(self) -> Result<u8, MyError> {
100            let (_tx, rx) = bounded(1);
101
102            interruptible(rx, Box::pin(async move {
103                Err(AnotherError::new().into())
104            })).await
105        }
106    }
107
108    #[test]
109    fn interrupted() {
110        let test = Test::new();
111        block_on(async {
112            assert_eq!(test.g().await, Ok(123));
113        });
114        let test = Test::new();
115        block_on(async {
116            assert_eq!(test.h().await, Err(AnotherError::new().into()));
117        });
118    }
119
120    #[test]
121    fn check_interruptible_sendable() {
122        let (_tx, rx) = bounded(1);
123
124        // Check that `interruptible_sendable(...)` is a `Send` future.
125        let _: &(dyn Future<Output = Result<i32, InterruptError>> + Send) = &interruptible_sendable(rx, Box::pin(async move {
126            Ok(123)
127        }));
128    }
129}