simple_async_pipe/
lib.rs

1//! Aims to provide a simple pipe-like functionality for async code.
2//!
3//! # Example
4//!
5//! ```rust
6//! # #[tokio::main]
7//! # async fn main() {
8//! #    use tokio::io::{AsyncWriteExt, AsyncReadExt};
9//!     let (mut reader, mut writer) = simple_async_pipe::pipe(64);
10//!
11//!     let message = b"hello world";
12//!     writer.write_all(message).await.unwrap();
13//!
14//!     let mut buffer = vec![0u8; message.len()];
15//!     reader.read_exact(&mut buffer).await.unwrap();
16//!     assert_eq!(&buffer, message);
17//! # }
18//! ```
19
20use std::pin::Pin;
21use std::sync::{Arc, Mutex};
22use std::task::{Context, Poll, Waker};
23
24use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
25use tokio::sync::mpsc;
26use tokio::sync::mpsc::error::TrySendError;
27
28pub struct PipeWrite {
29    sender: mpsc::Sender<Vec<u8>>,
30    shared: Arc<Mutex<PipeShared>>,
31}
32
33pub struct PipeRead {
34    read_remaining: Vec<u8>,
35    receiver: mpsc::Receiver<Vec<u8>>,
36    shared: Arc<Mutex<PipeShared>>,
37}
38
39struct PipeShared {
40    read_waker: Option<Waker>,
41    write_waker: Option<Waker>,
42}
43
44/// Creates a in-memory pipe. [PipeWrite] will not succeed instant if the internal buffer is full.
45pub fn pipe(buffer: usize) -> (PipeRead, PipeWrite) {
46    let (sender, receiver) = mpsc::channel(buffer);
47    let shared = Arc::new(Mutex::new(PipeShared {
48        read_waker: Default::default(),
49        write_waker: Default::default(),
50    }));
51
52    let read = PipeRead {
53        receiver,
54        read_remaining: Default::default(),
55        shared: shared.clone(),
56    };
57    let write = PipeWrite {
58        sender,
59        shared: shared.clone(),
60    };
61
62    (read, write)
63}
64
65impl AsyncWrite for PipeWrite {
66    fn poll_write(
67        self: Pin<&mut Self>,
68        cx: &mut Context<'_>,
69        buf: &[u8],
70    ) -> Poll<Result<usize, std::io::Error>> {
71        match self.sender.try_send(buf.to_vec()) {
72            Ok(_) => {
73                if let Some(read_waker) = self.shared.lock().unwrap().read_waker.take() {
74                    read_waker.wake();
75                }
76                Poll::Ready(Ok(buf.len()))
77            }
78            Err(e) => match e {
79                TrySendError::Full(_) => {
80                    self.shared.lock().unwrap().write_waker = Some(cx.waker().clone());
81                    Poll::Pending
82                }
83                TrySendError::Closed(_) => Poll::Ready(Err(std::io::Error::new(
84                    std::io::ErrorKind::BrokenPipe,
85                    "receiver closed",
86                ))),
87            },
88        }
89    }
90
91    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
92        // TODO: Implement kind of flushing
93        Poll::Ready(Ok(()))
94    }
95
96    fn poll_shutdown(
97        self: Pin<&mut Self>,
98        _cx: &mut Context<'_>,
99    ) -> Poll<Result<(), std::io::Error>> {
100        // TODO: Check if there is something to do. Maybe drop sender?
101        Poll::Ready(Ok(()))
102    }
103}
104
105impl AsyncRead for PipeRead {
106    fn poll_read(
107        mut self: Pin<&mut Self>,
108        cx: &mut Context<'_>,
109        buf: &mut ReadBuf<'_>,
110    ) -> Poll<std::io::Result<()>> {
111        let mut write_to_buf = |vec: &[u8]| -> Vec<u8> {
112            let end = std::cmp::min(buf.remaining(), vec.len());
113            let slice_to_write = &vec[0..end];
114            buf.put_slice(slice_to_write);
115
116            let rest_of_vec = &vec[end..];
117            rest_of_vec.to_vec()
118        };
119
120        if self.read_remaining.len() > 0 {
121            self.read_remaining = write_to_buf(&mut self.read_remaining);
122            return Poll::Ready(Ok(()));
123        }
124
125        match self.receiver.poll_recv(cx) {
126            Poll::Ready(v) => match v {
127                None => Poll::Ready(Err(std::io::Error::new(
128                    std::io::ErrorKind::BrokenPipe,
129                    "sender closed",
130                ))),
131                Some(v) => {
132                    self.read_remaining = write_to_buf(&v);
133                    if let Some(waker) = self.shared.lock().unwrap().write_waker.take() {
134                        waker.wake();
135                    }
136                    Poll::Ready(Ok(()))
137                }
138            },
139            Poll::Pending => Poll::Pending,
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use tokio::io::{AsyncReadExt, AsyncWriteExt};
147
148    use super::*;
149
150    #[tokio::test]
151    async fn test_single_write() {
152        let (mut reader, mut writer) = pipe(512);
153
154        let to_send = b"hello world";
155        writer.write_all(to_send).await.expect("error writing");
156
157        let mut buffer = vec![0u8; to_send.len()];
158        reader.read_exact(&mut buffer).await.expect("error reading");
159
160        assert_eq!(&buffer, to_send);
161    }
162
163    #[tokio::test]
164    async fn test_multi_write() {
165        let (mut reader, mut writer) = pipe(512);
166
167        let to_send = b"hello world";
168        writer.write_all(b"hello").await.expect("error writing");
169        writer.write_all(b" world").await.expect("error writing");
170
171        let mut buffer = vec![0u8; to_send.len()];
172        reader.read_exact(&mut buffer).await.expect("error reading");
173
174        assert_eq!(&buffer, to_send);
175    }
176
177    #[tokio::test]
178    async fn test_write_more_than_buffer() {
179        let (mut reader, mut writer) = pipe(2);
180
181        let to_send = b"hello world";
182        tokio::spawn(async move {
183            writer.write_all(b"hello").await.expect("error writing");
184            writer.write_all(b" world").await.expect("error writing");
185        });
186
187        let mut buffer = vec![0u8; to_send.len()];
188        reader.read_exact(&mut buffer).await.expect("error reading");
189
190        assert_eq!(&buffer, to_send);
191    }
192}