tycho_util/futures/
shared.rs

1use std::cell::UnsafeCell;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, Weak};
6use std::task::{Context, Poll};
7
8use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
9
10#[must_use = "futures do nothing unless you `.await` or poll them"]
11pub struct Shared<Fut: Future> {
12    inner: Option<Arc<Inner<Fut>>>,
13    permit_fut: Option<SyncBoxFuture<Result<OwnedSemaphorePermit, AcquireError>>>,
14    permit: Option<OwnedSemaphorePermit>,
15}
16
17type SyncBoxFuture<T> = Pin<Box<dyn Future<Output = T> + Sync + Send + 'static>>;
18
19impl<Fut: Future> Clone for Shared<Fut> {
20    fn clone(&self) -> Self {
21        Self {
22            inner: self.inner.clone(),
23            permit_fut: None,
24            permit: None,
25        }
26    }
27}
28
29impl<Fut: Future> Shared<Fut> {
30    pub fn new(future: Fut) -> Self {
31        let semaphore = Arc::new(Semaphore::new(1));
32        let inner = Arc::new(Inner {
33            state: AtomicUsize::new(POLLING),
34            future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
35            semaphore,
36        });
37
38        Self {
39            inner: Some(inner),
40            permit_fut: None,
41            permit: None,
42        }
43    }
44
45    pub fn weak_future(&self) -> Option<WeakShared<Fut>> {
46        self.inner.as_ref().map(|inner| WeakShared {
47            inner: Some(Arc::downgrade(inner)),
48            permit_fut: None,
49            permit: None,
50        })
51    }
52
53    pub fn downgrade(&self) -> Option<WeakSharedHandle<Fut>> {
54        self.inner
55            .as_ref()
56            .map(|inner| WeakSharedHandle(Arc::downgrade(inner)))
57    }
58
59    /// Drops the future, returning whether it was the last instance.
60    pub fn consume(mut self) -> bool {
61        self.inner
62            .take()
63            .map(|inner| Arc::into_inner(inner).is_some())
64            .unwrap_or_default()
65    }
66}
67
68fn poll_impl<'cx, Fut>(
69    this_inner: &mut Option<Arc<Inner<Fut>>>,
70    this_permit_fut: &mut Option<SyncBoxFuture<Result<OwnedSemaphorePermit, AcquireError>>>,
71    this_permit: &mut Option<OwnedSemaphorePermit>,
72    cx: &mut Context<'cx>,
73) -> Poll<(Fut::Output, bool)>
74where
75    Fut: Future,
76    Fut::Output: Clone,
77{
78    let inner = this_inner
79        .take()
80        .expect("Shared future polled again after completion");
81
82    // Fast path for when the wrapped future has already completed
83    if inner.state.load(Ordering::Acquire) == COMPLETE {
84        // Safety: We're in the COMPLETE state
85        return unsafe { Poll::Ready(inner.take_or_clone_output()) };
86    }
87
88    if this_permit.is_none() {
89        *this_permit = Some('permit: {
90            // Poll semaphore future
91            let permit_fut = if let Some(fut) = this_permit_fut.as_mut() {
92                fut
93            } else {
94                // Avoid allocations completely if we can grab a permit immediately
95                match Arc::clone(&inner.semaphore).try_acquire_owned() {
96                    Ok(permit) => break 'permit permit,
97                    Err(TryAcquireError::NoPermits) => {}
98                    // NOTE: We don't expect the semaphore to be closed
99                    Err(TryAcquireError::Closed) => unreachable!(),
100                }
101
102                let next_fut = Arc::clone(&inner.semaphore).acquire_owned();
103                this_permit_fut.get_or_insert(Box::pin(next_fut))
104            };
105
106            // Acquire a permit to poll the inner future
107            match permit_fut.as_mut().poll(cx) {
108                Poll::Pending => {
109                    *this_inner = Some(inner);
110                    return Poll::Pending;
111                }
112                Poll::Ready(Ok(permit)) => {
113                    // Reset the permit future as we don't need it anymore
114                    *this_permit_fut = None;
115                    permit
116                }
117                // NOTE: We don't expect the semaphore to be closed
118                Poll::Ready(Err(_e)) => unreachable!(),
119            }
120        });
121    }
122
123    assert!(this_permit_fut.is_none(), "permit already acquired");
124
125    match inner.state.load(Ordering::Acquire) {
126        COMPLETE => {
127            // SAFETY: We're in the COMPLETE state
128            return unsafe { Poll::Ready(inner.take_or_clone_output()) };
129        }
130        POISONED => panic!("inner future panicked during poll"),
131        _ => {}
132    }
133
134    // Create poison guard
135    struct Reset<'a> {
136        state: &'a AtomicUsize,
137        did_not_panic: bool,
138    }
139
140    impl Drop for Reset<'_> {
141        fn drop(&mut self) {
142            if !self.did_not_panic {
143                self.state.store(POISONED, Ordering::Release);
144            }
145        }
146    }
147
148    let mut reset = Reset {
149        state: &inner.state,
150        did_not_panic: false,
151    };
152
153    let output = {
154        // SAFETY: We are now a sole owner of the permit to poll the inner future
155        let future = unsafe {
156            match &mut *inner.future_or_output.get() {
157                FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
158                FutureOrOutput::Output(_) => unreachable!(),
159            }
160        };
161
162        let poll_result = future.poll(cx);
163        reset.did_not_panic = true;
164
165        match poll_result {
166            Poll::Pending => {
167                drop(reset); // Make borrow checker happy
168                *this_inner = Some(inner);
169                return Poll::Pending;
170            }
171            Poll::Ready(output) => output,
172        }
173    };
174
175    unsafe {
176        *inner.future_or_output.get() = FutureOrOutput::Output(output);
177    }
178
179    inner.state.store(COMPLETE, Ordering::Release);
180
181    drop(reset); // Make borrow checker happy
182
183    // permit gets dropped because this future is consumed in exchange for result
184
185    // SAFETY: We're in the COMPLETE state
186    unsafe { Poll::Ready(inner.take_or_clone_output()) }
187}
188
189impl<Fut> Future for Shared<Fut>
190where
191    Fut: Future,
192    Fut::Output: Clone,
193{
194    type Output = (Fut::Output, bool);
195
196    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
197        let Shared {
198            inner,
199            permit_fut,
200            permit,
201        } = &mut *self;
202
203        poll_impl(inner, permit_fut, permit, cx)
204    }
205}
206
207/// A future that preserves its place in wait queue but does not own a shared future.
208/// Use [`WeakSharedHandle`] if you want to poll an upgraded future and only pass a weak ref around.
209#[must_use = "futures do nothing unless you `.await` or poll them"]
210pub struct WeakShared<Fut: Future> {
211    inner: Option<Weak<Inner<Fut>>>,
212    permit_fut: Option<SyncBoxFuture<Result<OwnedSemaphorePermit, AcquireError>>>,
213    permit: Option<OwnedSemaphorePermit>,
214}
215
216impl<Fut> Future for WeakShared<Fut>
217where
218    Fut: Future,
219    Fut::Output: Clone,
220{
221    type Output = Option<(Fut::Output, bool)>;
222
223    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
224        let WeakShared {
225            inner,
226            permit_fut,
227            permit,
228        } = &mut *self;
229
230        let weak_inner = inner
231            .take()
232            .expect("Weak shared future polled again after completion");
233
234        let mut strong_inner = weak_inner.upgrade();
235
236        if strong_inner.is_none() {
237            return Poll::Ready(None);
238        };
239
240        let poll_result = poll_impl(&mut strong_inner, permit_fut, permit, cx);
241
242        *inner = strong_inner.is_some().then_some(weak_inner);
243
244        poll_result.map(Some)
245    }
246}
247
248/// A handle can be upgraded to a shared future, but cannot be directly awaited.
249/// Use [`WeakShared`] if you want to poll without an upgrade.
250#[repr(transparent)]
251pub struct WeakSharedHandle<Fut: Future>(Weak<Inner<Fut>>);
252
253impl<Fut: Future> WeakSharedHandle<Fut> {
254    pub fn upgrade(&self) -> Option<Shared<Fut>> {
255        self.0.upgrade().map(|inner| Shared {
256            inner: Some(inner),
257            permit_fut: None,
258            permit: None,
259        })
260    }
261
262    pub fn strong_count(&self) -> usize {
263        self.0.strong_count()
264    }
265}
266
267struct Inner<Fut: Future> {
268    state: AtomicUsize,
269    future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
270    semaphore: Arc<Semaphore>,
271}
272
273impl<Fut> Inner<Fut>
274where
275    Fut: Future,
276    Fut::Output: Clone,
277{
278    /// Safety: callers must first ensure that `inner.state`
279    /// is `COMPLETE`
280    unsafe fn take_or_clone_output(self: Arc<Self>) -> (Fut::Output, bool) {
281        match Arc::try_unwrap(self) {
282            Ok(inner) => match inner.future_or_output.into_inner() {
283                FutureOrOutput::Output(item) => (item, true),
284                FutureOrOutput::Future(_) => unreachable!(),
285            },
286            Err(inner) => match unsafe { &*inner.future_or_output.get() } {
287                FutureOrOutput::Output(item) => (item.clone(), false),
288                FutureOrOutput::Future(_) => unreachable!(),
289            },
290        }
291    }
292}
293
294unsafe impl<Fut> Send for Inner<Fut>
295where
296    Fut: Future + Send,
297    Fut::Output: Send + Sync,
298{
299}
300
301unsafe impl<Fut> Sync for Inner<Fut>
302where
303    Fut: Future + Send,
304    Fut::Output: Send + Sync,
305{
306}
307
308enum FutureOrOutput<Fut: Future> {
309    Future(Fut),
310    Output(Fut::Output),
311}
312
313const POLLING: usize = 0;
314const COMPLETE: usize = 2;
315const POISONED: usize = 3;
316
317#[cfg(test)]
318mod tests {
319    //! Addresses the original `Shared` futures issue:
320    //! <https://github.com/rust-lang/futures-rs/issues/2706/>
321
322    use futures_util::FutureExt;
323
324    use super::*;
325
326    async fn yield_now() {
327        /// Yield implementation
328        struct YieldNow {
329            yielded: bool,
330        }
331
332        impl Future for YieldNow {
333            type Output = ();
334
335            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
336                if self.yielded {
337                    return Poll::Ready(());
338                }
339
340                self.yielded = true;
341                cx.waker().wake_by_ref();
342                Poll::Pending
343            }
344        }
345
346        YieldNow { yielded: false }.await;
347    }
348
349    #[tokio::test(flavor = "multi_thread")]
350    async fn must_not_hang_up() {
351        for _ in 0..200 {
352            for _ in 0..1000 {
353                test_fut().await;
354            }
355        }
356        println!();
357    }
358
359    async fn test_fut() {
360        let f1 = Shared::new(yield_now());
361        let f2 = f1.clone();
362        let x1 = tokio::spawn(async move {
363            f1.now_or_never();
364        });
365        let x2 = tokio::spawn(async move {
366            f2.await;
367        });
368        x1.await.ok();
369        x2.await.ok();
370    }
371}