veilid_tools/
must_join_handle.rs

1use super::*;
2
3use core::task::{Context, Poll};
4
5#[derive(Debug)]
6pub struct MustJoinHandle<T> {
7    join_handle: Option<LowLevelJoinHandle<T>>,
8    completed: bool,
9}
10
11impl<T> MustJoinHandle<T> {
12    pub fn new(join_handle: LowLevelJoinHandle<T>) -> Self {
13        Self {
14            join_handle: Some(join_handle),
15            completed: false,
16        }
17    }
18
19    pub fn detach(mut self) {
20        cfg_if! {
21            if #[cfg(feature="rt-async-std")] {
22                self.join_handle = None;
23            } else if #[cfg(feature="rt-tokio")] {
24                self.join_handle = None;
25            } else if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
26                if let Some(jh) = self.join_handle.take() {
27                    jh.detach();
28                }
29            } else {
30                compile_error!("needs executor implementation");
31            }
32        }
33        self.completed = true;
34    }
35
36    #[allow(unused_mut)]
37    pub async fn abort(mut self) {
38        if !self.completed {
39            cfg_if! {
40                if #[cfg(feature="rt-async-std")] {
41                    if let Some(jh) = self.join_handle.take() {
42                        jh.cancel().await;
43                        self.completed = true;
44                    }
45                } else if #[cfg(feature="rt-tokio")] {
46                    if let Some(jh) = self.join_handle.take() {
47                        jh.abort();
48                        let _ = jh.await;
49                        self.completed = true;
50                    }
51                } else if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
52                    drop(self.join_handle.take());
53                    self.completed = true;
54                } else {
55                    compile_error!("needs executor implementation");
56                }
57
58            }
59        }
60    }
61}
62
63impl<T> Drop for MustJoinHandle<T> {
64    fn drop(&mut self) {
65        // panic if we haven't completed
66        if !self.completed {
67            panic!("MustJoinHandle was not completed upon drop. Add cooperative cancellation where appropriate to ensure this is completed before drop.")
68        }
69    }
70}
71
72impl<T: 'static> Future for MustJoinHandle<T> {
73    type Output = T;
74
75    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
76        match Pin::new(self.join_handle.as_mut().unwrap()).poll(cx) {
77            Poll::Ready(t) => {
78                if self.completed {
79                    panic!("should not poll completed join handle");
80                }
81                self.completed = true;
82                cfg_if! {
83                    if #[cfg(feature="rt-async-std")] {
84                        Poll::Ready(t)
85                    } else if #[cfg(feature="rt-tokio")] {
86                        match t {
87                            Ok(t) => Poll::Ready(t),
88                            Err(e) => {
89                                if e.is_panic() {
90                                    // Resume the panic on the main task
91                                    std::panic::resume_unwind(e.into_panic());
92                                } else {
93                                    panic!("join error was not a panic, should not poll after abort");
94                                }
95                            }
96                        }
97                    } else if #[cfg(all(target_arch = "wasm32", target_os = "unknown"))] {
98                        Poll::Ready(t)
99                    } else {
100                        compile_error!("needs executor implementation");
101                    }
102                }
103            }
104            Poll::Pending => Poll::Pending,
105        }
106    }
107}