tempest_rt/sync/
oneshot.rs1use std::{
8 cell::RefCell,
9 future::poll_fn,
10 rc::Rc,
11 task::{Poll, Waker},
12};
13
14use derive_more::{Display, Error};
15
16struct Inner<T> {
17 value: Option<T>,
18 waker: Option<Waker>,
19}
20
21#[derive(derive_more::Debug)]
23pub struct Sender<T> {
24 #[debug("{:p}", Rc::as_ptr(inner))]
25 inner: Rc<RefCell<Inner<T>>>,
26}
27
28#[derive(derive_more::Debug)]
30#[must_use = "futures do nothing unless awaited"]
31pub struct Receiver<T> {
32 #[debug("{:p}", Rc::as_ptr(inner))]
33 inner: Rc<RefCell<Inner<T>>>,
34}
35
36pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
38 let inner = Rc::new(RefCell::new(Inner {
39 value: None,
40 waker: None,
41 }));
42 (
43 Sender {
44 inner: inner.clone(),
45 },
46 Receiver { inner },
47 )
48}
49
50impl<T> Sender<T> {
51 pub fn send(self, val: T) -> Result<(), T> {
58 if Rc::strong_count(&self.inner) == 1 {
59 return Err(val);
60 }
61 let mut borrow = self.inner.borrow_mut();
62 borrow.value = Some(val);
63 if let Some(waker) = borrow.waker.take() {
64 waker.wake();
65 }
66 Ok(())
67 }
68}
69
70#[derive(Debug, Display, Error, PartialEq, Eq)]
72#[display("sender has been dropped")]
73pub struct RecvError;
74
75#[derive(Debug, Display, Error, PartialEq, Eq)]
77pub enum TryRecvError {
78 #[display("channel is empty")]
80 Empty,
81 #[display("sender has been dropped")]
83 Closed,
84}
85
86impl<T> Receiver<T> {
87 pub(crate) fn poll_recv(
88 &mut self,
89 cx: &mut std::task::Context<'_>,
90 ) -> Poll<Result<T, RecvError>> {
91 let mut borrow = self.inner.borrow_mut();
92 if let Some(val) = borrow.value.take() {
93 return Poll::Ready(Ok(val));
94 }
95 if Rc::strong_count(&self.inner) == 1 {
96 return Poll::Ready(Err(RecvError));
97 }
98 borrow.waker = Some(cx.waker().clone());
99 Poll::Pending
100 }
101
102 pub async fn recv(mut self) -> Result<T, RecvError> {
106 poll_fn(|cx| self.poll_recv(cx)).await
107 }
108
109 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
114 let mut borrow = self.inner.borrow_mut();
115 if let Some(val) = borrow.value.take() {
116 return Ok(val);
117 }
118 if Rc::strong_count(&self.inner) == 1 {
119 Err(TryRecvError::Closed)
120 } else {
121 Err(TryRecvError::Empty)
122 }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use tempest_io::VirtualIo;
129
130 use crate::{block_on, spawn};
131
132 use super::*;
133
134 #[test]
135 fn test_oneshot_send_recv() {
136 block_on(VirtualIo::default(), async {
137 let (tx, rx) = channel();
138 tx.send(5).unwrap();
139 assert_eq!(rx.recv().await.unwrap(), 5);
140 })
141 }
142
143 #[test]
144 fn test_oneshot_sender_dropped() {
145 block_on(VirtualIo::default(), async {
146 let (tx, rx) = channel::<i32>();
147 drop(tx);
148 assert_eq!(rx.recv().await, Err(RecvError));
149 });
150 }
151
152 #[test]
153 fn test_oneshot_receiver_dropped() {
154 block_on(VirtualIo::default(), async {
155 let (tx, rx) = channel::<i32>();
156 drop(rx);
157 assert_eq!(tx.send(99), Err(99));
158 });
159 }
160
161 #[test]
162 fn test_oneshot_from_task() {
163 block_on(VirtualIo::default(), async {
164 let (tx, rx) = channel();
165 let handle = spawn(async {
166 tx.send(42).unwrap();
167 });
168
169 handle.await.unwrap();
170 let result = rx.recv().await.unwrap();
171 assert_eq!(result, 42);
172 });
173 }
174
175 #[test]
176 fn test_oneshot_try_recv_empty() {
177 block_on(VirtualIo::default(), async {
178 let (_tx, mut rx) = channel::<i32>();
179 assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
180 });
181 }
182
183 #[test]
184 fn test_oneshot_try_recv_closed() {
185 block_on(VirtualIo::default(), async {
186 let (tx, mut rx) = channel::<i32>();
187 drop(tx);
188 assert_eq!(rx.try_recv(), Err(TryRecvError::Closed));
189 });
190 }
191
192 #[test]
193 fn test_oneshot_try_recv_value() {
194 block_on(VirtualIo::default(), async {
195 let (tx, mut rx) = channel();
196 tx.send(42).unwrap();
197 assert_eq!(rx.try_recv(), Ok(42));
198 });
199 }
200}