xet_runtime/utils/
async_read.rs1use std::io::Write;
2use std::pin::Pin;
3use std::task::{Context, Poll, ready};
4
5use futures::{AsyncRead, AsyncReadExt};
6
7pub struct CopyReader<'r, 'w, R, W> {
12 src: Pin<&'r mut R>,
13 writer: &'w mut W,
14}
15
16impl<R: AsyncRead + Unpin, W: Write> AsyncRead for CopyReader<'_, '_, R, W> {
17 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<std::io::Result<usize>> {
18 let res = ready!(self.src.as_mut().poll_read(cx, buf))?;
19 self.writer.write_all(&buf[..res])?;
20 Poll::Ready(Ok(res))
21 }
22}
23
24impl<R: AsyncRead + Unpin, W: Write> Unpin for CopyReader<'_, '_, R, W> {}
25
26impl<'r, 'w, R: AsyncRead + Unpin, W: Write> CopyReader<'r, 'w, R, W> {
27 pub fn new(src: &'r mut R, writer: &'w mut W) -> Self {
28 let src = Pin::new(src);
29 Self { src, writer }
30 }
31}
32
33#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
35#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
36pub trait AsyncReadCustomExt: AsyncReadExt + Unpin {
37 async fn drain<'a>(&'a mut self) -> std::io::Result<()> {
39 const BUFFER_SIZE: usize = 8192;
40 let mut buf = [0u8; BUFFER_SIZE];
41 loop {
42 let n = self.read(&mut buf).await?;
43 if n == 0 {
44 break; }
46 }
47 Ok(())
48 }
49}
50
51#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
52#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
53impl<T: AsyncReadExt + Unpin> AsyncReadCustomExt for T {}
54
55#[cfg(test)]
56mod tests {
57 use std::io::{Read, Seek, SeekFrom};
58
59 use bytes::Bytes;
60 use futures::io::Cursor;
61 use futures::{AsyncReadExt, TryStreamExt};
62 use tempfile::tempfile;
63
64 use super::*;
65
66 #[tokio::test]
67 async fn test_copy_reader() {
68 let readers: Vec<Box<dyn AsyncRead + Unpin>> = vec![
69 Box::new(Cursor::new("abcdef".as_bytes())),
70 Box::new(Cursor::new(vec![0x88; 1024])),
71 n_stream(3),
72 n_stream(500),
73 n_stream(10000),
74 ];
75
76 for mut reader in readers {
77 let mut writer = Vec::new();
78 let mut copy_reader = CopyReader::new(&mut reader, &mut writer);
79 let mut buf = Vec::new();
80 assert!(copy_reader.read_to_end(&mut buf).await.is_ok());
81 assert_eq!(buf, writer);
82 }
83 }
84
85 #[tokio::test]
86 async fn test_copy_reader_to_file() {
87 let readers: Vec<Box<dyn AsyncRead + Unpin>> = vec![
88 Box::new(Cursor::new("abcdeflaksjdlakjsldkajlfkjal".as_bytes())),
89 Box::new(Cursor::new(vec![0x88; 10000])),
90 n_stream(3),
91 n_stream(500),
92 n_stream(10000),
93 ];
94
95 for mut reader in readers {
96 let mut writer = tempfile().unwrap();
97 let mut copy_reader = CopyReader::new(&mut reader, &mut writer);
98 let mut buf = Vec::new();
99 assert!(copy_reader.read_to_end(&mut buf).await.is_ok());
100
101 assert!(writer.seek(SeekFrom::Start(0)).is_ok());
103 let mut file_contents = Vec::new();
104 assert!(writer.read_to_end(&mut file_contents).is_ok());
105
106 assert_eq!(buf, file_contents);
107 }
108 }
109
110 #[tokio::test]
111 async fn test_copy_reader_partially() {
112 let readers: Vec<Box<dyn AsyncRead + Unpin>> = vec![Box::new(Cursor::new(vec![0x88; 1024])), n_stream(10000)];
113
114 for mut reader in readers {
115 let mut writer = Vec::new();
116 let mut copy_reader = CopyReader::new(&mut reader, &mut writer);
117 let mut buf = vec![0; 512];
118 assert!(copy_reader.read_exact(&mut buf).await.is_ok());
119 assert_eq!(buf, writer);
120 }
121 }
122
123 struct NIter {
124 remaining: usize,
125 }
126
127 impl NIter {
128 fn new(n: usize) -> Self {
129 Self { remaining: n }
130 }
131 }
132
133 impl Iterator for NIter {
134 type Item = Result<Bytes, std::io::Error>;
135
136 fn next(&mut self) -> Option<Self::Item> {
137 if self.remaining == 0 {
138 return None;
139 }
140 self.remaining -= 1;
141 Some(Ok(Bytes::from_static(b"hello world")))
142 }
143 }
144
145 fn n_stream(n: usize) -> Box<dyn AsyncRead + Unpin> {
146 Box::new(futures::stream::iter(NIter::new(n)).into_async_read())
147 }
148}