1use core::{
2 any::Any,
3 fmt,
4 future::{pending, Future},
5 mem::{self, ManuallyDrop},
6 panic::{AssertUnwindSafe, UnwindSafe},
7 pin::Pin,
8 ptr,
9 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
10};
11
12pub trait Select: Sized {
14 fn select<Fut>(self, other: Fut) -> SelectFuture<Self, Fut>
15 where
16 Fut: Future;
17}
18
19impl<F> Select for F
20where
21 F: Future,
22{
23 #[inline]
24 fn select<Fut>(self, other: Fut) -> SelectFuture<Self, Fut>
25 where
26 Fut: Future,
27 {
28 SelectFuture {
29 fut1: self,
30 fut2: other,
31 }
32 }
33}
34
35pub struct SelectFuture<Fut1, Fut2> {
36 fut1: Fut1,
37 fut2: Fut2,
38}
39
40impl<Fut1, Fut2> Future for SelectFuture<Fut1, Fut2>
41where
42 Fut1: Future,
43 Fut2: Future,
44{
45 type Output = SelectOutput<Fut1::Output, Fut2::Output>;
46
47 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
48 unsafe {
52 let Self { fut1, fut2 } = self.get_unchecked_mut();
53
54 if let Poll::Ready(a) = Pin::new_unchecked(fut1).poll(cx) {
55 return Poll::Ready(SelectOutput::A(a));
56 }
57
58 Pin::new_unchecked(fut2).poll(cx).map(SelectOutput::B)
59 }
60 }
61}
62
63pub enum SelectOutput<A, B> {
64 A(A),
65 B(B),
66}
67
68impl<A, B> fmt::Debug for SelectOutput<A, B> {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 match self {
71 Self::A(_) => f.debug_struct("SelectOutput::A(..)"),
72 Self::B(_) => f.debug_struct("SelectOutput::B(..)"),
73 }
74 .finish()
75 }
76}
77
78pub trait NowOrPanic: Sized {
93 type Output;
94
95 fn now_or_panic(&mut self) -> Self::Output;
96}
97
98impl<F> NowOrPanic for F
99where
100 F: Future,
101{
102 type Output = F::Output;
103
104 fn now_or_panic(&mut self) -> Self::Output {
105 let waker = noop_waker();
106 let cx = &mut Context::from_waker(&waker);
107
108 match unsafe { Pin::new_unchecked(self).poll(cx) } {
111 Poll::Ready(ret) => ret,
112 Poll::Pending => panic!("Future can not be polled to complete"),
113 }
114 }
115}
116
117pub struct CatchUnwind<Fut> {
119 fut: Fut,
120}
121
122impl<Fut> CatchUnwind<Fut>
123where
124 Fut: Future + UnwindSafe,
125{
126 #[inline]
127 pub const fn new(fut: Fut) -> Self {
128 Self { fut }
129 }
130}
131
132impl<Fut> Future for CatchUnwind<Fut>
133where
134 Fut: Future + UnwindSafe,
135{
136 type Output = Result<Fut::Output, Box<dyn Any + Send>>;
137
138 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139 std::panic::catch_unwind(AssertUnwindSafe(|| {
140 unsafe { self.map_unchecked_mut(|this| &mut this.fut) }.poll(cx)
143 }))?
144 .map(Ok)
145 }
146}
147
148pub struct ReusableLocalBoxFuture<'a, T> {
151 boxed: Pin<Box<dyn Future<Output = T> + 'a>>,
152}
153
154impl<'a, T> ReusableLocalBoxFuture<'a, T> {
155 pub fn new<F>(future: F) -> Self
156 where
157 F: Future<Output = T> + 'a,
158 {
159 Self {
160 boxed: Box::pin(future),
161 }
162 }
163
164 pub fn set<F>(&mut self, future: F)
165 where
166 F: Future<Output = T> + 'a,
167 {
168 if let Err(future) = self.try_set(future) {
169 *self = Self::new(future);
170 }
171 }
172
173 fn try_set<F>(&mut self, future: F) -> Result<(), F>
174 where
175 F: Future<Output = T> + 'a,
176 {
177 #[inline(always)]
182 fn real_try_set<'a, F>(this: &mut ReusableLocalBoxFuture<'a, F::Output>, future: F) -> Result<(), F>
183 where
184 F: Future + 'a,
185 {
186 let boxed = mem::replace(&mut this.boxed, Box::pin(pending()));
188 reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed))
189 }
190
191 real_try_set(self, future)
192 }
193
194 pub fn get_pin(&mut self) -> Pin<&mut dyn Future<Output = T>> {
195 self.boxed.as_mut()
196 }
197}
198
199impl<T> fmt::Debug for ReusableLocalBoxFuture<'_, T> {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.debug_struct("ReusableLocalBoxFuture").finish()
202 }
203}
204
205fn reuse_pin_box<T: ?Sized, U, O, F>(boxed: Pin<Box<T>>, new_value: U, callback: F) -> Result<O, U>
206where
207 F: FnOnce(Box<U>) -> O,
208{
209 use std::alloc::Layout;
210
211 let layout = Layout::for_value::<T>(&*boxed);
212 if layout != Layout::new::<U>() {
213 return Err(new_value);
214 }
215
216 let raw: *mut T = Box::into_raw(unsafe { Pin::into_inner_unchecked(boxed) });
219
220 let guard = CallOnDrop::new(|| {
223 let raw: *mut U = raw.cast::<U>();
224 unsafe { raw.write(new_value) };
225
226 let boxed = unsafe { Box::from_raw(raw) };
231
232 callback(boxed)
233 });
234
235 unsafe { ptr::drop_in_place(raw) };
237
238 Ok(guard.call())
240}
241
242struct CallOnDrop<O, F: FnOnce() -> O> {
243 f: ManuallyDrop<F>,
244}
245
246impl<O, F: FnOnce() -> O> CallOnDrop<O, F> {
247 fn new(f: F) -> Self {
248 let f = ManuallyDrop::new(f);
249 Self { f }
250 }
251 fn call(self) -> O {
252 let mut this = ManuallyDrop::new(self);
253 let f = unsafe { ManuallyDrop::take(&mut this.f) };
254 f()
255 }
256}
257
258impl<O, F: FnOnce() -> O> Drop for CallOnDrop<O, F> {
259 fn drop(&mut self) {
260 let f = unsafe { ManuallyDrop::take(&mut self.f) };
261 f();
262 }
263}
264
265const TBL: RawWakerVTable = RawWakerVTable::new(|_| raw_waker(), |_| {}, |_| {}, |_| {});
266
267const fn raw_waker() -> RawWaker {
268 RawWaker::new(ptr::null(), &TBL)
269}
270
271pub(crate) fn noop_waker() -> Waker {
272 unsafe { Waker::from_raw(raw_waker()) }
275}
276
277#[cfg(test)]
278mod test {
279 use core::{future::poll_fn, pin::pin};
280
281 use super::*;
282
283 #[test]
284 fn test_select() {
285 let fut = async {
286 poll_fn(|cx| {
287 cx.waker().wake_by_ref();
288 Poll::<()>::Pending
289 })
290 .await;
291 123
292 }
293 .select(async { 321 });
294
295 matches!(pin!(fut).now_or_panic(), SelectOutput::B(321));
296 }
297}