xitca_unsafe_collection/
futures.rs

1use core::{
2    any::Any,
3    fmt,
4    future::{pending, Future},
5    mem::{self, ManuallyDrop},
6    panic::{AssertUnwindSafe, UnwindSafe},
7    pin::Pin,
8    ptr,
9    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
10};
11
12/// Biased select always prioritize polling Self.
13pub trait Select: Sized {
14    fn select<Fut>(self, other: Fut) -> SelectFuture<Self, Fut>
15    where
16        Fut: Future;
17}
18
19impl<F> Select for F
20where
21    F: Future,
22{
23    #[inline]
24    fn select<Fut>(self, other: Fut) -> SelectFuture<Self, Fut>
25    where
26        Fut: Future,
27    {
28        SelectFuture {
29            fut1: self,
30            fut2: other,
31        }
32    }
33}
34
35pub struct SelectFuture<Fut1, Fut2> {
36    fut1: Fut1,
37    fut2: Fut2,
38}
39
40impl<Fut1, Fut2> Future for SelectFuture<Fut1, Fut2>
41where
42    Fut1: Future,
43    Fut2: Future,
44{
45    type Output = SelectOutput<Fut1::Output, Fut2::Output>;
46
47    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
48        // SAFETY:
49        // This is safe as Fut1 and Fut2 do not move.
50        // They both accessed only through a single Pin<&mut _>.
51        unsafe {
52            let Self { fut1, fut2 } = self.get_unchecked_mut();
53
54            if let Poll::Ready(a) = Pin::new_unchecked(fut1).poll(cx) {
55                return Poll::Ready(SelectOutput::A(a));
56            }
57
58            Pin::new_unchecked(fut2).poll(cx).map(SelectOutput::B)
59        }
60    }
61}
62
63pub enum SelectOutput<A, B> {
64    A(A),
65    B(B),
66}
67
68impl<A, B> fmt::Debug for SelectOutput<A, B> {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        match self {
71            Self::A(_) => f.debug_struct("SelectOutput::A(..)"),
72            Self::B(_) => f.debug_struct("SelectOutput::B(..)"),
73        }
74        .finish()
75    }
76}
77
78/// Trait for trying to complete future in a single poll. And panic when fail to do so.
79///
80/// Useful for async code that only expose async API but not really doing stuff in async manner.
81///
82/// # Examples
83/// ```rust
84/// # use xitca_unsafe_collection::futures::NowOrPanic;
85///
86/// async fn looks_like() {
87///     // nothing async really happened.
88/// }
89///
90/// looks_like().now_or_panic();
91/// ```
92pub trait NowOrPanic: Sized {
93    type Output;
94
95    fn now_or_panic(&mut self) -> Self::Output;
96}
97
98impl<F> NowOrPanic for F
99where
100    F: Future,
101{
102    type Output = F::Output;
103
104    fn now_or_panic(&mut self) -> Self::Output {
105        let waker = noop_waker();
106        let cx = &mut Context::from_waker(&waker);
107
108        // SAFETY:
109        // self is not moved.
110        match unsafe { Pin::new_unchecked(self).poll(cx) } {
111            Poll::Ready(ret) => ret,
112            Poll::Pending => panic!("Future can not be polled to complete"),
113        }
114    }
115}
116
117/// Future for the catch unwind async block.
118pub struct CatchUnwind<Fut> {
119    fut: Fut,
120}
121
122impl<Fut> CatchUnwind<Fut>
123where
124    Fut: Future + UnwindSafe,
125{
126    #[inline]
127    pub const fn new(fut: Fut) -> Self {
128        Self { fut }
129    }
130}
131
132impl<Fut> Future for CatchUnwind<Fut>
133where
134    Fut: Future + UnwindSafe,
135{
136    type Output = Result<Fut::Output, Box<dyn Any + Send>>;
137
138    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139        std::panic::catch_unwind(AssertUnwindSafe(|| {
140            // SAFETY:
141            // fut is not moved.
142            unsafe { self.map_unchecked_mut(|this| &mut this.fut) }.poll(cx)
143        }))?
144        .map(Ok)
145    }
146}
147
148/// Copied from [ReusableBoxFuture](https://docs.rs/tokio-util/latest/tokio_util/sync/struct.ReusableBoxFuture.html).
149/// But without `Send` bound.
150pub struct ReusableLocalBoxFuture<'a, T> {
151    boxed: Pin<Box<dyn Future<Output = T> + 'a>>,
152}
153
154impl<'a, T> ReusableLocalBoxFuture<'a, T> {
155    pub fn new<F>(future: F) -> Self
156    where
157        F: Future<Output = T> + 'a,
158    {
159        Self {
160            boxed: Box::pin(future),
161        }
162    }
163
164    pub fn set<F>(&mut self, future: F)
165    where
166        F: Future<Output = T> + 'a,
167    {
168        if let Err(future) = self.try_set(future) {
169            *self = Self::new(future);
170        }
171    }
172
173    fn try_set<F>(&mut self, future: F) -> Result<(), F>
174    where
175        F: Future<Output = T> + 'a,
176    {
177        // If we try to inline the contents of this function, the type checker complains because
178        // the bound `T: 'a` is not satisfied in the call to `pending()`. But by putting it in an
179        // inner function that doesn't have `T` as a generic parameter, we implicitly get the bound
180        // `F::Output: 'a` transitively through `F: 'a`, allowing us to call `pending()`.
181        #[inline(always)]
182        fn real_try_set<'a, F>(this: &mut ReusableLocalBoxFuture<'a, F::Output>, future: F) -> Result<(), F>
183        where
184            F: Future + 'a,
185        {
186            // future::Pending<T> is a ZST so this never allocates.
187            let boxed = mem::replace(&mut this.boxed, Box::pin(pending()));
188            reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed))
189        }
190
191        real_try_set(self, future)
192    }
193
194    pub fn get_pin(&mut self) -> Pin<&mut dyn Future<Output = T>> {
195        self.boxed.as_mut()
196    }
197}
198
199impl<T> fmt::Debug for ReusableLocalBoxFuture<'_, T> {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        f.debug_struct("ReusableLocalBoxFuture").finish()
202    }
203}
204
205fn reuse_pin_box<T: ?Sized, U, O, F>(boxed: Pin<Box<T>>, new_value: U, callback: F) -> Result<O, U>
206where
207    F: FnOnce(Box<U>) -> O,
208{
209    use std::alloc::Layout;
210
211    let layout = Layout::for_value::<T>(&*boxed);
212    if layout != Layout::new::<U>() {
213        return Err(new_value);
214    }
215
216    // SAFETY: We don't ever construct a non-pinned reference to the old `T` from now on, and we
217    // always drop the `T`.
218    let raw: *mut T = Box::into_raw(unsafe { Pin::into_inner_unchecked(boxed) });
219
220    // When dropping the old value panics, we still want to call `callback` — so move the rest of
221    // the code into a guard type.
222    let guard = CallOnDrop::new(|| {
223        let raw: *mut U = raw.cast::<U>();
224        unsafe { raw.write(new_value) };
225
226        // SAFETY:
227        // - `T` and `U` have the same layout.
228        // - `raw` comes from a `Box` that uses the same allocator as this one.
229        // - `raw` points to a valid instance of `U` (we just wrote it in).
230        let boxed = unsafe { Box::from_raw(raw) };
231
232        callback(boxed)
233    });
234
235    // Drop the old value.
236    unsafe { ptr::drop_in_place(raw) };
237
238    // Run the rest of the code.
239    Ok(guard.call())
240}
241
242struct CallOnDrop<O, F: FnOnce() -> O> {
243    f: ManuallyDrop<F>,
244}
245
246impl<O, F: FnOnce() -> O> CallOnDrop<O, F> {
247    fn new(f: F) -> Self {
248        let f = ManuallyDrop::new(f);
249        Self { f }
250    }
251    fn call(self) -> O {
252        let mut this = ManuallyDrop::new(self);
253        let f = unsafe { ManuallyDrop::take(&mut this.f) };
254        f()
255    }
256}
257
258impl<O, F: FnOnce() -> O> Drop for CallOnDrop<O, F> {
259    fn drop(&mut self) {
260        let f = unsafe { ManuallyDrop::take(&mut self.f) };
261        f();
262    }
263}
264
265const TBL: RawWakerVTable = RawWakerVTable::new(|_| raw_waker(), |_| {}, |_| {}, |_| {});
266
267const fn raw_waker() -> RawWaker {
268    RawWaker::new(ptr::null(), &TBL)
269}
270
271pub(crate) fn noop_waker() -> Waker {
272    // SAFETY:
273    // no op waker uphold all the rules of RawWaker and RawWakerVTable
274    unsafe { Waker::from_raw(raw_waker()) }
275}
276
277#[cfg(test)]
278mod test {
279    use core::{future::poll_fn, pin::pin};
280
281    use super::*;
282
283    #[test]
284    fn test_select() {
285        let fut = async {
286            poll_fn(|cx| {
287                cx.waker().wake_by_ref();
288                Poll::<()>::Pending
289            })
290            .await;
291            123
292        }
293        .select(async { 321 });
294
295        matches!(pin!(fut).now_or_panic(), SelectOutput::B(321));
296    }
297}