use super::*;
use crate::{Advance, AdvanceAsync, Buffer};
use alloc::vec::Vec;
const CAPACITY_FACTOR: usize = 2;
#[derive(Debug)]
pub struct ReadSource<T, S> {
stream: S,
chunk_size: usize,
buf: Vec<T>,
pos: usize,
end: usize,
exhausted: bool,
}
impl<T, S> ReadSource<T, S> {
pub fn new(stream: S, chunk_size: usize) -> Self {
ReadSource {
stream,
chunk_size,
buf: Vec::with_capacity(CAPACITY_FACTOR * chunk_size),
pos: 0,
end: 0,
exhausted: false,
}
}
fn begin_read(&mut self)
where
T: Default,
{
if self.buf.capacity() - self.end < self.chunk_size {
self.buf.drain(..self.pos);
self.end -= self.pos;
self.pos = 0;
}
self.buf.resize_with(self.end + self.chunk_size, T::default);
}
fn finish_read(&mut self, read: usize) {
self.end += read;
if read == 0 {
self.exhausted = true;
}
}
}
impl<S> Advance for ReadSource<u8, S>
where
S: Read,
{
type Error = S::Error;
fn advance(&mut self) -> Result<(), Self::Error> {
self.begin_read();
let read = self.stream.read(&mut self.buf[self.end..])?;
self.finish_read(read);
Ok(())
}
}
impl<S> AdvanceAsync for ReadSource<u8, S>
where
S: AsyncRead + Unpin,
{
type Error = S::Error;
fn poll_advance(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let this = self.get_mut();
this.begin_read();
match Pin::new(&mut this.stream).poll_read(cx, &mut this.buf[this.end..]) {
Poll::Ready(Ok(read)) => {
this.finish_read(read);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => Poll::Pending,
}
}
}
impl<T, S> Buffer for ReadSource<T, S> {
type Output = [T];
fn buffer(&mut self) -> (&Self::Output, &mut usize) {
(&self.buf[self.pos..self.end], &mut self.pos)
}
fn exhausted(&self) -> bool {
self.exhausted
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{advance_async, tests::block_on, Advance, Buffer};
use core::{
cmp::min,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use pin_project_lite::pin_project;
#[derive(Debug)]
struct Error;
struct TestRead<'a> {
data: &'a [u8],
max_read: usize,
}
impl<'a> TestRead<'a> {
fn new(data: &'a [u8]) -> Self {
Self::new_with_max_read(data, usize::MAX)
}
fn new_with_max_read(data: &'a [u8], max_read: usize) -> Self {
TestRead { data, max_read }
}
}
impl<'a> Read for TestRead<'a> {
type Error = Error;
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
let to_read = min(min(self.data.len(), buf.len()), self.max_read);
buf[..to_read].copy_from_slice(&self.data[..to_read]);
self.data = &self.data[to_read..];
Ok(to_read)
}
}
impl<'a> AsyncRead for TestRead<'a> {
type Error = Error;
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, Self::Error>> {
Poll::Ready(self.read(buf))
}
}
struct ErrorRead;
impl Read for ErrorRead {
type Error = Error;
fn read(&mut self, _buf: &mut [u8]) -> Result<usize, Self::Error> {
Err(Error)
}
}
impl AsyncRead for ErrorRead {
type Error = Error;
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context,
_buf: &mut [u8],
) -> Poll<Result<usize, Self::Error>> {
Poll::Ready(Err(Error))
}
}
struct PendingRead;
impl AsyncRead for PendingRead {
type Error = Error;
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context,
_buf: &mut [u8],
) -> Poll<Result<usize, Self::Error>> {
Poll::Pending
}
}
pin_project! {
struct PollOnce<F> {
#[pin]
underlying: F,
}
}
impl<F: Future> Future for PollOnce<F> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let _ = self.project().underlying.poll(cx);
Poll::Ready(())
}
}
#[test]
fn should_fill_buffer() {
let r = TestRead::new(b"1234_1234_abc");
let mut src = ReadSource::new(r, 5);
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234_");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234_1234_");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234_1234_abc");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234_1234_abc");
assert!(src.exhausted());
}
#[test]
fn should_advance_pos() {
let r = TestRead::new(b"1234_1234_1234_1234_abc");
let mut src = ReadSource::new(r, 10);
src.advance().unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"1234_1234_");
*pos += 8;
let (buf, _) = src.buffer();
assert_eq!(buf, b"4_");
src.advance().unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"4_1234_1234_");
*pos += 12;
let (buf, _) = src.buffer();
assert_eq!(buf, b"");
src.advance().unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"abc");
*pos += 2;
let (buf, _) = src.buffer();
assert_eq!(buf, b"c");
}
#[test]
fn should_leave_buffer_unchanged_on_error() {
let mut src = ReadSource::new(ErrorRead, 10);
let result = src.advance();
assert!(result.is_err());
assert_eq!(src.buffer().0, b"");
}
#[test]
fn should_fill_buffer_async() {
let r = TestRead::new(b"1234_1234_abc");
let mut src = ReadSource::new(r, 5);
block_on(async move {
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"1234_");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"1234_1234_");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"1234_1234_abc");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"1234_1234_abc");
assert!(src.exhausted());
});
}
#[test]
fn should_advance_pos_async() {
let r = TestRead::new(b"1234_1234_1234_1234_abc");
let mut src = ReadSource::new(r, 10);
block_on(async move {
advance_async(&mut src).await.unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"1234_1234_");
*pos += 8;
let (buf, _) = src.buffer();
assert_eq!(buf, b"4_");
advance_async(&mut src).await.unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"4_1234_1234_");
*pos += 12;
let (buf, _) = src.buffer();
assert_eq!(buf, b"");
advance_async(&mut src).await.unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"abc");
*pos += 2;
let (buf, _) = src.buffer();
assert_eq!(buf, b"c");
});
}
#[test]
fn should_leave_buffer_unchanged_on_error_async() {
let mut src = ReadSource::new(ErrorRead, 10);
block_on(async move {
let result = advance_async(&mut src).await;
assert!(result.is_err());
assert_eq!(src.buffer().0, b"");
});
}
#[test]
fn should_leave_buffer_unchanged_on_pending_poll() {
let mut src = ReadSource::new(PendingRead, 10);
block_on(async move {
PollOnce {
underlying: advance_async(&mut src),
}
.await;
assert_eq!(src.buffer().0, b"");
});
}
#[test]
fn should_handle_short_reads() {
let r = TestRead::new_with_max_read(b"1234_1234_1234_1234_", 4);
let mut src = ReadSource::new(r, 5);
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234");
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234_123");
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234_1234_12");
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234_1234_1234_1");
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234_1234_1234_1234_");
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234_1234_1234_1234_");
}
}