1use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6pub use futures::channel::oneshot::Canceled;
7use slab::Slab;
8
9use crate::cell::Cell;
10use crate::task::LocalWaker;
11
12pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
14 let inner = Cell::new(Inner {
15 value: None,
16 rx_task: LocalWaker::new(),
17 });
18 let tx = Sender {
19 inner: inner.clone(),
20 };
21 let rx = Receiver { inner };
22 (tx, rx)
23}
24
25pub fn pool<T>() -> Pool<T> {
27 Pool(Cell::new(Slab::new()))
28}
29
30#[derive(Debug)]
33pub struct Sender<T> {
34 inner: Cell<Inner<T>>,
35}
36
37#[derive(Debug)]
40#[must_use = "futures do nothing unless polled"]
41pub struct Receiver<T> {
42 inner: Cell<Inner<T>>,
43}
44
45impl<T> Unpin for Receiver<T> {}
47impl<T> Unpin for Sender<T> {}
48
49#[derive(Debug)]
50struct Inner<T> {
51 value: Option<T>,
52 rx_task: LocalWaker,
53}
54
55impl<T> Sender<T> {
56 pub fn send(mut self, val: T) -> Result<(), T> {
67 if self.inner.strong_count() == 2 {
68 let inner = self.inner.get_mut();
69 inner.value = Some(val);
70 inner.rx_task.wake();
71 Ok(())
72 } else {
73 Err(val)
74 }
75 }
76
77 pub fn is_canceled(&self) -> bool {
80 self.inner.strong_count() == 1
81 }
82}
83
84impl<T> Drop for Sender<T> {
85 fn drop(&mut self) {
86 self.inner.get_ref().rx_task.wake();
87 }
88}
89
90impl<T> Future for Receiver<T> {
91 type Output = Result<T, Canceled>;
92
93 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
94 let this = self.get_mut();
95
96 if let Some(val) = this.inner.get_mut().value.take() {
98 return Poll::Ready(Ok(val));
99 }
100
101 if this.inner.strong_count() == 1 {
103 Poll::Ready(Err(Canceled))
104 } else {
105 this.inner.get_ref().rx_task.register(cx.waker());
106 Poll::Pending
107 }
108 }
109}
110
111pub struct Pool<T>(Cell<Slab<PoolInner<T>>>);
113
114bitflags::bitflags! {
115 pub struct Flags: u8 {
116 const SENDER = 0b0000_0001;
117 const RECEIVER = 0b0000_0010;
118 }
119}
120
121#[derive(Debug)]
122struct PoolInner<T> {
123 flags: Flags,
124 value: Option<T>,
125 waker: LocalWaker,
126}
127
128impl<T> Pool<T> {
129 pub fn channel(&mut self) -> (PSender<T>, PReceiver<T>) {
130 let token = self.0.get_mut().insert(PoolInner {
131 flags: Flags::all(),
132 value: None,
133 waker: LocalWaker::default(),
134 });
135
136 (
137 PSender {
138 token,
139 inner: self.0.clone(),
140 },
141 PReceiver {
142 token,
143 inner: self.0.clone(),
144 },
145 )
146 }
147}
148
149impl<T> Clone for Pool<T> {
150 fn clone(&self) -> Self {
151 Pool(self.0.clone())
152 }
153}
154
155#[derive(Debug)]
158pub struct PSender<T> {
159 token: usize,
160 inner: Cell<Slab<PoolInner<T>>>,
161}
162
163#[derive(Debug)]
166#[must_use = "futures do nothing unless polled"]
167pub struct PReceiver<T> {
168 token: usize,
169 inner: Cell<Slab<PoolInner<T>>>,
170}
171
172impl<T> Unpin for PReceiver<T> {}
174impl<T> Unpin for PSender<T> {}
175
176impl<T> PSender<T> {
177 pub fn send(mut self, val: T) -> Result<(), T> {
188 let inner = unsafe { self.inner.get_mut().get_unchecked_mut(self.token) };
189
190 if inner.flags.contains(Flags::RECEIVER) {
191 inner.value = Some(val);
192 inner.waker.wake();
193 Ok(())
194 } else {
195 Err(val)
196 }
197 }
198
199 pub fn is_canceled(&self) -> bool {
202 !unsafe { self.inner.get_ref().get_unchecked(self.token) }
203 .flags
204 .contains(Flags::RECEIVER)
205 }
206}
207
208impl<T> Drop for PSender<T> {
209 fn drop(&mut self) {
210 let inner = unsafe { self.inner.get_mut().get_unchecked_mut(self.token) };
211 if inner.flags.contains(Flags::RECEIVER) {
212 inner.waker.wake();
213 inner.flags.remove(Flags::SENDER);
214 } else {
215 self.inner.get_mut().remove(self.token);
216 }
217 }
218}
219
220impl<T> Drop for PReceiver<T> {
221 fn drop(&mut self) {
222 let inner = unsafe { self.inner.get_mut().get_unchecked_mut(self.token) };
223 if inner.flags.contains(Flags::SENDER) {
224 inner.flags.remove(Flags::RECEIVER);
225 } else {
226 self.inner.get_mut().remove(self.token);
227 }
228 }
229}
230
231impl<T> Future for PReceiver<T> {
232 type Output = Result<T, Canceled>;
233
234 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235 let this = self.get_mut();
236 let inner = unsafe { this.inner.get_mut().get_unchecked_mut(this.token) };
237
238 if let Some(val) = inner.value.take() {
240 return Poll::Ready(Ok(val));
241 }
242
243 if !inner.flags.contains(Flags::SENDER) {
245 Poll::Ready(Err(Canceled))
246 } else {
247 inner.waker.register(cx.waker());
248 Poll::Pending
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use futures::future::lazy;
257
258 #[scrappy_rt::test]
259 async fn test_oneshot() {
260 let (tx, rx) = channel();
261 tx.send("test").unwrap();
262 assert_eq!(rx.await.unwrap(), "test");
263
264 let (tx, rx) = channel();
265 assert!(!tx.is_canceled());
266 drop(rx);
267 assert!(tx.is_canceled());
268 assert!(tx.send("test").is_err());
269
270 let (tx, rx) = channel::<&'static str>();
271 drop(tx);
272 assert!(rx.await.is_err());
273
274 let (tx, mut rx) = channel::<&'static str>();
275 assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending);
276 tx.send("test").unwrap();
277 assert_eq!(rx.await.unwrap(), "test");
278
279 let (tx, mut rx) = channel::<&'static str>();
280 assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending);
281 drop(tx);
282 assert!(rx.await.is_err());
283 }
284
285 #[scrappy_rt::test]
286 async fn test_pool() {
287 let (tx, rx) = pool().channel();
288 tx.send("test").unwrap();
289 assert_eq!(rx.await.unwrap(), "test");
290
291 let (tx, rx) = pool().channel();
292 assert!(!tx.is_canceled());
293 drop(rx);
294 assert!(tx.is_canceled());
295 assert!(tx.send("test").is_err());
296
297 let (tx, rx) = pool::<&'static str>().channel();
298 drop(tx);
299 assert!(rx.await.is_err());
300
301 let (tx, mut rx) = pool::<&'static str>().channel();
302 assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending);
303 tx.send("test").unwrap();
304 assert_eq!(rx.await.unwrap(), "test");
305
306 let (tx, mut rx) = pool::<&'static str>().channel();
307 assert_eq!(lazy(|cx| Pin::new(&mut rx).poll(cx)).await, Poll::Pending);
308 drop(tx);
309 assert!(rx.await.is_err());
310 }
311}