Skip to main content

tempest_rt/sync/
oneshot.rs

1//! Single-value, single-producer, single-consumer channel.
2//!
3//! Create a channel with [`channel`]. The [`Sender`] sends exactly one value; the [`Receiver`]
4//! is a future that resolves to that value. If either end is dropped before the send completes,
5//! the other end observes a closed error.
6
7use std::{
8    cell::RefCell,
9    future::poll_fn,
10    rc::Rc,
11    task::{Poll, Waker},
12};
13
14use derive_more::{Display, Error};
15
16struct Inner<T> {
17    value: Option<T>,
18    waker: Option<Waker>,
19}
20
21/// Sending half of a oneshot channel.
22#[derive(derive_more::Debug)]
23pub struct Sender<T> {
24    #[debug("{:p}", Rc::as_ptr(inner))]
25    inner: Rc<RefCell<Inner<T>>>,
26}
27
28/// Receiving half of a oneshot channel. Implements [`Future`] and resolves to the sent value.
29#[derive(derive_more::Debug)]
30#[must_use = "futures do nothing unless awaited"]
31pub struct Receiver<T> {
32    #[debug("{:p}", Rc::as_ptr(inner))]
33    inner: Rc<RefCell<Inner<T>>>,
34}
35
36/// Creates a oneshot channel, returning the sender and receiver halves.
37pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
38    let inner = Rc::new(RefCell::new(Inner {
39        value: None,
40        waker: None,
41    }));
42    (
43        Sender {
44            inner: inner.clone(),
45        },
46        Receiver { inner },
47    )
48}
49
50impl<T> Sender<T> {
51    /// Tries to send `val` to the [`Receiver`].
52    ///
53    /// # Returns
54    ///
55    /// - `Ok(())` when the send has succeeded
56    /// - `Err(val)` when the channel was closed (receiver was dropped)
57    pub fn send(self, val: T) -> Result<(), T> {
58        if Rc::strong_count(&self.inner) == 1 {
59            return Err(val);
60        }
61        let mut borrow = self.inner.borrow_mut();
62        borrow.value = Some(val);
63        if let Some(waker) = borrow.waker.take() {
64            waker.wake();
65        }
66        Ok(())
67    }
68}
69
70/// Error returned by [`Receiver::recv`] when the sender has been dropped without sending a value.
71#[derive(Debug, Display, Error, PartialEq, Eq)]
72#[display("sender has been dropped")]
73pub struct RecvError;
74
75/// Error returned by [`Receiver::try_recv`].
76#[derive(Debug, Display, Error, PartialEq, Eq)]
77pub enum TryRecvError {
78    /// The sender is still alive but has not sent a value yet.
79    #[display("channel is empty")]
80    Empty,
81    /// The sender was dropped without sending a value.
82    #[display("sender has been dropped")]
83    Closed,
84}
85
86impl<T> Receiver<T> {
87    pub(crate) fn poll_recv(
88        &mut self,
89        cx: &mut std::task::Context<'_>,
90    ) -> Poll<Result<T, RecvError>> {
91        let mut borrow = self.inner.borrow_mut();
92        if let Some(val) = borrow.value.take() {
93            return Poll::Ready(Ok(val));
94        }
95        if Rc::strong_count(&self.inner) == 1 {
96            return Poll::Ready(Err(RecvError));
97        }
98        borrow.waker = Some(cx.waker().clone());
99        Poll::Pending
100    }
101
102    /// Receives the value, parking until it arrives.
103    ///
104    /// Returns `Err` if the sender was dropped without sending a value.
105    pub async fn recv(mut self) -> Result<T, RecvError> {
106        poll_fn(|cx| self.poll_recv(cx)).await
107    }
108
109    /// Receives without waiting.
110    ///
111    /// Returns `Err(Empty)` if no value has been sent yet, or `Err(Closed)` if the sender was
112    /// dropped without sending.
113    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
114        let mut borrow = self.inner.borrow_mut();
115        if let Some(val) = borrow.value.take() {
116            return Ok(val);
117        }
118        if Rc::strong_count(&self.inner) == 1 {
119            Err(TryRecvError::Closed)
120        } else {
121            Err(TryRecvError::Empty)
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use tempest_io::VirtualIo;
129
130    use crate::{block_on, spawn};
131
132    use super::*;
133
134    #[test]
135    fn test_oneshot_send_recv() {
136        block_on(VirtualIo::default(), async {
137            let (tx, rx) = channel();
138            tx.send(5).unwrap();
139            assert_eq!(rx.recv().await.unwrap(), 5);
140        })
141    }
142
143    #[test]
144    fn test_oneshot_sender_dropped() {
145        block_on(VirtualIo::default(), async {
146            let (tx, rx) = channel::<i32>();
147            drop(tx);
148            assert_eq!(rx.recv().await, Err(RecvError));
149        });
150    }
151
152    #[test]
153    fn test_oneshot_receiver_dropped() {
154        block_on(VirtualIo::default(), async {
155            let (tx, rx) = channel::<i32>();
156            drop(rx);
157            assert_eq!(tx.send(99), Err(99));
158        });
159    }
160
161    #[test]
162    fn test_oneshot_from_task() {
163        block_on(VirtualIo::default(), async {
164            let (tx, rx) = channel();
165            let handle = spawn(async {
166                tx.send(42).unwrap();
167            });
168
169            handle.await.unwrap();
170            let result = rx.recv().await.unwrap();
171            assert_eq!(result, 42);
172        });
173    }
174
175    #[test]
176    fn test_oneshot_try_recv_empty() {
177        block_on(VirtualIo::default(), async {
178            let (_tx, mut rx) = channel::<i32>();
179            assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
180        });
181    }
182
183    #[test]
184    fn test_oneshot_try_recv_closed() {
185        block_on(VirtualIo::default(), async {
186            let (tx, mut rx) = channel::<i32>();
187            drop(tx);
188            assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
189        });
190    }
191
192    #[test]
193    fn test_oneshot_try_recv_value() {
194        block_on(VirtualIo::default(), async {
195            let (tx, mut rx) = channel();
196            tx.send(42).unwrap();
197            assert_eq!(rx.try_recv(), Ok(42));
198        });
199    }
200}