tokio_async_utils/
task_handle.rs

1use std::ops::Deref;
2
3use tokio::task::JoinHandle;
4
5/// TaskHandle is simple wrapper around Tokio task::JoinHandle that aborts tasks on Handle drop.
6/// The easiest way to obtain TaskHandle is importing TaskExt trait and running `to_task_handle()` on Tokio task::JoinHandle.
7///
8/// # Example
9/// ```rs
10/// let (tx, mut rx) = mpsc::channel(20);
11/// let handle =
12///     tokio::spawn(async move { while let Some(_) = rx.recv().await {} }).to_task_handle();
13///
14/// let r = tx.send(true).await;
15/// // drop handle so the inner task is aborted
16/// drop(handle);
17///
18/// // sadly seems like we need to wait so Tokio runtime has time to actually drop all variables
19/// sleep(Duration::from_millis(1)).await;
20/// let r = tx.send(false).await;
21/// assert!(r.is_err(), "'rx' along with task inside 'handle' should be dropped at this point so tx.send fails");
22/// ```
23#[derive(Debug)]
24pub struct TaskHandle<T>(pub JoinHandle<T>);
25
26impl<T> Deref for TaskHandle<T> {
27    type Target = JoinHandle<T>;
28
29    fn deref(&self) -> &Self::Target {
30        &self.0
31    }
32}
33
34pub trait TaskExt<T> {
35    fn to_task_handle(self) -> TaskHandle<T>;
36}
37
38impl<T> TaskExt<T> for JoinHandle<T> {
39    fn to_task_handle(self) -> TaskHandle<T> {
40        TaskHandle(self)
41    }
42}
43
44impl<T> Drop for TaskHandle<T> {
45    fn drop(&mut self) {
46        self.0.abort()
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use std::time::Duration;
53
54    use tokio::{sync::mpsc, time::sleep};
55
56    use super::*;
57
58    #[tokio::test]
59    async fn is_dropped_correctly() {
60        let (tx, mut rx) = mpsc::channel(20);
61        let handle =
62            tokio::spawn(async move { while let Some(_) = rx.recv().await {} }).to_task_handle();
63
64        let r = tx.send(true).await;
65        assert!(r.is_ok());
66        drop(handle);
67        // i guess we need to wait until tokio runtime drops inner task
68        sleep(Duration::from_millis(1)).await;
69        let r = tx.send(false).await;
70        assert!(r.is_err(), "expected error, but got ok");
71    }
72}