Skip to main content

xet_runtime/utils/
async_read.rs

1use std::io::Write;
2use std::pin::Pin;
3use std::task::{Context, Poll, ready};
4
5use futures::{AsyncRead, AsyncReadExt};
6
7// (AsyncRead) adaptor
8// wraps over an AsyncRead, copying all the contents read from the inner reader
9// and buffers it in an internal buffer which can be retrieved by calling .consume()
10// to return a copy of all the content that was read.
11pub struct CopyReader<'r, 'w, R, W> {
12    src: Pin<&'r mut R>,
13    writer: &'w mut W,
14}
15
16impl<R: AsyncRead + Unpin, W: Write> AsyncRead for CopyReader<'_, '_, R, W> {
17    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
18        let res = ready!(self.src.as_mut().poll_read(cx, buf))?;
19        self.writer.write_all(&buf[..res])?;
20        Poll::Ready(Ok(res))
21    }
22}
23
24impl<R: AsyncRead + Unpin, W: Write> Unpin for CopyReader<'_, '_, R, W> {}
25
26impl<'r, 'w, R: AsyncRead + Unpin, W: Write> CopyReader<'r, 'w, R, W> {
27    pub fn new(src: &'r mut R, writer: &'w mut W) -> Self {
28        let src = Pin::new(src);
29        Self { src, writer }
30    }
31}
32
33/// An extension trait for `AsyncRead` that provides additional methods beyond the normal `AsyncReadExt`.
34#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
35#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
36pub trait AsyncReadCustomExt: AsyncReadExt + Unpin {
37    /// Reads all data from the stream until EOF and discards it.
38    async fn drain<'a>(&'a mut self) -> std::io::Result<()> {
39        const BUFFER_SIZE: usize = 8192;
40        let mut buf = [0u8; BUFFER_SIZE];
41        loop {
42            let n = self.read(&mut buf).await?;
43            if n == 0 {
44                break; // EOF
45            }
46        }
47        Ok(())
48    }
49}
50
51#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
52#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
53impl<T: AsyncReadExt + Unpin> AsyncReadCustomExt for T {}
54
55#[cfg(test)]
56mod tests {
57    use std::io::{Read, Seek, SeekFrom};
58
59    use bytes::Bytes;
60    use futures::io::Cursor;
61    use futures::{AsyncReadExt, TryStreamExt};
62    use tempfile::tempfile;
63
64    use super::*;
65
66    #[tokio::test]
67    async fn test_copy_reader() {
68        let readers: Vec<Box<dyn AsyncRead + Unpin>> = vec![
69            Box::new(Cursor::new("abcdef".as_bytes())),
70            Box::new(Cursor::new(vec![0x88; 1024])),
71            n_stream(3),
72            n_stream(500),
73            n_stream(10000),
74        ];
75
76        for mut reader in readers {
77            let mut writer = Vec::new();
78            let mut copy_reader = CopyReader::new(&mut reader, &mut writer);
79            let mut buf = Vec::new();
80            assert!(copy_reader.read_to_end(&mut buf).await.is_ok());
81            assert_eq!(buf, writer);
82        }
83    }
84
85    #[tokio::test]
86    async fn test_copy_reader_to_file() {
87        let readers: Vec<Box<dyn AsyncRead + Unpin>> = vec![
88            Box::new(Cursor::new("abcdeflaksjdlakjsldkajlfkjal".as_bytes())),
89            Box::new(Cursor::new(vec![0x88; 10000])),
90            n_stream(3),
91            n_stream(500),
92            n_stream(10000),
93        ];
94
95        for mut reader in readers {
96            let mut writer = tempfile().unwrap();
97            let mut copy_reader = CopyReader::new(&mut reader, &mut writer);
98            let mut buf = Vec::new();
99            assert!(copy_reader.read_to_end(&mut buf).await.is_ok());
100
101            // read file to compare contents
102            assert!(writer.seek(SeekFrom::Start(0)).is_ok());
103            let mut file_contents = Vec::new();
104            assert!(writer.read_to_end(&mut file_contents).is_ok());
105
106            assert_eq!(buf, file_contents);
107        }
108    }
109
110    #[tokio::test]
111    async fn test_copy_reader_partially() {
112        let readers: Vec<Box<dyn AsyncRead + Unpin>> = vec![Box::new(Cursor::new(vec![0x88; 1024])), n_stream(10000)];
113
114        for mut reader in readers {
115            let mut writer = Vec::new();
116            let mut copy_reader = CopyReader::new(&mut reader, &mut writer);
117            let mut buf = vec![0; 512];
118            assert!(copy_reader.read_exact(&mut buf).await.is_ok());
119            assert_eq!(buf, writer);
120        }
121    }
122
123    struct NIter {
124        remaining: usize,
125    }
126
127    impl NIter {
128        fn new(n: usize) -> Self {
129            Self { remaining: n }
130        }
131    }
132
133    impl Iterator for NIter {
134        type Item = Result<Bytes, std::io::Error>;
135
136        fn next(&mut self) -> Option<Self::Item> {
137            if self.remaining == 0 {
138                return None;
139            }
140            self.remaining -= 1;
141            Some(Ok(Bytes::from_static(b"hello world")))
142        }
143    }
144
145    fn n_stream(n: usize) -> Box<dyn AsyncRead + Unpin> {
146        Box::new(futures::stream::iter(NIter::new(n)).into_async_read())
147    }
148}