Skip to main content

xet_runtime/utils/
pipe.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use bytes::Bytes;
6use futures::Stream;
7use tokio::io::AsyncWrite;
8use tokio::sync::mpsc;
9use tokio::sync::mpsc::error::TrySendError;
10use tokio_util::io::StreamReader;
11
12pub fn pipe(buffer_size: usize) -> (ChannelWriter, ChannelStream) {
13    let (sender, receiver) = mpsc::channel(buffer_size);
14    (ChannelWriter::new(sender), ChannelStream::new(receiver))
15}
16
17/// Adapter that implements AsyncRead from an mpsc Receiver
18pub struct ChannelStream(mpsc::Receiver<io::Result<Bytes>>);
19
20impl ChannelStream {
21    fn new(rx: mpsc::Receiver<io::Result<Bytes>>) -> Self {
22        Self(rx)
23    }
24
25    pub fn reader(self) -> ChannelReader {
26        ChannelReader::new(self)
27    }
28}
29
30impl Stream for ChannelStream {
31    type Item = io::Result<Bytes>;
32
33    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34        self.0.poll_recv(cx)
35    }
36}
37
38type ChannelReader = StreamReader<ChannelStream, Bytes>;
39
40/// Adapter that implements AsyncWrite from a mpsc Sender
41pub struct ChannelWriter(mpsc::Sender<io::Result<Bytes>>);
42
43impl ChannelWriter {
44    fn new(tx: mpsc::Sender<io::Result<Bytes>>) -> Self {
45        Self(tx)
46    }
47}
48
49impl AsyncWrite for ChannelWriter {
50    fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
51        let perm = match self.0.try_reserve() {
52            Ok(p) => p,
53            Err(TrySendError::Closed(_)) => {
54                return Poll::Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, "receiver closed")));
55            },
56            Err(TrySendError::Full(_)) => return Poll::Pending,
57        };
58
59        let data = Bytes::copy_from_slice(buf);
60        let len = data.len();
61        perm.send(Ok(data));
62
63        Poll::Ready(Ok(len))
64    }
65
66    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
67        // mpsc channels don't buffer in the same way, so flush is a no-op
68        Poll::Ready(Ok(()))
69    }
70
71    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72        // Dropping the sender will close the channel
73        Poll::Ready(Ok(()))
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use tokio::io::{AsyncReadExt, AsyncWriteExt};
80
81    use super::*;
82
83    #[tokio::test]
84    async fn test_channel_read_write() {
85        let (mut writer, stream) = pipe(10);
86        let mut reader = stream.reader();
87
88        // Write some data
89        writer.write_all(b"Hello, ").await.unwrap();
90        writer.write_all(b"World!").await.unwrap();
91
92        // Drop writer to signal EOF
93        drop(writer);
94
95        // Read the data
96        let mut buf = Vec::new();
97        reader.read_to_end(&mut buf).await.unwrap();
98
99        assert_eq!(buf, b"Hello, World!");
100    }
101}