Skip to main content

vpp_plugin/vlib/process_node/
mpsc.rs

1//! Multi-producer, single-consumer channel for sending data into asynchronous tasks
2
3use futures::task::AtomicWaker;
4use std::{
5    cell::Cell,
6    pin::Pin,
7    sync::{Arc, mpsc::TryRecvError},
8    task::{Context, Poll},
9};
10
11/// A multiple-producer, single-consumer channel for async process events.
12///
13/// This channel is unbounded.
14pub struct Sender<T> {
15    inner: std::sync::mpsc::Sender<T>,
16    shared_state: Arc<MpscSharedState>,
17}
18
19/// Receiver for a single consumer side of an MPSC channel.
20pub struct Receiver<T> {
21    inner: std::sync::mpsc::Receiver<T>,
22    shared_state: Arc<MpscSharedState>,
23    _not_sync: std::marker::PhantomData<Cell<()>>,
24}
25
26struct MpscSharedState {
27    rx_waker: AtomicWaker,
28}
29
30// SAFETY: `MpscSender` uses a mutex/condvar-protected buffer and only stores `T` when `T: Send`, so it is safe to send across threads.
31unsafe impl<T: Send> Send for Sender<T> {}
32
33// SAFETY: `MpscSender` can be shared across threads because internal synchronization ensures correctness.
34unsafe impl<T: Send> Sync for Sender<T> {}
35
36// Receiver is not Sync by design (single consumer), but it may be moved between threads safely.
37// SAFETY: `MpscReceiver` guarantees single-consumer semantics while preserving `T: Send`.
38unsafe impl<T: Send> Send for Receiver<T> {}
39
40impl<T> Sender<T> {
41    /// Send a value into the channel.
42    pub fn send(&self, value: T) -> Result<(), T> {
43        self.inner
44            .send(value)
45            .map_err(|std::sync::mpsc::SendError(value)| value)?;
46
47        // Notify the process/executor to resume via VPP event mechanism.
48        self.shared_state.rx_waker.wake();
49
50        Ok(())
51    }
52}
53
54impl<T> Clone for Sender<T> {
55    fn clone(&self) -> Self {
56        Self {
57            inner: self.inner.clone(),
58            shared_state: self.shared_state.clone(),
59        }
60    }
61}
62
63impl<T> Receiver<T> {
64    /// Try to receive a value without blocking.
65    pub fn try_recv(&self) -> Option<T> {
66        self.inner.try_recv().ok()
67    }
68
69    /// Returns a future that waits for the next value or channel close.
70    pub fn recv(&self) -> ReceiverFuture<'_, T> {
71        ReceiverFuture { receiver: self }
72    }
73
74    fn poll_recv(&self, cx: &mut Context<'_>) -> Poll<Option<T>> {
75        match self.inner.try_recv() {
76            Ok(value) => Poll::Ready(Some(value)),
77            Err(TryRecvError::Disconnected) => Poll::Ready(None),
78            Err(TryRecvError::Empty) => {
79                self.shared_state.rx_waker.register(cx.waker());
80
81                Poll::Pending
82            }
83        }
84    }
85}
86
87/// Future for receiver for a single consumer side of an MPSC channel.
88pub struct ReceiverFuture<'a, T> {
89    receiver: &'a Receiver<T>,
90}
91
92impl<'a, T> Future for ReceiverFuture<'a, T> {
93    type Output = Option<T>;
94
95    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
96        self.receiver.poll_recv(cx)
97    }
98}
99
100/// Create an unbounded multi-producer single-consumer channel.
101pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
102    let (sender, receiver) = std::sync::mpsc::channel();
103
104    let shared_state = Arc::new(MpscSharedState {
105        rx_waker: AtomicWaker::new(),
106    });
107
108    (
109        Sender {
110            inner: sender,
111            shared_state: shared_state.clone(),
112        },
113        Receiver {
114            inner: receiver,
115            shared_state,
116            _not_sync: std::marker::PhantomData,
117        },
118    )
119}
120
121#[cfg(test)]
122mod tests {
123    use super::channel;
124    use futures_task::noop_waker;
125
126    use std::{
127        pin::Pin,
128        task::{Context, Poll},
129        thread,
130    };
131
132    #[test]
133    fn mpsc_channel_basic_send_recv() {
134        let (tx, rx) = channel();
135        assert!(tx.send(10).is_ok());
136        assert!(tx.send(20).is_ok());
137
138        assert_eq!(rx.try_recv(), Some(10));
139        assert_eq!(rx.try_recv(), Some(20));
140        assert_eq!(rx.try_recv(), None);
141
142        drop(tx);
143        assert!(rx.try_recv().is_none());
144    }
145
146    #[test]
147    fn mpsc_channel_multithreaded_producers() {
148        let (tx, rx) = channel();
149        let tx1 = tx.clone();
150        let tx2 = tx.clone();
151
152        let t1 = thread::spawn(move || {
153            for i in 0..4 {
154                assert!(tx1.send(i).is_ok());
155            }
156        });
157        let t2 = thread::spawn(move || {
158            for i in 4..8 {
159                assert!(tx2.send(i).is_ok());
160            }
161        });
162
163        t1.join().unwrap();
164        t2.join().unwrap();
165
166        let mut seen = [false; 8];
167        for _ in 0..8 {
168            let value = rx.try_recv().expect("channel should return value");
169            assert!(value < 8);
170            seen[value] = true;
171        }
172
173        assert!(seen.iter().all(|&v| v));
174    }
175
176    #[test]
177    fn mpsc_channel_async_poll_wakes() {
178        let (tx, rx) = channel();
179        let mut rx_future = rx.recv();
180        let waker = noop_waker();
181        let mut cx = Context::from_waker(&waker);
182
183        assert!(matches!(
184            Pin::new(&mut rx_future).poll(&mut cx),
185            Poll::Pending
186        ));
187
188        assert!(tx.send(42).is_ok());
189
190        match Pin::new(&mut rx_future).poll(&mut cx) {
191            Poll::Ready(Some(v)) => assert_eq!(v, 42),
192            other => panic!("expected ready after send, got {:?}", other),
193        }
194
195        drop(tx);
196        let mut rx_future2 = rx.recv();
197        assert!(matches!(
198            Pin::new(&mut rx_future2).poll(&mut cx),
199            Poll::Ready(None)
200        ));
201    }
202}