scrappy_utils/
mpsc.rs

1//! A multi-producer, single-consumer, futures-aware, FIFO queue.
2use std::any::Any;
3use std::collections::VecDeque;
4use std::error::Error;
5use std::fmt;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures::{Sink, Stream};
10
11use crate::cell::Cell;
12use crate::task::LocalWaker;
13
14/// Creates a unbounded in-memory channel with buffered storage.
15pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
16    let shared = Cell::new(Shared {
17        has_receiver: true,
18        buffer: VecDeque::new(),
19        blocked_recv: LocalWaker::new(),
20    });
21    let sender = Sender {
22        shared: shared.clone(),
23    };
24    let receiver = Receiver { shared };
25    (sender, receiver)
26}
27
28#[derive(Debug)]
29struct Shared<T> {
30    buffer: VecDeque<T>,
31    blocked_recv: LocalWaker,
32    has_receiver: bool,
33}
34
35/// The transmission end of a channel.
36///
37/// This is created by the `channel` function.
38#[derive(Debug)]
39pub struct Sender<T> {
40    shared: Cell<Shared<T>>,
41}
42
43impl<T> Unpin for Sender<T> {}
44
45impl<T> Sender<T> {
46    /// Sends the provided message along this channel.
47    pub fn send(&self, item: T) -> Result<(), SendError<T>> {
48        let shared = unsafe { self.shared.get_mut_unsafe() };
49        if !shared.has_receiver {
50            return Err(SendError(item)); // receiver was dropped
51        };
52        shared.buffer.push_back(item);
53        shared.blocked_recv.wake();
54        Ok(())
55    }
56
57    /// Closes the sender half
58    ///
59    /// This prevents any further messages from being sent on the channel while
60    /// still enabling the receiver to drain messages that are buffered.
61    pub fn close(&mut self) {
62        self.shared.get_mut().has_receiver = false;
63    }
64}
65
66impl<T> Clone for Sender<T> {
67    fn clone(&self) -> Self {
68        Sender {
69            shared: self.shared.clone(),
70        }
71    }
72}
73
74impl<T> Sink<T> for Sender<T> {
75    type Error = SendError<T>;
76
77    fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78        Poll::Ready(Ok(()))
79    }
80
81    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), SendError<T>> {
82        self.send(item)
83    }
84
85    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> {
86        Poll::Ready(Ok(()))
87    }
88
89    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90        Poll::Ready(Ok(()))
91    }
92}
93
94impl<T> Drop for Sender<T> {
95    fn drop(&mut self) {
96        let count = self.shared.strong_count();
97        let shared = self.shared.get_mut();
98
99        // check is last sender is about to drop
100        if shared.has_receiver && count == 2 {
101            // Wake up receiver as its stream has ended
102            shared.blocked_recv.wake();
103        }
104    }
105}
106
107/// The receiving end of a channel which implements the `Stream` trait.
108///
109/// This is created by the `channel` function.
110#[derive(Debug)]
111pub struct Receiver<T> {
112    shared: Cell<Shared<T>>,
113}
114
115impl<T> Receiver<T> {
116    /// Create Sender
117    pub fn sender(&self) -> Sender<T> {
118        Sender {
119            shared: self.shared.clone(),
120        }
121    }
122}
123
124impl<T> Unpin for Receiver<T> {}
125
126impl<T> Stream for Receiver<T> {
127    type Item = T;
128
129    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130        if self.shared.strong_count() == 1 {
131            // All senders have been dropped, so drain the buffer and end the
132            // stream.
133            Poll::Ready(self.shared.get_mut().buffer.pop_front())
134        } else if let Some(msg) = self.shared.get_mut().buffer.pop_front() {
135            Poll::Ready(Some(msg))
136        } else {
137            self.shared.get_mut().blocked_recv.register(cx.waker());
138            Poll::Pending
139        }
140    }
141}
142
143impl<T> Drop for Receiver<T> {
144    fn drop(&mut self) {
145        let shared = self.shared.get_mut();
146        shared.buffer.clear();
147        shared.has_receiver = false;
148    }
149}
150
151/// Error type for sending, used when the receiving end of a channel is
152/// dropped
153pub struct SendError<T>(T);
154
155impl<T> fmt::Debug for SendError<T> {
156    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
157        fmt.debug_tuple("SendError").field(&"...").finish()
158    }
159}
160
161impl<T> fmt::Display for SendError<T> {
162    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
163        write!(fmt, "send failed because receiver is gone")
164    }
165}
166
167impl<T: Any> Error for SendError<T> {
168    fn description(&self) -> &str {
169        "send failed because receiver is gone"
170    }
171}
172
173impl<T> SendError<T> {
174    /// Returns the message that was attempted to be sent but failed.
175    pub fn into_inner(self) -> T {
176        self.0
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use futures::future::lazy;
184    use futures::{Stream, StreamExt};
185
186    #[scrappy_rt::test]
187    async fn test_mpsc() {
188        let (tx, mut rx) = channel();
189        tx.send("test").unwrap();
190        assert_eq!(rx.next().await.unwrap(), "test");
191
192        let tx2 = tx.clone();
193        tx2.send("test2").unwrap();
194        assert_eq!(rx.next().await.unwrap(), "test2");
195
196        assert_eq!(
197            lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
198            Poll::Pending
199        );
200        drop(tx2);
201        assert_eq!(
202            lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
203            Poll::Pending
204        );
205        drop(tx);
206        assert_eq!(rx.next().await, None);
207
208        let (tx, rx) = channel();
209        tx.send("test").unwrap();
210        drop(rx);
211        assert!(tx.send("test").is_err());
212
213        let (mut tx, _) = channel();
214        let tx2 = tx.clone();
215        tx.close();
216        assert!(tx.send("test").is_err());
217        assert!(tx2.send("test").is_err());
218    }
219}