scrappy_utils/
oneshot.rs

1//! A one-shot, futures-aware channel.
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6pub use futures::channel::oneshot::Canceled;
7use slab::Slab;
8
9use crate::cell::Cell;
10use crate::task::LocalWaker;
11
12/// Creates a new futures-aware, one-shot channel.
13pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
14    let inner = Cell::new(Inner {
15        value: None,
16        rx_task: LocalWaker::new(),
17    });
18    let tx = Sender {
19        inner: inner.clone(),
20    };
21    let rx = Receiver { inner };
22    (tx, rx)
23}
24
25/// Creates a new futures-aware, pool of one-shot's.
26pub fn pool<T>() -> Pool<T> {
27    Pool(Cell::new(Slab::new()))
28}
29
30/// Represents the completion half of a oneshot through which the result of a
31/// computation is signaled.
32#[derive(Debug)]
33pub struct Sender<T> {
34    inner: Cell<Inner<T>>,
35}
36
37/// A future representing the completion of a computation happening elsewhere in
38/// memory.
39#[derive(Debug)]
40#[must_use = "futures do nothing unless polled"]
41pub struct Receiver<T> {
42    inner: Cell<Inner<T>>,
43}
44
45// The channels do not ever project Pin to the inner T
46impl<T> Unpin for Receiver<T> {}
47impl<T> Unpin for Sender<T> {}
48
49#[derive(Debug)]
50struct Inner<T> {
51    value: Option<T>,
52    rx_task: LocalWaker,
53}
54
55impl<T> Sender<T> {
56    /// Completes this oneshot with a successful result.
57    ///
58    /// This function will consume `self` and indicate to the other end, the
59    /// `Receiver`, that the error provided is the result of the computation this
60    /// represents.
61    ///
62    /// If the value is successfully enqueued for the remote end to receive,
63    /// then `Ok(())` is returned. If the receiving end was dropped before
64    /// this function was called, however, then `Err` is returned with the value
65    /// provided.
66    pub fn send(mut self, val: T) -> Result<(), T> {
67        if self.inner.strong_count() == 2 {
68            let inner = self.inner.get_mut();
69            inner.value = Some(val);
70            inner.rx_task.wake();
71            Ok(())
72        } else {
73            Err(val)
74        }
75    }
76
77    /// Tests to see whether this `Sender`'s corresponding `Receiver`
78    /// has gone away.
79    pub fn is_canceled(&self) -> bool {
80        self.inner.strong_count() == 1
81    }
82}
83
84impl<T> Drop for Sender<T> {
85    fn drop(&mut self) {
86        self.inner.get_ref().rx_task.wake();
87    }
88}
89
90impl<T> Future for Receiver<T> {
91    type Output = Result<T, Canceled>;
92
93    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
94        let this = self.get_mut();
95
96        // If we've got a value, then skip the logic below as we're done.
97        if let Some(val) = this.inner.get_mut().value.take() {
98            return Poll::Ready(Ok(val));
99        }
100
101        // Check if sender is dropped and return error if it is.
102        if this.inner.strong_count() == 1 {
103            Poll::Ready(Err(Canceled))
104        } else {
105            this.inner.get_ref().rx_task.register(cx.waker());
106            Poll::Pending
107        }
108    }
109}
110
111/// Futures-aware, pool of one-shot's.
112pub struct Pool<T>(Cell<Slab<PoolInner<T>>>);
113
114bitflags::bitflags! {
115    pub struct Flags: u8 {
116        const SENDER = 0b0000_0001;
117        const RECEIVER = 0b0000_0010;
118    }
119}
120
121#[derive(Debug)]
122struct PoolInner<T> {
123    flags: Flags,
124    value: Option<T>,
125    waker: LocalWaker,
126}
127
128impl<T> Pool<T> {
129    pub fn channel(&mut self) -> (PSender<T>, PReceiver<T>) {
130        let token = self.0.get_mut().insert(PoolInner {
131            flags: Flags::all(),
132            value: None,
133            waker: LocalWaker::default(),
134        });
135
136        (
137            PSender {
138                token,
139                inner: self.0.clone(),
140            },
141            PReceiver {
142                token,
143                inner: self.0.clone(),
144            },
145        )
146    }
147}
148
149impl<T> Clone for Pool<T> {
150    fn clone(&self) -> Self {
151        Pool(self.0.clone())
152    }
153}
154
155/// Represents the completion half of a oneshot through which the result of a
156/// computation is signaled.
157#[derive(Debug)]
158pub struct PSender<T> {
159    token: usize,
160    inner: Cell<Slab<PoolInner<T>>>,
161}
162
163/// A future representing the completion of a computation happening elsewhere in
164/// memory.
165#[derive(Debug)]
166#[must_use = "futures do nothing unless polled"]
167pub struct PReceiver<T> {
168    token: usize,
169    inner: Cell<Slab<PoolInner<T>>>,
170}
171
172// The oneshots do not ever project Pin to the inner T
173impl<T> Unpin for PReceiver<T> {}
174impl<T> Unpin for PSender<T> {}
175
176impl<T> PSender<T> {
177    /// Completes this oneshot with a successful result.
178    ///
179    /// This function will consume `self` and indicate to the other end, the
180    /// `Receiver`, that the error provided is the result of the computation this
181    /// represents.
182    ///
183    /// If the value is successfully enqueued for the remote end to receive,
184    /// then `Ok(())` is returned. If the receiving end was dropped before
185    /// this function was called, however, then `Err` is returned with the value
186    /// provided.
187    pub fn send(mut self, val: T) -> Result<(), T> {
188        let inner = unsafe { self.inner.get_mut().get_unchecked_mut(self.token) };
189
190        if inner.flags.contains(Flags::RECEIVER) {
191            inner.value = Some(val);
192            inner.waker.wake();
193            Ok(())
194        } else {
195            Err(val)
196        }
197    }
198
199    /// Tests to see whether this `Sender`'s corresponding `Receiver`
200    /// has gone away.
201    pub fn is_canceled(&self) -> bool {
202        !unsafe { self.inner.get_ref().get_unchecked(self.token) }
203            .flags
204            .contains(Flags::RECEIVER)
205    }
206}
207
208impl<T> Drop for PSender<T> {
209    fn drop(&mut self) {
210        let inner = unsafe { self.inner.get_mut().get_unchecked_mut(self.token) };
211        if inner.flags.contains(Flags::RECEIVER) {
212            inner.waker.wake();
213            inner.flags.remove(Flags::SENDER);
214        } else {
215            self.inner.get_mut().remove(self.token);
216        }
217    }
218}
219
220impl<T> Drop for PReceiver<T> {
221    fn drop(&mut self) {
222        let inner = unsafe { self.inner.get_mut().get_unchecked_mut(self.token) };
223        if inner.flags.contains(Flags::SENDER) {
224            inner.flags.remove(Flags::RECEIVER);
225        } else {
226            self.inner.get_mut().remove(self.token);
227        }
228    }
229}
230
231impl<T> Future for PReceiver<T> {
232    type Output = Result<T, Canceled>;
233
234    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235        let this = self.get_mut();
236        let inner = unsafe { this.inner.get_mut().get_unchecked_mut(this.token) };
237
238        // If we've got a value, then skip the logic below as we're done.
239        if let Some(val) = inner.value.take() {
240            return Poll::Ready(Ok(val));
241        }
242
243        // Check if sender is dropped and return error if it is.
244        if !inner.flags.contains(Flags::SENDER) {
245            Poll::Ready(Err(Canceled))
246        } else {
247            inner.waker.register(cx.waker());
248            Poll::Pending
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use futures::future::lazy;
257
258    #[scrappy_rt::test]
259    async fn test_oneshot() {
260        let (tx, rx) = channel();
261        tx.send("test").unwrap();
262        assert_eq!(rx.await.unwrap(), "test");
263
264        let (tx, rx) = channel();
265        assert!(!tx.is_canceled());
266        drop(rx);
267        assert!(tx.is_canceled());
268        assert!(tx.send("test").is_err());
269
270        let (tx, rx) = channel::<&'static str>();
271        drop(tx);
272        assert!(rx.await.is_err());
273
274        let (tx, mut rx) = channel::<&'static str>();
275        assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending);
276        tx.send("test").unwrap();
277        assert_eq!(rx.await.unwrap(), "test");
278
279        let (tx, mut rx) = channel::<&'static str>();
280        assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending);
281        drop(tx);
282        assert!(rx.await.is_err());
283    }
284
285    #[scrappy_rt::test]
286    async fn test_pool() {
287        let (tx, rx) = pool().channel();
288        tx.send("test").unwrap();
289        assert_eq!(rx.await.unwrap(), "test");
290
291        let (tx, rx) = pool().channel();
292        assert!(!tx.is_canceled());
293        drop(rx);
294        assert!(tx.is_canceled());
295        assert!(tx.send("test").is_err());
296
297        let (tx, rx) = pool::<&'static str>().channel();
298        drop(tx);
299        assert!(rx.await.is_err());
300
301        let (tx, mut rx) = pool::<&'static str>().channel();
302        assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending);
303        tx.send("test").unwrap();
304        assert_eq!(rx.await.unwrap(), "test");
305
306        let (tx, mut rx) = pool::<&'static str>().channel();
307        assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending);
308        drop(tx);
309        assert!(rx.await.is_err());
310    }
311}