1use std::io::{BufRead, Error, ErrorKind, Read, Result};
4
5use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt};
6
7pub trait BufReadMoreExt {
9 fn read_limit_until(&mut self, byte: u8, buf: &mut Vec<u8>, limit: u64) -> Result<usize>;
15}
16
17#[async_trait::async_trait]
19pub trait AsyncBufReadMoreExt {
20 async fn read_limit_until(&mut self, byte: u8, buf: &mut Vec<u8>, limit: u64) -> Result<usize>;
32}
33
34impl<R: BufRead> BufReadMoreExt for R {
35 fn read_limit_until(&mut self, byte: u8, buf: &mut Vec<u8>, limit: u64) -> Result<usize> {
36 read_limit_until(self, byte, buf, limit)
38 }
39}
40
41fn read_limit_until<R: BufRead>(
42 stream: R,
43 byte: u8,
44 buf: &mut Vec<u8>,
45 limit: u64,
46) -> Result<usize> {
47 let mut stream = stream.take(limit);
48 let amount = stream.read_until(byte, buf)?;
49
50 if amount as u64 == limit && !buf.ends_with(&[byte]) {
51 return Err(Error::new(ErrorKind::InvalidData, "line too long"));
52 }
53
54 Ok(amount)
55}
56
57#[async_trait::async_trait]
58impl<R: AsyncBufRead + Send + Unpin> AsyncBufReadMoreExt for R {
59 async fn read_limit_until(&mut self, byte: u8, buf: &mut Vec<u8>, limit: u64) -> Result<usize> {
60 let mut stream = self.take(limit);
61 let amount = stream.read_until(byte, buf).await?;
62
63 if amount as u64 == limit && !buf.ends_with(&[byte]) {
64 return Err(Error::new(ErrorKind::InvalidData, "line too long"));
65 }
66
67 Ok(amount)
68 }
69}
70
71pub trait PeekRead {
73 fn peek(&mut self, amount: usize) -> Result<&[u8]>;
78
79 fn peek_exact(&mut self, amount: usize) -> Result<&[u8]> {
84 let mut prev_buf_len = 0;
85
86 loop {
87 let buffer = self.peek(amount)?;
88
89 if buffer.len() >= amount {
90 break;
91 } else if prev_buf_len == buffer.len() {
92 return Err(ErrorKind::UnexpectedEof.into());
93 }
94
95 prev_buf_len = buffer.len();
96 }
97
98 self.peek(amount)
99 }
100}
101
102pub trait CountRead {
104 fn read_count(&self) -> u64;
110}
111
112pub trait SourceCountRead {
117 fn source_read_count(&self) -> u64;
119}
120
121pub struct ComboReader<R: Read> {
123 stream: R,
124 buf: Vec<u8>,
125 buf_len_threshold: usize,
126 read_count: u64,
127 source_read_count: u64,
128}
129
130impl<R: Read> ComboReader<R> {
131 pub fn new(reader: R) -> Self {
133 Self {
134 stream: reader,
135 buf: Vec::new(),
136 buf_len_threshold: 4096,
137 read_count: 0,
138 source_read_count: 0,
139 }
140 }
141
142 pub fn get_ref(&self) -> &R {
144 &self.stream
145 }
146
147 pub fn get_mut(&mut self) -> &mut R {
149 &mut self.stream
150 }
151
152 pub fn into_inner(self) -> R {
154 self.stream
155 }
156
157 pub fn buffer(&self) -> &[u8] {
159 &self.buf
160 }
161
162 fn fill_buf_impl(&mut self, amount: usize) -> Result<()> {
163 if self.buf.len() < amount {
164 let offset = self.buf.len();
165 self.buf.resize(offset + self.buf_len_threshold, 0);
166 let amount = self.stream.read(&mut self.buf[offset..])?;
167 self.buf.truncate(offset + amount);
168
169 self.source_read_count += amount as u64;
170 }
171
172 Ok(())
173 }
174
175 fn shift_buf(&mut self, amount: usize) {
176 self.buf.copy_within(amount.., 0);
177 self.buf.truncate(self.buf.len() - amount);
178 }
179}
180
181impl<R: Read> Read for ComboReader<R> {
182 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
183 if !self.buf.is_empty() {
184 let amount = self.buf.len().min(buf.len());
185 (&mut buf[0..amount]).copy_from_slice(&self.buf[0..amount]);
186 self.shift_buf(amount);
187
188 self.read_count += amount as u64;
189
190 Ok(amount)
191 } else if buf.len() >= self.buf_len_threshold {
192 debug_assert!(self.buf.is_empty());
193
194 let amount = self.stream.read(buf)?;
195
196 self.source_read_count += amount as u64;
197 self.read_count += amount as u64;
198
199 Ok(amount)
200 } else {
201 debug_assert!(self.buf.is_empty());
202
203 self.fill_buf()?;
204 let amount = buf.len().min(self.buf.len());
205 (&mut buf[0..amount]).copy_from_slice(&self.buf[0..amount]);
206 self.consume(amount);
207
208 Ok(amount)
209 }
210 }
211}
212
213impl<R: Read> BufRead for ComboReader<R> {
214 fn fill_buf(&mut self) -> Result<&[u8]> {
215 self.fill_buf_impl(self.buf_len_threshold)?;
216
217 Ok(&self.buf)
218 }
219
220 fn consume(&mut self, amount: usize) {
221 let amount = self.buf.len().min(amount);
222 self.shift_buf(amount);
223
224 self.read_count += amount as u64;
225 }
226}
227
228impl<R: Read> PeekRead for ComboReader<R> {
229 fn peek(&mut self, amount: usize) -> Result<&[u8]> {
230 self.fill_buf_impl(amount)?;
231
232 let amount = amount.min(self.buf.len());
233
234 Ok(&self.buf[0..amount])
235 }
236}
237
238impl<R: Read> CountRead for ComboReader<R> {
239 fn read_count(&self) -> u64 {
240 self.read_count
241 }
242}
243
244impl<R: Read> SourceCountRead for ComboReader<R> {
245 fn source_read_count(&self) -> u64 {
246 self.source_read_count
247 }
248}
249
250#[cfg(test)]
251mod tests_sync {
252 use crate::io::{BufReadMoreExt, CountRead, SourceCountRead};
253 use std::io::{BufRead, Cursor, Read};
254
255 use super::{PeekRead, ComboReader};
256
257 #[test]
258 fn test_read_limit_until() {
259 let mut input = Cursor::new(b"a\r\nb\r\n\r\nc");
260 let mut output = Vec::new();
261 let count = input.read_limit_until(b'\n', &mut output, 9999).unwrap();
262
263 assert_eq!(count, 3);
264 assert_eq!(&output, b"a\r\n");
265 assert_eq!(input.position(), 3);
266 }
267
268 #[test]
269 fn test_read_limit_until_eof() {
270 let mut input = Cursor::new(b"abc");
271 let mut output = Vec::new();
272 let count = input.read_limit_until(b'\n', &mut output, 9999).unwrap();
273
274 assert_eq!(count, 3);
275 assert_eq!(&output, b"abc");
276 assert_eq!(input.position(), 3);
277 }
278
279 #[test]
280 fn test_read_limit_until_limit() {
281 let mut input = Cursor::new(b"aaaaabbbbbccccc");
282 let mut output = Vec::new();
283 let result = input.read_limit_until(b'\n', &mut output, 7);
284
285 assert!(result.is_err());
286 }
287
288 #[test]
289 fn test_combo_reader_read() {
290 let input = Cursor::new(b"0123456789abcdef");
291 let mut reader = ComboReader::new(input);
292 let mut output = Vec::new();
293
294 output.resize(2, 0);
295 let amount = reader.read(&mut output).unwrap();
296 assert_eq!(amount, 2);
297 assert_eq!(output, b"01");
298 assert_eq!(reader.buffer(), b"23456789abcdef");
299 assert_eq!(reader.read_count(), 2);
300 assert_eq!(reader.source_read_count(), 16);
301
302 output.resize(4, 0);
303 let amount = reader.read(&mut output).unwrap();
304 assert_eq!(amount, 4);
305 assert_eq!(output, b"2345");
306 assert_eq!(reader.buffer(), b"6789abcdef");
307 assert_eq!(reader.read_count(), 6);
308 assert_eq!(reader.source_read_count(), 16);
309
310 output.resize(100, 0);
311 let amount = reader.read(&mut output).unwrap();
312 assert_eq!(amount, 10);
313 assert_eq!(&output[0..10], b"6789abcdef");
314 assert_eq!(reader.buffer(), b"");
315 assert_eq!(reader.read_count(), 16);
316 assert_eq!(reader.source_read_count(), 16);
317
318 let amount = reader.read(&mut output).unwrap();
319 assert_eq!(amount, 0);
320 assert_eq!(reader.buffer(), b"");
321 assert_eq!(reader.read_count(), 16);
322 assert_eq!(reader.source_read_count(), 16);
323 }
324
325 #[test]
326 fn test_combo_reader_bufread() {
327 let input = Cursor::new(b"0123456789abcdef");
328 let mut reader = ComboReader::new(input);
329
330 let buffer = reader.fill_buf().unwrap();
331 assert_eq!(buffer, b"0123456789abcdef");
332 assert_eq!(reader.read_count(), 0);
333 assert_eq!(reader.source_read_count(), 16);
334
335 reader.consume(4);
336 assert_eq!(reader.buffer(), b"456789abcdef");
337 assert_eq!(reader.read_count(), 4);
338 assert_eq!(reader.source_read_count(), 16);
339
340 let buffer = reader.fill_buf().unwrap();
341 assert_eq!(buffer, b"456789abcdef");
342 assert_eq!(reader.read_count(), 4);
343 assert_eq!(reader.source_read_count(), 16);
344
345 reader.consume(12);
346 assert_eq!(reader.buffer(), b"");
347 assert_eq!(reader.read_count(), 16);
348 assert_eq!(reader.source_read_count(), 16);
349 }
350
351 #[test]
352 fn test_combo_reader_peek() {
353 let input = Cursor::new(b"0123456789abcdef");
354 let mut reader = ComboReader::new(input);
355
356 let output = reader.peek(4).unwrap();
357 assert_eq!(output, b"0123");
358 let output = reader.peek_exact(4).unwrap();
359 assert_eq!(output, b"0123");
360
361 let mut output = Vec::new();
362 output.resize(6, 0);
363
364 reader.read_exact(&mut output).unwrap();
365 assert_eq!(output, b"012345");
366
367 let output = reader.peek(4).unwrap();
368 assert_eq!(output, b"6789");
369 let output = reader.peek_exact(4).unwrap();
370 assert_eq!(output, b"6789");
371
372 let mut output = Vec::new();
373 output.resize(6, 0);
374
375 reader.read_exact(&mut output).unwrap();
376 assert_eq!(output, b"6789ab");
377
378 let result = reader.peek_exact(9999);
379 assert!(result.is_err());
380 }
381
382 #[test]
383 fn test_combo_reader_big_read() {
384 let mut input = Vec::new();
385
386 for _ in 0..5000 {
387 input.extend_from_slice(b"0123456789abcdef");
388 }
389
390 let input = Cursor::new(input);
391 let mut reader = ComboReader::new(input);
392
393 let mut output = Vec::new();
394 output.resize(5000, 0);
395
396 let amount = reader.read(&mut output).unwrap();
397 assert_eq!(amount, 5000);
398 assert_eq!(reader.read_count(), 5000);
399 assert_eq!(reader.source_read_count(), 5000);
400 }
401}
402
403#[cfg(test)]
404mod tests_async {
405 use crate::io::AsyncBufReadMoreExt;
406 use std::io::Cursor;
407
408 #[tokio::test]
409 async fn test_read_limit_until() {
410 let mut input = Cursor::new(b"a\r\nb\r\n\r\nc");
411 let mut output = Vec::new();
412 let count = input
413 .read_limit_until(b'\n', &mut output, 9999)
414 .await
415 .unwrap();
416
417 assert_eq!(count, 3);
418 assert_eq!(&output, b"a\r\n");
419 assert_eq!(input.position(), 3);
420 }
421
422 #[tokio::test]
423 async fn test_read_limit_until_eof() {
424 let mut input = Cursor::new(b"abc");
425 let mut output = Vec::new();
426 let count = input
427 .read_limit_until(b'\n', &mut output, 9999)
428 .await
429 .unwrap();
430
431 assert_eq!(count, 3);
432 assert_eq!(&output, b"abc");
433 assert_eq!(input.position(), 3);
434 }
435
436 #[tokio::test]
437 async fn test_read_limit_until_limit() {
438 let mut input = Cursor::new(b"aaaaabbbbbccccc");
439 let mut output = Vec::new();
440 let result = input.read_limit_until(b'\n', &mut output, 7).await;
441
442 assert!(result.is_err());
443 }
444}