vpp_plugin/vlib/process_node/
mpsc.rs1use futures::task::AtomicWaker;
4use std::{
5 cell::Cell,
6 pin::Pin,
7 sync::{Arc, mpsc::TryRecvError},
8 task::{Context, Poll},
9};
10
11pub struct Sender<T> {
15 inner: std::sync::mpsc::Sender<T>,
16 shared_state: Arc<MpscSharedState>,
17}
18
19pub struct Receiver<T> {
21 inner: std::sync::mpsc::Receiver<T>,
22 shared_state: Arc<MpscSharedState>,
23 _not_sync: std::marker::PhantomData<Cell<()>>,
24}
25
26struct MpscSharedState {
27 rx_waker: AtomicWaker,
28}
29
30unsafe impl<T: Send> Send for Sender<T> {}
32
33unsafe impl<T: Send> Sync for Sender<T> {}
35
36unsafe impl<T: Send> Send for Receiver<T> {}
39
40impl<T> Sender<T> {
41 pub fn send(&self, value: T) -> Result<(), T> {
43 self.inner
44 .send(value)
45 .map_err(|std::sync::mpsc::SendError(value)| value)?;
46
47 self.shared_state.rx_waker.wake();
49
50 Ok(())
51 }
52}
53
54impl<T> Clone for Sender<T> {
55 fn clone(&self) -> Self {
56 Self {
57 inner: self.inner.clone(),
58 shared_state: self.shared_state.clone(),
59 }
60 }
61}
62
63impl<T> Receiver<T> {
64 pub fn try_recv(&self) -> Option<T> {
66 self.inner.try_recv().ok()
67 }
68
69 pub fn recv(&self) -> ReceiverFuture<'_, T> {
71 ReceiverFuture { receiver: self }
72 }
73
74 fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Option<T>> {
75 match self.inner.try_recv() {
76 Ok(value) => Poll::Ready(Some(value)),
77 Err(TryRecvError::Disconnected) => Poll::Ready(None),
78 Err(TryRecvError::Empty) => {
79 self.shared_state.rx_waker.register(cx.waker());
80
81 Poll::Pending
82 }
83 }
84 }
85}
86
87pub struct ReceiverFuture<'a, T> {
89 receiver: &'a Receiver<T>,
90}
91
92impl<'a, T> Future for ReceiverFuture<'a, T> {
93 type Output = Option<T>;
94
95 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
96 self.receiver.poll_recv(cx)
97 }
98}
99
100pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
102 let (sender, receiver) = std::sync::mpsc::channel();
103
104 let shared_state = Arc::new(MpscSharedState {
105 rx_waker: AtomicWaker::new(),
106 });
107
108 (
109 Sender {
110 inner: sender,
111 shared_state: shared_state.clone(),
112 },
113 Receiver {
114 inner: receiver,
115 shared_state,
116 _not_sync: std::marker::PhantomData,
117 },
118 )
119}
120
121#[cfg(test)]
122mod tests {
123 use super::channel;
124 use futures_task::noop_waker;
125
126 use std::{
127 pin::Pin,
128 task::{Context, Poll},
129 thread,
130 };
131
132 #[test]
133 fn mpsc_channel_basic_send_recv() {
134 let (tx, rx) = channel();
135 assert!(tx.send(10).is_ok());
136 assert!(tx.send(20).is_ok());
137
138 assert_eq!(rx.try_recv(), Some(10));
139 assert_eq!(rx.try_recv(), Some(20));
140 assert_eq!(rx.try_recv(), None);
141
142 drop(tx);
143 assert!(rx.try_recv().is_none());
144 }
145
146 #[test]
147 fn mpsc_channel_multithreaded_producers() {
148 let (tx, rx) = channel();
149 let tx1 = tx.clone();
150 let tx2 = tx.clone();
151
152 let t1 = thread::spawn(move || {
153 for i in 0..4 {
154 assert!(tx1.send(i).is_ok());
155 }
156 });
157 let t2 = thread::spawn(move || {
158 for i in 4..8 {
159 assert!(tx2.send(i).is_ok());
160 }
161 });
162
163 t1.join().unwrap();
164 t2.join().unwrap();
165
166 let mut seen = [false; 8];
167 for _ in 0..8 {
168 let value = rx.try_recv().expect("channel should return value");
169 assert!(value < 8);
170 seen[value] = true;
171 }
172
173 assert!(seen.iter().all(|&v| v));
174 }
175
176 #[test]
177 fn mpsc_channel_async_poll_wakes() {
178 let (tx, rx) = channel();
179 let mut rx_future = rx.recv();
180 let waker = noop_waker();
181 let mut cx = Context::from_waker(&waker);
182
183 assert!(matches!(
184 Pin::new(&mut rx_future).poll(&mut cx),
185 Poll::Pending
186 ));
187
188 assert!(tx.send(42).is_ok());
189
190 match Pin::new(&mut rx_future).poll(&mut cx) {
191 Poll::Ready(Some(v)) => assert_eq!(v, 42),
192 other => panic!("expected ready after send, got {:?}", other),
193 }
194
195 drop(tx);
196 let mut rx_future2 = rx.recv();
197 assert!(matches!(
198 Pin::new(&mut rx_future2).poll(&mut cx),
199 Poll::Ready(None)
200 ));
201 }
202}