webaves/
io.rs

1//! IO helpers.
2
3use std::io::{BufRead, Error, ErrorKind, Read, Result};
4
5use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt};
6
7/// Extension trait for [std::io::BufRead].
8pub trait BufReadMoreExt {
9    /// Reads bytes into `buf` until the delimiter `byte` or EOF is reached.
10    ///
11    /// This function is similar to [std::io::BufRead::read_until].
12    /// In addition, this function returns an error when the number of bytes
13    /// read equals `limit` and the deliminator has not been reached.
14    fn read_limit_until(&mut self, byte: u8, buf: &mut Vec<u8>, limit: u64) -> Result<usize>;
15}
16
17/// Extension trait for [tokio::io::AsyncBufRead].
18#[async_trait::async_trait]
19pub trait AsyncBufReadMoreExt {
20    /// Reads bytes into `buf` until the delimiter `byte` or EOF is reached.
21    ///
22    /// Equivalent to:
23    ///
24    /// ```ignore
25    /// async fn read_limit_until(&mut self, byte: u8, buf: &mut Vec<u8>, limit: u64) -> Result<usize>;
26    /// ```
27    ///
28    /// This function is similar to [tokio::io::AsyncBufReadExt::read_until].
29    /// In addition, this function returns an error when the number of bytes
30    /// read equals `limit` and the deliminator has not been reached.
31    async fn read_limit_until(&mut self, byte: u8, buf: &mut Vec<u8>, limit: u64) -> Result<usize>;
32}
33
34impl<R: BufRead> BufReadMoreExt for R {
35    fn read_limit_until(&mut self, byte: u8, buf: &mut Vec<u8>, limit: u64) -> Result<usize> {
36        // Compiler won't use Take<&mut R> in trait here so it's in a separate function.
37        read_limit_until(self, byte, buf, limit)
38    }
39}
40
41fn read_limit_until<R: BufRead>(
42    stream: R,
43    byte: u8,
44    buf: &mut Vec<u8>,
45    limit: u64,
46) -> Result<usize> {
47    let mut stream = stream.take(limit);
48    let amount = stream.read_until(byte, buf)?;
49
50    if amount as u64 == limit && !buf.ends_with(&[byte]) {
51        return Err(Error::new(ErrorKind::InvalidData, "line too long"));
52    }
53
54    Ok(amount)
55}
56
57#[async_trait::async_trait]
58impl<R: AsyncBufRead + Send + Unpin> AsyncBufReadMoreExt for R {
59    async fn read_limit_until(&mut self, byte: u8, buf: &mut Vec<u8>, limit: u64) -> Result<usize> {
60        let mut stream = self.take(limit);
61        let amount = stream.read_until(byte, buf).await?;
62
63        if amount as u64 == limit && !buf.ends_with(&[byte]) {
64            return Err(Error::new(ErrorKind::InvalidData, "line too long"));
65        }
66
67        Ok(amount)
68    }
69}
70
71/// Read data without consuming it.
72pub trait PeekRead {
73    /// Returns data from the stream without advancing the stream position.
74    ///
75    /// At most one read call is made to fill the buffer and returns a slice
76    /// to the buffer. The length of the slice may be smaller than requested.
77    fn peek(&mut self, amount: usize) -> Result<&[u8]>;
78
79    /// Returns data from the stream without advancing the stream position.
80    ///
81    /// This function is similar to [Self:peek] except the length of the slice
82    /// returned will be equal to `amount`. Returns an error if EOF.
83    fn peek_exact(&mut self, amount: usize) -> Result<&[u8]> {
84        let mut prev_buf_len = 0;
85
86        loop {
87            let buffer = self.peek(amount)?;
88
89            if buffer.len() >= amount {
90                break;
91            } else if prev_buf_len == buffer.len() {
92                return Err(ErrorKind::UnexpectedEof.into());
93            }
94
95            prev_buf_len = buffer.len();
96        }
97
98        self.peek(amount)
99    }
100}
101
102/// Count number of bytes read.
103pub trait CountRead {
104    /// Returns the number of bytes read from this stream.
105    ///
106    /// The value represents the number of bytes marked as consumed and not
107    /// bytes stored in internal buffers. If the stream is seekable, seeking
108    /// does not affect this value.
109    fn read_count(&self) -> u64;
110}
111
112/// Count number of bytes from a source stream.
113///
114/// This trait is for reader objects that wrap another stream and transform
115/// data such as a decoders.
116pub trait SourceCountRead {
117    /// Returns the number of bytes read by this object from the source stream.
118    fn source_read_count(&self) -> u64;
119}
120
121/// Buffered reader various features implemented.
122pub struct ComboReader<R: Read> {
123    stream: R,
124    buf: Vec<u8>,
125    buf_len_threshold: usize,
126    read_count: u64,
127    source_read_count: u64,
128}
129
130impl<R: Read> ComboReader<R> {
131    /// Creates a reader with the given stream.
132    pub fn new(reader: R) -> Self {
133        Self {
134            stream: reader,
135            buf: Vec::new(),
136            buf_len_threshold: 4096,
137            read_count: 0,
138            source_read_count: 0,
139        }
140    }
141
142    /// Returns a reference to the wrapped stream.
143    pub fn get_ref(&self) -> &R {
144        &self.stream
145    }
146
147    /// Returns a mutable reference to the wrapped stream.
148    pub fn get_mut(&mut self) -> &mut R {
149        &mut self.stream
150    }
151
152    /// Returns the wrapped stream.
153    pub fn into_inner(self) -> R {
154        self.stream
155    }
156
157    /// Returns a reference to the internal buffer.
158    pub fn buffer(&self) -> &[u8] {
159        &self.buf
160    }
161
162    fn fill_buf_impl(&mut self, amount: usize) -> Result<()> {
163        if self.buf.len() < amount {
164            let offset = self.buf.len();
165            self.buf.resize(offset + self.buf_len_threshold, 0);
166            let amount = self.stream.read(&mut self.buf[offset..])?;
167            self.buf.truncate(offset + amount);
168
169            self.source_read_count += amount as u64;
170        }
171
172        Ok(())
173    }
174
175    fn shift_buf(&mut self, amount: usize) {
176        self.buf.copy_within(amount.., 0);
177        self.buf.truncate(self.buf.len() - amount);
178    }
179}
180
181impl<R: Read> Read for ComboReader<R> {
182    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
183        if !self.buf.is_empty() {
184            let amount = self.buf.len().min(buf.len());
185            (&mut buf[0..amount]).copy_from_slice(&self.buf[0..amount]);
186            self.shift_buf(amount);
187
188            self.read_count += amount as u64;
189
190            Ok(amount)
191        } else if buf.len() >= self.buf_len_threshold {
192            debug_assert!(self.buf.is_empty());
193
194            let amount = self.stream.read(buf)?;
195
196            self.source_read_count += amount as u64;
197            self.read_count += amount as u64;
198
199            Ok(amount)
200        } else {
201            debug_assert!(self.buf.is_empty());
202
203            self.fill_buf()?;
204            let amount = buf.len().min(self.buf.len());
205            (&mut buf[0..amount]).copy_from_slice(&self.buf[0..amount]);
206            self.consume(amount);
207
208            Ok(amount)
209        }
210    }
211}
212
213impl<R: Read> BufRead for ComboReader<R> {
214    fn fill_buf(&mut self) -> Result<&[u8]> {
215        self.fill_buf_impl(self.buf_len_threshold)?;
216
217        Ok(&self.buf)
218    }
219
220    fn consume(&mut self, amount: usize) {
221        let amount = self.buf.len().min(amount);
222        self.shift_buf(amount);
223
224        self.read_count += amount as u64;
225    }
226}
227
228impl<R: Read> PeekRead for ComboReader<R> {
229    fn peek(&mut self, amount: usize) -> Result<&[u8]> {
230        self.fill_buf_impl(amount)?;
231
232        let amount = amount.min(self.buf.len());
233
234        Ok(&self.buf[0..amount])
235    }
236}
237
238impl<R: Read> CountRead for ComboReader<R> {
239    fn read_count(&self) -> u64 {
240        self.read_count
241    }
242}
243
244impl<R: Read> SourceCountRead for ComboReader<R> {
245    fn source_read_count(&self) -> u64 {
246        self.source_read_count
247    }
248}
249
250#[cfg(test)]
251mod tests_sync {
252    use crate::io::{BufReadMoreExt, CountRead, SourceCountRead};
253    use std::io::{BufRead, Cursor, Read};
254
255    use super::{PeekRead, ComboReader};
256
257    #[test]
258    fn test_read_limit_until() {
259        let mut input = Cursor::new(b"a\r\nb\r\n\r\nc");
260        let mut output = Vec::new();
261        let count = input.read_limit_until(b'\n', &mut output, 9999).unwrap();
262
263        assert_eq!(count, 3);
264        assert_eq!(&output, b"a\r\n");
265        assert_eq!(input.position(), 3);
266    }
267
268    #[test]
269    fn test_read_limit_until_eof() {
270        let mut input = Cursor::new(b"abc");
271        let mut output = Vec::new();
272        let count = input.read_limit_until(b'\n', &mut output, 9999).unwrap();
273
274        assert_eq!(count, 3);
275        assert_eq!(&output, b"abc");
276        assert_eq!(input.position(), 3);
277    }
278
279    #[test]
280    fn test_read_limit_until_limit() {
281        let mut input = Cursor::new(b"aaaaabbbbbccccc");
282        let mut output = Vec::new();
283        let result = input.read_limit_until(b'\n', &mut output, 7);
284
285        assert!(result.is_err());
286    }
287
288    #[test]
289    fn test_combo_reader_read() {
290        let input = Cursor::new(b"0123456789abcdef");
291        let mut reader = ComboReader::new(input);
292        let mut output = Vec::new();
293
294        output.resize(2, 0);
295        let amount = reader.read(&mut output).unwrap();
296        assert_eq!(amount, 2);
297        assert_eq!(output, b"01");
298        assert_eq!(reader.buffer(), b"23456789abcdef");
299        assert_eq!(reader.read_count(), 2);
300        assert_eq!(reader.source_read_count(), 16);
301
302        output.resize(4, 0);
303        let amount = reader.read(&mut output).unwrap();
304        assert_eq!(amount, 4);
305        assert_eq!(output, b"2345");
306        assert_eq!(reader.buffer(), b"6789abcdef");
307        assert_eq!(reader.read_count(), 6);
308        assert_eq!(reader.source_read_count(), 16);
309
310        output.resize(100, 0);
311        let amount = reader.read(&mut output).unwrap();
312        assert_eq!(amount, 10);
313        assert_eq!(&output[0..10], b"6789abcdef");
314        assert_eq!(reader.buffer(), b"");
315        assert_eq!(reader.read_count(), 16);
316        assert_eq!(reader.source_read_count(), 16);
317
318        let amount = reader.read(&mut output).unwrap();
319        assert_eq!(amount, 0);
320        assert_eq!(reader.buffer(), b"");
321        assert_eq!(reader.read_count(), 16);
322        assert_eq!(reader.source_read_count(), 16);
323    }
324
325    #[test]
326    fn test_combo_reader_bufread() {
327        let input = Cursor::new(b"0123456789abcdef");
328        let mut reader = ComboReader::new(input);
329
330        let buffer = reader.fill_buf().unwrap();
331        assert_eq!(buffer, b"0123456789abcdef");
332        assert_eq!(reader.read_count(), 0);
333        assert_eq!(reader.source_read_count(), 16);
334
335        reader.consume(4);
336        assert_eq!(reader.buffer(), b"456789abcdef");
337        assert_eq!(reader.read_count(), 4);
338        assert_eq!(reader.source_read_count(), 16);
339
340        let buffer = reader.fill_buf().unwrap();
341        assert_eq!(buffer, b"456789abcdef");
342        assert_eq!(reader.read_count(), 4);
343        assert_eq!(reader.source_read_count(), 16);
344
345        reader.consume(12);
346        assert_eq!(reader.buffer(), b"");
347        assert_eq!(reader.read_count(), 16);
348        assert_eq!(reader.source_read_count(), 16);
349    }
350
351    #[test]
352    fn test_combo_reader_peek() {
353        let input = Cursor::new(b"0123456789abcdef");
354        let mut reader = ComboReader::new(input);
355
356        let output = reader.peek(4).unwrap();
357        assert_eq!(output, b"0123");
358        let output = reader.peek_exact(4).unwrap();
359        assert_eq!(output, b"0123");
360
361        let mut output = Vec::new();
362        output.resize(6, 0);
363
364        reader.read_exact(&mut output).unwrap();
365        assert_eq!(output, b"012345");
366
367        let output = reader.peek(4).unwrap();
368        assert_eq!(output, b"6789");
369        let output = reader.peek_exact(4).unwrap();
370        assert_eq!(output, b"6789");
371
372        let mut output = Vec::new();
373        output.resize(6, 0);
374
375        reader.read_exact(&mut output).unwrap();
376        assert_eq!(output, b"6789ab");
377
378        let result = reader.peek_exact(9999);
379        assert!(result.is_err());
380    }
381
382    #[test]
383    fn test_combo_reader_big_read() {
384        let mut input = Vec::new();
385
386        for _ in 0..5000 {
387            input.extend_from_slice(b"0123456789abcdef");
388        }
389
390        let input = Cursor::new(input);
391        let mut reader = ComboReader::new(input);
392
393        let mut output = Vec::new();
394        output.resize(5000, 0);
395
396        let amount = reader.read(&mut output).unwrap();
397        assert_eq!(amount, 5000);
398        assert_eq!(reader.read_count(), 5000);
399        assert_eq!(reader.source_read_count(), 5000);
400    }
401}
402
403#[cfg(test)]
404mod tests_async {
405    use crate::io::AsyncBufReadMoreExt;
406    use std::io::Cursor;
407
408    #[tokio::test]
409    async fn test_read_limit_until() {
410        let mut input = Cursor::new(b"a\r\nb\r\n\r\nc");
411        let mut output = Vec::new();
412        let count = input
413            .read_limit_until(b'\n', &mut output, 9999)
414            .await
415            .unwrap();
416
417        assert_eq!(count, 3);
418        assert_eq!(&output, b"a\r\n");
419        assert_eq!(input.position(), 3);
420    }
421
422    #[tokio::test]
423    async fn test_read_limit_until_eof() {
424        let mut input = Cursor::new(b"abc");
425        let mut output = Vec::new();
426        let count = input
427            .read_limit_until(b'\n', &mut output, 9999)
428            .await
429            .unwrap();
430
431        assert_eq!(count, 3);
432        assert_eq!(&output, b"abc");
433        assert_eq!(input.position(), 3);
434    }
435
436    #[tokio::test]
437    async fn test_read_limit_until_limit() {
438        let mut input = Cursor::new(b"aaaaabbbbbccccc");
439        let mut output = Vec::new();
440        let result = input.read_limit_until(b'\n', &mut output, 7).await;
441
442        assert!(result.is_err());
443    }
444}