simple_async_pipe/
lib.rs
1use std::pin::Pin;
21use std::sync::{Arc, Mutex};
22use std::task::{Context, Poll, Waker};
23
24use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
25use tokio::sync::mpsc;
26use tokio::sync::mpsc::error::TrySendError;
27
28pub struct PipeWrite {
29 sender: mpsc::Sender<Vec<u8>>,
30 shared: Arc<Mutex<PipeShared>>,
31}
32
33pub struct PipeRead {
34 read_remaining: Vec<u8>,
35 receiver: mpsc::Receiver<Vec<u8>>,
36 shared: Arc<Mutex<PipeShared>>,
37}
38
39struct PipeShared {
40 read_waker: Option<Waker>,
41 write_waker: Option<Waker>,
42}
43
44pub fn pipe(buffer: usize) -> (PipeRead, PipeWrite) {
46 let (sender, receiver) = mpsc::channel(buffer);
47 let shared = Arc::new(Mutex::new(PipeShared {
48 read_waker: Default::default(),
49 write_waker: Default::default(),
50 }));
51
52 let read = PipeRead {
53 receiver,
54 read_remaining: Default::default(),
55 shared: shared.clone(),
56 };
57 let write = PipeWrite {
58 sender,
59 shared: shared.clone(),
60 };
61
62 (read, write)
63}
64
65impl AsyncWrite for PipeWrite {
66 fn poll_write(
67 self: Pin<&mut Self>,
68 cx: &mut Context<'_>,
69 buf: &[u8],
70 ) -> Poll<Result<usize, std::io::Error>> {
71 match self.sender.try_send(buf.to_vec()) {
72 Ok(_) => {
73 if let Some(read_waker) = self.shared.lock().unwrap().read_waker.take() {
74 read_waker.wake();
75 }
76 Poll::Ready(Ok(buf.len()))
77 }
78 Err(e) => match e {
79 TrySendError::Full(_) => {
80 self.shared.lock().unwrap().write_waker = Some(cx.waker().clone());
81 Poll::Pending
82 }
83 TrySendError::Closed(_) => Poll::Ready(Err(std::io::Error::new(
84 std::io::ErrorKind::BrokenPipe,
85 "receiver closed",
86 ))),
87 },
88 }
89 }
90
91 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
92 Poll::Ready(Ok(()))
94 }
95
96 fn poll_shutdown(
97 self: Pin<&mut Self>,
98 _cx: &mut Context<'_>,
99 ) -> Poll<Result<(), std::io::Error>> {
100 Poll::Ready(Ok(()))
102 }
103}
104
105impl AsyncRead for PipeRead {
106 fn poll_read(
107 mut self: Pin<&mut Self>,
108 cx: &mut Context<'_>,
109 buf: &mut ReadBuf<'_>,
110 ) -> Poll<std::io::Result<()>> {
111 let mut write_to_buf = |vec: &[u8]| -> Vec<u8> {
112 let end = std::cmp::min(buf.remaining(), vec.len());
113 let slice_to_write = &vec[0..end];
114 buf.put_slice(slice_to_write);
115
116 let rest_of_vec = &vec[end..];
117 rest_of_vec.to_vec()
118 };
119
120 if self.read_remaining.len() > 0 {
121 self.read_remaining = write_to_buf(&mut self.read_remaining);
122 return Poll::Ready(Ok(()));
123 }
124
125 match self.receiver.poll_recv(cx) {
126 Poll::Ready(v) => match v {
127 None => Poll::Ready(Err(std::io::Error::new(
128 std::io::ErrorKind::BrokenPipe,
129 "sender closed",
130 ))),
131 Some(v) => {
132 self.read_remaining = write_to_buf(&v);
133 if let Some(waker) = self.shared.lock().unwrap().write_waker.take() {
134 waker.wake();
135 }
136 Poll::Ready(Ok(()))
137 }
138 },
139 Poll::Pending => Poll::Pending,
140 }
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use tokio::io::{AsyncReadExt, AsyncWriteExt};
147
148 use super::*;
149
150 #[tokio::test]
151 async fn test_single_write() {
152 let (mut reader, mut writer) = pipe(512);
153
154 let to_send = b"hello world";
155 writer.write_all(to_send).await.expect("error writing");
156
157 let mut buffer = vec![0u8; to_send.len()];
158 reader.read_exact(&mut buffer).await.expect("error reading");
159
160 assert_eq!(&buffer, to_send);
161 }
162
163 #[tokio::test]
164 async fn test_multi_write() {
165 let (mut reader, mut writer) = pipe(512);
166
167 let to_send = b"hello world";
168 writer.write_all(b"hello").await.expect("error writing");
169 writer.write_all(b" world").await.expect("error writing");
170
171 let mut buffer = vec![0u8; to_send.len()];
172 reader.read_exact(&mut buffer).await.expect("error reading");
173
174 assert_eq!(&buffer, to_send);
175 }
176
177 #[tokio::test]
178 async fn test_write_more_than_buffer() {
179 let (mut reader, mut writer) = pipe(2);
180
181 let to_send = b"hello world";
182 tokio::spawn(async move {
183 writer.write_all(b"hello").await.expect("error writing");
184 writer.write_all(b" world").await.expect("error writing");
185 });
186
187 let mut buffer = vec![0u8; to_send.len()];
188 reader.read_exact(&mut buffer).await.expect("error reading");
189
190 assert_eq!(&buffer, to_send);
191 }
192}