use crate::{Advance, AdvanceAsync, Buffer};
use alloc::vec::Vec;
use core::{
pin::Pin,
task::{Context, Poll},
};
use futures_core::Stream;
use pin_project_lite::pin_project;
#[derive(Debug)]
pub enum Infallible {}
pin_project! {
pub struct InfallibleWrapper<S> {
#[pin]
underlying: S,
}
}
impl<S> Iterator for InfallibleWrapper<S>
where
S: Iterator,
{
type Item = Result<S::Item, Infallible>;
fn next(&mut self) -> Option<Self::Item> {
self.underlying.next().map(Ok)
}
}
impl<S> Stream for InfallibleWrapper<S>
where
S: Stream,
{
type Item = Result<S::Item, Infallible>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.project().underlying.poll_next(cx) {
Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
pub struct IterSource<T, S> {
stream: S,
buf: Vec<T>,
pos: usize,
exhausted: bool,
}
impl<T, S> IterSource<T, S> {
pub fn new(stream: S) -> Self {
IterSource {
stream,
buf: Vec::new(),
pos: 0,
exhausted: false,
}
}
}
impl<T, S> IterSource<T, InfallibleWrapper<S>> {
pub fn new_infallible(stream: S) -> Self {
IterSource::new(InfallibleWrapper { underlying: stream })
}
}
fn handle_buffer_result<T: Clone, B: AsRef<[T]>, E>(
buffer_result: Option<Result<B, E>>,
buf: &mut Vec<T>,
pos: &mut usize,
exhausted: &mut bool,
) -> Result<(), E> {
match buffer_result {
Some(Ok(buffer)) => {
let buffer = buffer.as_ref();
if buf.capacity() - buf.len() < buffer.len() {
buf.drain(..*pos);
*pos = 0;
}
buf.extend_from_slice(buffer);
Ok(())
}
Some(Err(err)) => Err(err),
None => {
*exhausted = true;
Ok(())
}
}
}
impl<B, E, T, S> Advance for IterSource<T, S>
where
T: Clone,
B: AsRef<[T]>,
S: Iterator<Item = Result<B, E>>,
{
type Error = E;
fn advance(&mut self) -> Result<(), Self::Error> {
if self.exhausted {
return Ok(());
}
handle_buffer_result(
self.stream.next(),
&mut self.buf,
&mut self.pos,
&mut self.exhausted,
)
}
}
impl<B, E, T, S> AdvanceAsync for IterSource<T, S>
where
T: Clone + Unpin,
B: AsRef<[T]>,
S: Stream<Item = Result<B, E>> + Unpin,
{
type Error = E;
fn poll_advance(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let this = self.get_mut();
if this.exhausted {
return Poll::Ready(Ok(()));
}
match Pin::new(&mut this.stream).poll_next(cx) {
Poll::Ready(poll_result) => Poll::Ready(handle_buffer_result(
poll_result,
&mut this.buf,
&mut this.pos,
&mut this.exhausted,
)),
Poll::Pending => Poll::Pending,
}
}
}
impl<T, S> Buffer for IterSource<T, S> {
type Output = [T];
fn buffer(&mut self) -> (&Self::Output, &mut usize) {
(&self.buf[self.pos..], &mut self.pos)
}
fn exhausted(&self) -> bool {
self.exhausted
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{advance_async, tests::block_on};
use core::iter;
use futures_util::stream;
#[derive(Debug)]
struct Error;
#[derive(Default)]
struct PanickingSource {
done: bool,
}
impl Iterator for PanickingSource {
type Item = &'static [u8];
fn next(&mut self) -> Option<Self::Item> {
if !self.done {
self.done = true;
None
} else {
panic!("already exhausted")
}
}
}
impl Stream for PanickingSource {
type Item = &'static [u8];
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if !self.done {
self.done = true;
Poll::Ready(None)
} else {
panic!("already exhausted")
}
}
}
#[test]
fn should_fill_buffer_from_regular_reads() {
let r = [b"12345", b"12346", b"12347"];
let mut src = IterSource::new_infallible(r.iter());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"12345");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"1234512346");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"123451234612347");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"123451234612347");
assert!(src.exhausted());
}
#[test]
fn should_fill_buffer_from_irregular_reads() {
let r = [
b"12345".as_ref(),
b"".as_ref(),
b"_12345_".as_ref(),
b"ab".as_ref(),
b"c".as_ref(),
];
let mut src = IterSource::new_infallible(r.iter());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"12345");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"12345");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"12345_12345_");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"12345_12345_ab");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"12345_12345_abc");
assert!(!src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"12345_12345_abc");
assert!(src.exhausted());
}
#[test]
fn should_advance_pos() {
let r = [
b"1234_1234_".as_ref(),
b"1234_1234_".as_ref(),
b"abc".as_ref(),
];
let mut src = IterSource::new_infallible(r.iter());
src.advance().unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"1234_1234_");
*pos += 8;
let (buf, _) = src.buffer();
assert_eq!(buf, b"4_");
src.advance().unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"4_1234_1234_");
*pos += 12;
let (buf, _) = src.buffer();
assert_eq!(buf, b"");
src.advance().unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"abc");
*pos += 2;
let (buf, _) = src.buffer();
assert_eq!(buf, b"c");
}
#[test]
fn should_leave_buffer_unchanged_on_error() {
let r = iter::once::<Result<&[u8], Error>>(Err(Error));
let mut src = IterSource::new(r);
let result = src.advance();
assert!(result.is_err());
assert_eq!(src.buffer().0, b"");
}
#[test]
fn should_not_poll_source_after_none() {
let mut src = IterSource::new_infallible(PanickingSource::default());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"");
assert!(src.exhausted());
src.advance().unwrap();
assert_eq!(src.buffer().0, b"");
assert!(src.exhausted());
}
#[test]
fn should_fill_buffer_from_regular_reads_async() {
let r = [b"12345", b"12346", b"12347"];
let mut src = IterSource::new_infallible(stream::iter(r.iter()));
block_on(async move {
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"12345");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"1234512346");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"123451234612347");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"123451234612347");
assert!(src.exhausted());
});
}
#[test]
fn should_fill_buffer_from_irregular_reads_async() {
let r = [
b"12345".as_ref(),
b"".as_ref(),
b"_12345_".as_ref(),
b"ab".as_ref(),
b"c".as_ref(),
];
let mut src = IterSource::new_infallible(stream::iter(r.iter()));
block_on(async move {
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"12345");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"12345");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"12345_12345_");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"12345_12345_ab");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"12345_12345_abc");
assert!(!src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"12345_12345_abc");
assert!(src.exhausted());
});
}
#[test]
fn should_advance_pos_async() {
let r = [
b"1234_1234_".as_ref(),
b"1234_1234_".as_ref(),
b"abc".as_ref(),
];
let mut src = IterSource::new_infallible(stream::iter(r.iter()));
block_on(async move {
advance_async(&mut src).await.unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"1234_1234_");
*pos += 8;
let (buf, _) = src.buffer();
assert_eq!(buf, b"4_");
advance_async(&mut src).await.unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"4_1234_1234_");
*pos += 12;
let (buf, _) = src.buffer();
assert_eq!(buf, b"");
advance_async(&mut src).await.unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"abc");
*pos += 2;
let (buf, _) = src.buffer();
assert_eq!(buf, b"c");
});
}
#[test]
fn should_leave_buffer_unchanged_on_error_async() {
let r = iter::once::<Result<&[u8], Error>>(Err(Error));
let mut src = IterSource::new(stream::iter(r));
block_on(async move {
let result = advance_async(&mut src).await;
assert!(result.is_err());
assert_eq!(src.buffer().0, b"");
});
}
#[test]
fn should_not_poll_source_after_none_async() {
let mut src = IterSource::new_infallible(PanickingSource::default());
block_on(async move {
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"");
assert!(src.exhausted());
advance_async(&mut src).await.unwrap();
assert_eq!(src.buffer().0, b"");
assert!(src.exhausted());
});
}
#[test]
fn should_advance_through_buffer() {
let r = iter::once(b"1234_1234_abc");
let mut src = IterSource::new_infallible(r);
src.advance().unwrap();
let (buf, pos) = src.buffer();
assert_eq!(buf, b"1234_1234_abc");
*pos += 5;
let (buf, pos) = src.buffer();
assert_eq!(buf, b"1234_abc");
*pos += 5;
let (buf, pos) = src.buffer();
assert_eq!(buf, b"abc");
*pos += 2;
let (buf, _) = src.buffer();
assert_eq!(buf, b"c");
}
}