zng_task/
io.rs

1//! IO tasks.
2//!
3//! Most of the types in this module are re-exported from [`futures_lite::io`].
4//!
5//! [`futures_lite::io`]: https://docs.rs/futures-lite/latest/futures_lite/io/index.html
6
7use std::{
8    fmt,
9    io::{BufRead, ErrorKind, Read, Write},
10    pin::Pin,
11    sync::Arc,
12    task::{self, Poll},
13    time::Duration,
14};
15
16use crate::{McWaker, Progress};
17
18#[doc(no_inline)]
19pub use futures_lite::io::{
20    AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BoxedReader, BoxedWriter,
21    BufReader, BufWriter, Cursor, ReadHalf, WriteHalf, copy, empty, repeat, sink, split,
22};
23use parking_lot::Mutex;
24use std::io::{Error, Result};
25use zng_time::{DInstant, INSTANT};
26use zng_txt::formatx;
27use zng_unit::{ByteLength, ByteUnits};
28use zng_var::{Var, impl_from_and_into_var, var};
29
30struct MeasureInner {
31    metrics: Var<Metrics>,
32    start_time: DInstant,
33    last_write: DInstant,
34    last_read: DInstant,
35}
36impl MeasureInner {
37    fn new(read_progress: (ByteLength, ByteLength), write_progress: (ByteLength, ByteLength)) -> Self {
38        let now = INSTANT.now();
39        Self {
40            metrics: var(Metrics {
41                read_progress,
42                read_speed: 0.bytes(),
43                write_progress,
44                write_speed: 0.bytes(),
45                total_time: Duration::ZERO,
46            }),
47            start_time: now,
48            last_write: now,
49            last_read: now,
50        }
51    }
52
53    fn on_read(&mut self, bytes: usize) {
54        if bytes == 0 {
55            return;
56        }
57
58        let bytes = bytes.bytes();
59
60        let now = INSTANT.now();
61        let elapsed = now - self.last_read;
62
63        self.last_read = now;
64        let read_speed = bytes_per_sec(bytes, elapsed);
65
66        let total_time = now - self.start_time;
67
68        self.metrics.modify(move |m| {
69            m.read_progress.0 += bytes;
70            m.read_speed = read_speed;
71            m.total_time = total_time;
72        });
73    }
74
75    fn on_write(&mut self, bytes: usize) {
76        if bytes == 0 {
77            return;
78        }
79
80        let bytes = bytes.bytes();
81
82        let now = INSTANT.now();
83        let elapsed = now - self.last_write;
84
85        self.last_write = now;
86        let write_speed = bytes_per_sec(bytes, elapsed);
87
88        let total_time = now - self.start_time;
89
90        self.metrics.modify(move |m| {
91            m.write_progress.0 += bytes;
92            m.write_speed = write_speed;
93            m.total_time = total_time;
94        });
95    }
96}
97
98/// Measure read/write of an async task.
99///
100/// Metrics are updated after each read/write, if you read/write all bytes in one call
101/// the metrics will only update once.
102pub struct Measure<T> {
103    task: T,
104    inner: MeasureInner,
105}
106impl<T> Measure<T> {
107    /// Start measuring a new read/write task.
108    pub fn new(task: T, total_read: ByteLength, total_write: ByteLength) -> Self {
109        Self::new_ongoing(task, (0.bytes(), total_read), (0.bytes(), total_write))
110    }
111
112    /// Continue measuring a read/write task.
113    pub fn new_ongoing(task: T, read_progress: (ByteLength, ByteLength), write_progress: (ByteLength, ByteLength)) -> Self {
114        Measure {
115            task,
116            inner: MeasureInner::new(read_progress, write_progress),
117        }
118    }
119
120    /// Current metrics.
121    ///
122    /// This value is updated after every read/write.
123    pub fn metrics(&mut self) -> Var<Metrics> {
124        self.inner.metrics.read_only()
125    }
126
127    /// Unwrap the inner task and final metrics.
128    pub fn finish(self) -> (T, Metrics) {
129        let mut metrics = self.inner.metrics.get();
130        metrics.total_time = self.inner.start_time.elapsed();
131        (self.task, metrics)
132    }
133}
134
135fn bytes_per_sec(bytes: ByteLength, elapsed: Duration) -> ByteLength {
136    let bytes_per_sec = bytes.0 as u128 / elapsed.as_nanos() / Duration::from_secs(1).as_nanos();
137    ByteLength(bytes_per_sec as usize)
138}
139
140impl<T: AsyncRead> AsyncRead for Measure<T> {
141    fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
142        // SAFETY: we don't move anything.
143        let self_ = unsafe { self.get_unchecked_mut() };
144
145        // SAFETY: we don't move task
146        match unsafe { Pin::new_unchecked(&mut self_.task) }.poll_read(cx, buf) {
147            Poll::Ready(Ok(bytes)) => {
148                self_.inner.on_read(bytes);
149                Poll::Ready(Ok(bytes))
150            }
151            p => p,
152        }
153    }
154}
155impl<T: AsyncWrite> AsyncWrite for Measure<T> {
156    fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
157        // SAFETY: we don't move anything.
158        let self_ = unsafe { self.get_unchecked_mut() };
159
160        // SAFETY: we don't move task
161        match unsafe { Pin::new_unchecked(&mut self_.task) }.poll_write(cx, buf) {
162            Poll::Ready(Ok(bytes)) => {
163                self_.inner.on_write(bytes);
164                Poll::Ready(Ok(bytes))
165            }
166            p => p,
167        }
168    }
169
170    fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<()>> {
171        // SAFETY: we don't move anything.
172        let self_ = unsafe { self.get_unchecked_mut() };
173
174        // SAFETY: we don't move task
175        unsafe { Pin::new_unchecked(&mut self_.task) }.poll_flush(cx)
176    }
177
178    fn poll_close(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<()>> {
179        // SAFETY: we don't move anything.
180        let self_ = unsafe { self.get_unchecked_mut() };
181
182        // SAFETY: we don't move task
183        unsafe { Pin::new_unchecked(&mut self_.task) }.poll_flush(cx)
184    }
185}
186impl<T: AsyncBufRead> AsyncBufRead for Measure<T> {
187    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<&[u8]>> {
188        // SAFETY: we don't move anything.
189        let self_ = unsafe { self.get_unchecked_mut() };
190
191        // SAFETY: we don't move task
192        unsafe { Pin::new_unchecked(&mut self_.task) }.poll_fill_buf(cx)
193    }
194
195    fn consume(self: Pin<&mut Self>, amt: usize) {
196        // SAFETY: we don't move anything.
197        let self_ = unsafe { self.get_unchecked_mut() };
198        // SAFETY: we don't move task
199        unsafe { Pin::new_unchecked(&mut self_.task) }.consume(amt);
200        self_.inner.on_read(amt);
201    }
202}
203impl<T: Read> Read for Measure<T> {
204    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
205        match self.task.read(buf) {
206            Ok(bytes) => {
207                self.inner.on_read(bytes);
208                Ok(bytes)
209            }
210            r => r,
211        }
212    }
213}
214impl<T: Write> Write for Measure<T> {
215    fn write(&mut self, buf: &[u8]) -> Result<usize> {
216        match self.task.write(buf) {
217            Ok(bytes) => {
218                self.inner.on_write(bytes);
219                Ok(bytes)
220            }
221            r => r,
222        }
223    }
224
225    fn flush(&mut self) -> Result<()> {
226        self.task.flush()
227    }
228}
229impl<T: BufRead> BufRead for Measure<T> {
230    fn fill_buf(&mut self) -> Result<&[u8]> {
231        self.task.fill_buf()
232    }
233
234    fn consume(&mut self, amount: usize) {
235        self.task.consume(amount);
236        self.inner.on_read(amount);
237    }
238}
239
240/// Information about the state of an async IO task.
241///
242/// Read is also called *receive* or *download*. Write is also called *send* or *upload*. The default
243/// display print uses arrows ↓ and ↑ for read and write.
244///
245/// Use [`Measure`] to measure a task.
246#[derive(Debug, Clone, PartialEq, Eq)]
247#[non_exhaustive]
248pub struct Metrics {
249    /// Number of bytes read / estimated total.
250    pub read_progress: (ByteLength, ByteLength),
251
252    /// Average read speed in bytes/second.
253    pub read_speed: ByteLength,
254
255    /// Number of bytes written / estimated total.
256    pub write_progress: (ByteLength, ByteLength),
257
258    /// Average write speed in bytes/second.
259    pub write_speed: ByteLength,
260
261    /// Total time for the entire task. This will continuously increase until
262    /// the task is finished.
263    pub total_time: Duration,
264}
265impl Metrics {
266    /// All zeros.
267    pub fn zero() -> Self {
268        Self {
269            read_progress: (0.bytes(), 0.bytes()),
270            read_speed: 0.bytes(),
271            write_progress: (0.bytes(), 0.bytes()),
272            write_speed: 0.bytes(),
273            total_time: Duration::ZERO,
274        }
275    }
276}
277impl fmt::Display for Metrics {
278    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279        let mut nl = false;
280        if self.read_progress.1 > 0.bytes() {
281            nl = true;
282            if self.read_progress.0 != self.read_progress.1 {
283                write!(f, "↓ {}-{}, {}/s", self.read_progress.0, self.read_progress.1, self.read_speed)?;
284                nl = true;
285            } else {
286                write!(f, "↓ {} . {:?}", self.read_progress.0, self.total_time)?;
287            }
288        }
289        if self.write_progress.1 > 0.bytes() {
290            if nl {
291                writeln!(f)?;
292            }
293            if self.write_progress.0 != self.write_progress.1 {
294                write!(f, "↑ {} - {}, {}/s", self.write_progress.0, self.write_progress.1, self.write_speed)?;
295            } else {
296                write!(f, "↑ {} . {:?}", self.write_progress.0, self.total_time)?;
297            }
298        }
299
300        Ok(())
301    }
302}
303impl_from_and_into_var! {
304    fn from(metrics: Metrics) -> Progress {
305        let mut status = Progress::indeterminate();
306        if metrics.read_progress.1 > 0.bytes() {
307            status = Progress::from_n_of(metrics.read_progress.0.0, metrics.read_progress.1.0);
308        }
309        if metrics.write_progress.1 > 0.bytes() {
310            let w_status = Progress::from_n_of(metrics.write_progress.0.0, metrics.write_progress.1.0);
311            if status.is_indeterminate() {
312                status = w_status;
313            } else {
314                status = status.and_fct(w_status.fct());
315            }
316        }
317        status.with_msg(formatx!("{metrics}")).with_meta_mut(|mut m| {
318            m.set(*METRICS_ID, metrics);
319        })
320    }
321}
322
323zng_state_map::static_id! {
324    /// Metrics in a [`Progress::with_meta`] metadata.
325    pub static ref METRICS_ID: zng_state_map::StateId<Metrics>;
326}
327
328/// Extension methods for [`std::io::Error`] to be used with errors returned by [`McBufReader`].
329pub trait McBufErrorExt {
330    /// Returns `true` if this error represents the condition where there are only [`McBufReader::is_lazy`] readers
331    /// left, the buffer is drained and the inner reader is not EOF.
332    ///
333    /// You can recover from this error by turning the reader non-lazy using [`McBufReader::set_lazy`].
334    fn is_only_lazy_left(&self) -> bool;
335}
336impl McBufErrorExt for std::io::Error {
337    fn is_only_lazy_left(&self) -> bool {
338        matches!(self.kind(), ErrorKind::Other) && format!("{self:?}").contains(ONLY_NON_LAZY_ERROR_MSG)
339    }
340}
341const ONLY_NON_LAZY_ERROR_MSG: &str = "no non-lazy readers left to read";
342
343/// Multiple consumer buffered read.
344///
345/// Clone an instance to create a new consumer, already read bytes stay in the buffer until all clones have read it,
346/// clones continue reading from the same offset as the reader they cloned.
347///
348/// A single instance of this reader behaves like a `BufReader`.
349///
350/// # Result
351///
352/// The result is *repeats* ready when `EOF` or an [`Error`] occurs, unfortunately the IO error is not cloneable
353/// so the error is recreated using [`CloneableError`] for subsequent poll attempts.
354///
355/// The inner reader is dropped as soon as it finishes.
356///
357/// # Lazy Clones
358///
359/// You can mark clones as [lazy], lazy clones don't pull from the inner reader, only advance when another clone reads, if
360/// all living clones are lazy they stop reading with an error. You can identify this custom error using the [`McBufErrorExt::is_only_lazy_left`]
361/// extension method.
362///
363/// [lazy]: Self::set_lazy
364pub struct McBufReader<S: AsyncRead> {
365    inner: Arc<Mutex<McBufInner<S>>>,
366    index: usize,
367    lazy: bool,
368}
369struct McBufInner<S: AsyncRead> {
370    source: Option<S>,
371    waker: McWaker,
372    lazy_wakers: Vec<task::Waker>,
373
374    buf: Vec<u8>,
375
376    clones: Vec<usize>,
377    non_lazy_count: usize,
378
379    result: ReadState,
380}
381impl<S: AsyncRead> McBufReader<S> {
382    /// Creates a buffered reader.
383    pub fn new(source: S) -> Self {
384        let mut clones = Vec::with_capacity(2);
385        clones.push(0);
386        McBufReader {
387            inner: Arc::new(Mutex::new(McBufInner {
388                source: Some(source),
389                waker: McWaker::empty(),
390                lazy_wakers: vec![],
391
392                buf: Vec::with_capacity(10.kilobytes().0),
393
394                clones,
395                non_lazy_count: 1,
396
397                result: ReadState::Running,
398            })),
399            index: 0,
400            lazy: false,
401        }
402    }
403
404    /// Returns `true` if this reader does not pull from the inner reader, only advancing when a non-lazy reader advances.
405    ///
406    /// The initial reader is not lazy, only clones of lazy readers are lazy by default.
407    pub fn is_lazy(&self) -> bool {
408        self.lazy
409    }
410
411    /// Sets [`is_lazy`].
412    ///
413    /// [`is_lazy`]: Self::is_lazy
414    pub fn set_lazy(&mut self, lazy: bool) {
415        if self.lazy != lazy {
416            if lazy {
417                self.inner.lock().non_lazy_count -= 1;
418            } else {
419                self.inner.lock().non_lazy_count += 1;
420            }
421            self.lazy = lazy;
422        }
423    }
424}
425impl<S: AsyncRead> Clone for McBufReader<S> {
426    fn clone(&self) -> Self {
427        let mut inner = self.inner.lock();
428
429        let offset = inner.clones[self.index];
430        let index = inner.clones.len();
431        inner.clones.push(offset);
432
433        if !self.lazy {
434            inner.non_lazy_count += 1;
435        }
436
437        Self {
438            inner: self.inner.clone(),
439            index,
440            lazy: self.lazy,
441        }
442    }
443}
444impl<S: AsyncRead> Drop for McBufReader<S> {
445    fn drop(&mut self) {
446        let mut inner = self.inner.lock();
447        inner.clones[self.index] = usize::MAX;
448        if !self.lazy {
449            inner.non_lazy_count -= 1;
450            if inner.non_lazy_count == 0 {
451                // notify lazy so they get the error.
452                for waker in inner.lazy_wakers.drain(..) {
453                    waker.wake();
454                }
455            }
456        }
457    }
458}
459impl<S: AsyncRead> AsyncRead for McBufReader<S> {
460    fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
461        let self_ = self.as_ref();
462        let mut inner = self_.inner.lock();
463        let inner = &mut *inner;
464
465        // ready data for this clone.
466        let mut i = inner.clones[self_.index];
467        let mut ready;
468
469        match &inner.result {
470            ReadState::Running => {
471                // source has not finished yet.
472
473                ready = &inner.buf[i..];
474
475                if ready.is_empty() {
476                    if self.lazy {
477                        if inner.non_lazy_count == 0 {
478                            // user can make this reader non-lazy and try again.
479                            return Poll::Ready(Err(Error::other(ONLY_NON_LAZY_ERROR_MSG)));
480                        } else {
481                            // register waker for after non-lazy poll.
482                            inner.lazy_wakers.push(cx.waker().clone());
483
484                            // wait non-lazy to pull.
485                            return Poll::Pending;
486                        }
487                    }
488
489                    // time to poll source.
490
491                    ready = &[];
492
493                    let waker = match inner.waker.push(cx.waker().clone()) {
494                        Some(w) => w,
495                        None => {
496                            // already polling from another clone.
497                            return Poll::Pending;
498                        }
499                    };
500
501                    let min_i = inner.clones.iter().copied().min().unwrap();
502                    if min_i > 0 {
503                        // reuse front.
504                        inner.buf.copy_within(min_i.., 0);
505                        inner.buf.truncate(inner.buf.len() - min_i);
506
507                        i -= min_i;
508                        for i in &mut inner.clones {
509                            *i -= min_i;
510                        }
511                    }
512
513                    let new_start = inner.buf.len();
514
515                    inner.buf.resize(inner.buf.len() + buf.len().max(10.kilobytes().0), 0);
516
517                    let mut inner_cx = task::Context::from_waker(&waker);
518
519                    // SAFETY: we don't move `source`.
520                    let source = unsafe { Pin::new_unchecked(inner.source.as_mut().unwrap()) };
521                    let result = source.poll_read(&mut inner_cx, &mut inner.buf[new_start..]);
522
523                    match result {
524                        Poll::Ready(result) => {
525                            // notify lazy readers.
526                            for waker in inner.lazy_wakers.drain(..) {
527                                waker.wake();
528                            }
529
530                            match result {
531                                Ok(0) => {
532                                    inner.waker.cancel();
533
534                                    // EOF
535                                    inner.buf.truncate(new_start);
536                                    inner.result = ReadState::Eof;
537                                    inner.source = None;
538
539                                    // continue 'copy ready
540                                }
541                                Ok(read) => {
542                                    inner.waker.cancel();
543
544                                    // Read > 0
545                                    inner.buf.truncate(new_start + read);
546                                    ready = &inner.buf[i..];
547
548                                    // continue 'copy ready
549                                }
550                                Err(e) => {
551                                    inner.waker.cancel();
552
553                                    // Error
554                                    inner.result = ReadState::Err(CloneableError::new(&e));
555                                    inner.buf = vec![];
556                                    inner.source = None;
557
558                                    return Poll::Ready(Err(e));
559                                }
560                            }
561                        }
562
563                        Poll::Pending => {
564                            inner.buf.truncate(new_start);
565                            return Poll::Pending;
566                        }
567                    }
568                }
569            }
570            ReadState::Eof => {
571                ready = &inner.buf[i..];
572
573                // continue 'copy ready
574            }
575            ReadState::Err(e) => return Poll::Ready(e.err()),
576        }
577
578        // 'copy ready
579
580        let max_ready = buf.len().min(ready.len());
581        buf[..max_ready].copy_from_slice(&ready[..max_ready]);
582
583        i += max_ready;
584        inner.clones[self_.index] = i;
585
586        Poll::Ready(Ok(max_ready))
587    }
588}
589
590/// Represents the cloneable parts of an [`Error`].
591///
592/// Unfortunately [`Error`] does not implement clone, this is needed to implemented
593/// IO futures that repeat the ready result after subsequent polls. This type partially
594/// works around the issue by copying enough information to recreate an error that is still useful.
595///
596/// The OS error code, [`ErrorKind`] and display message are preserved. Note that this not an error type,
597/// it must be converted to [`Error`] using `into` or [`err`].
598///
599/// [`err`]: Self::err
600#[derive(Clone)]
601pub struct CloneableError {
602    info: ErrorInfo,
603}
604#[derive(Clone)]
605enum ErrorInfo {
606    OsError(i32),
607    Other(ErrorKind, String),
608}
609impl CloneableError {
610    /// Copy the cloneable information from the [`Error`].
611    pub fn new(e: &Error) -> Self {
612        let info = if let Some(code) = e.raw_os_error() {
613            ErrorInfo::OsError(code)
614        } else {
615            ErrorInfo::Other(e.kind(), format!("{e}"))
616        };
617
618        Self { info }
619    }
620
621    /// Returns an `Err(Error)` generated from the cloneable information.
622    pub fn err<T>(&self) -> Result<T> {
623        Err(self.clone().into())
624    }
625}
626impl From<CloneableError> for Error {
627    fn from(e: CloneableError) -> Self {
628        match e.info {
629            ErrorInfo::OsError(code) => Error::from_raw_os_error(code),
630            ErrorInfo::Other(kind, msg) => Error::new(kind, msg),
631        }
632    }
633}
634
635/// Represents a stream reader that generates an error if the source stream exceeds a limit.
636///
637/// Note that some bytes over the limit may be read once if the source stream is buffered.
638pub struct ReadLimited<S> {
639    source: S,
640    limit: usize,
641    on_limit: fn() -> std::io::Error,
642}
643impl<S> ReadLimited<S> {
644    /// Construct a limited reader.
645    ///
646    /// The `on_limit` closure is called for every read attempt after the limit is reached.
647    pub fn new(source: S, limit: ByteLength, on_limit: fn() -> std::io::Error) -> Self {
648        Self {
649            source,
650            limit: limit.0,
651            on_limit,
652        }
653    }
654
655    /// New with default on limit error.
656    pub fn new_default_err(source: S, limit: ByteLength) -> Self {
657        Self::new(source, limit, || {
658            std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "source exceeded read limit")
659        })
660    }
661}
662impl<S> AsyncRead for ReadLimited<S>
663where
664    S: AsyncRead,
665{
666    fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, mut buf: &mut [u8]) -> Poll<Result<usize>> {
667        // SAFETY: we don't move anything.
668        let self_ = unsafe { self.get_unchecked_mut() };
669
670        if self_.limit == 0 {
671            let err = (self_.on_limit)();
672            return Poll::Ready(Err(err));
673        }
674
675        if buf.len() > self_.limit {
676            buf = &mut buf[..self_.limit];
677        }
678
679        // SAFETY: we never move `source`.
680        match unsafe { Pin::new_unchecked(&mut self_.source) }.poll_read(cx, buf) {
681            Poll::Ready(Ok(n)) => {
682                self_.limit = self_.limit.saturating_sub(n);
683                Poll::Ready(Ok(n))
684            }
685            r => r,
686        }
687    }
688}
689impl<S> AsyncBufRead for ReadLimited<S>
690where
691    S: AsyncBufRead,
692{
693    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<&[u8]>> {
694        // SAFETY: we don't move anything.
695        let self_ = unsafe { self.get_unchecked_mut() };
696
697        if self_.limit == 0 {
698            let err = (self_.on_limit)();
699            return Poll::Ready(Err(err));
700        }
701
702        // SAFETY: we never move `source`.
703        unsafe { Pin::new_unchecked(&mut self_.source) }.poll_fill_buf(cx)
704    }
705
706    fn consume(self: Pin<&mut Self>, amt: usize) {
707        // SAFETY: we don't move anything.
708        let self_ = unsafe { self.get_unchecked_mut() };
709        // SAFETY: we never move `source`.
710        unsafe { Pin::new_unchecked(&mut self_.source) }.consume(amt);
711        self_.limit = self_.limit.saturating_sub(amt);
712    }
713}
714impl<S> Read for ReadLimited<S>
715where
716    S: Read,
717{
718    fn read(&mut self, mut buf: &mut [u8]) -> Result<usize> {
719        if self.limit == 0 {
720            let err = (self.on_limit)();
721            return Err(err);
722        }
723
724        if buf.len() > self.limit {
725            buf = &mut buf[..self.limit];
726        }
727
728        match self.source.read(buf) {
729            Ok(n) => {
730                self.limit = self.limit.saturating_sub(n);
731                Ok(n)
732            }
733            r => r,
734        }
735    }
736}
737impl<S> BufRead for ReadLimited<S>
738where
739    S: BufRead,
740{
741    fn fill_buf(&mut self) -> Result<&[u8]> {
742        if self.limit == 0 {
743            let err = (self.on_limit)();
744            return Err(err);
745        }
746
747        self.source.fill_buf()
748    }
749
750    fn consume(&mut self, amount: usize) {
751        self.source.consume(amount);
752        self.limit = self.limit.saturating_sub(amount);
753    }
754}
755
756enum ReadState {
757    Running,
758    Eof,
759    Err(CloneableError),
760}
761
762#[cfg(test)]
763mod tests {
764    use super::*;
765    use crate as task;
766    use zng_unit::TimeUnits;
767
768    #[test]
769    pub fn mc_buf_reader_parallel() {
770        let data = Data::new(60.kilobytes().0);
771
772        let mut expected = vec![0; data.len];
773        let _ = data.clone().blocking_read(&mut expected[..]);
774
775        let mut a = McBufReader::new(data);
776        let mut b = a.clone();
777        let mut c = a.clone();
778
779        let (a, b, c) = async_test(async move {
780            let a = task::run(async move {
781                let mut buf = vec![];
782                a.read_to_end(&mut buf).await.unwrap();
783                buf
784            });
785            let b = task::run(async move {
786                let mut buf: Vec<u8> = vec![];
787                b.read_to_end(&mut buf).await.unwrap();
788                buf
789            });
790            let c = task::run(async move {
791                let mut buf: Vec<u8> = vec![];
792                c.read_to_end(&mut buf).await.unwrap();
793                buf
794            });
795
796            task::all!(a, b, c).await
797        });
798
799        crate::assert_vec_eq!(expected, a);
800        crate::assert_vec_eq!(expected, b);
801        crate::assert_vec_eq!(expected, c);
802    }
803
804    #[test]
805    pub fn mc_buf_reader_single() {
806        let data = Data::new(60.kilobytes().0);
807
808        let mut expected = vec![0; data.len];
809        let _ = data.clone().blocking_read(&mut expected[..]);
810
811        let mut a = McBufReader::new(data);
812
813        let a = async_test(async move {
814            let a = task::run(async move {
815                let mut buf = vec![];
816                a.read_to_end(&mut buf).await.unwrap();
817                buf
818            });
819
820            a.await
821        });
822
823        crate::assert_vec_eq!(expected, a);
824    }
825
826    #[test]
827    pub fn mc_buf_reader_sequential() {
828        let data = Data::new(60.kilobytes().0);
829
830        let mut expected = vec![0; data.len];
831        let _ = data.clone().blocking_read(&mut expected[..]);
832
833        let mut clones = vec![McBufReader::new(data)];
834        for _ in 0..5 {
835            clones.push(clones[0].clone());
836        }
837
838        let r = async_test(async move {
839            let mut r = vec![];
840
841            for mut clone in clones {
842                let mut buf = vec![];
843                clone.read_to_end(&mut buf).await.unwrap();
844                r.push(buf);
845            }
846
847            r
848        });
849
850        for r in r {
851            crate::assert_vec_eq!(expected, r);
852        }
853    }
854
855    #[test]
856    pub fn mc_buf_reader_completed() {
857        let data = Data::new(60.kilobytes().0);
858        let mut buf = Vec::with_capacity(data.len);
859        let mut a = McBufReader::new(data);
860
861        let r = async_test(async move {
862            a.read_to_end(&mut buf).await.unwrap();
863
864            let mut b = a.clone();
865            buf.clear();
866
867            b.read_to_end(&mut buf).await.unwrap();
868            buf.len()
869        });
870
871        assert_eq!(0, r);
872    }
873
874    #[test]
875    pub fn mc_buf_reader_error() {
876        let mut data = Data::new(20.kilobytes().0);
877        data.set_error();
878
879        let mut expected = vec![0; data.len];
880        let _ = data.clone().blocking_read(&mut expected[..]);
881
882        let mut a = McBufReader::new(data);
883        let mut b = a.clone();
884
885        let (a, b) = async_test(async move {
886            let a = task::run(async move {
887                let mut buf = vec![];
888                a.read_to_end(&mut buf).await.unwrap_err()
889            });
890            let b = task::run(async move {
891                let mut buf: Vec<u8> = vec![];
892                b.read_to_end(&mut buf).await.unwrap_err()
893            });
894
895            task::all!(a, b).await
896        });
897
898        assert_eq!(ErrorKind::InvalidData, a.kind());
899        assert_eq!(ErrorKind::InvalidData, b.kind());
900    }
901
902    #[test]
903    pub fn mc_buf_reader_error_completed() {
904        let mut data = Data::new(20.kilobytes().0);
905        data.set_error();
906
907        let mut buf = Vec::with_capacity(data.len);
908        let mut a = McBufReader::new(data);
909
910        let (a, b) = async_test(async move {
911            let a_err = a.read_to_end(&mut buf).await.unwrap_err();
912
913            let mut b = a.clone();
914            buf.clear();
915
916            let b_err = b.read_to_end(&mut buf).await.unwrap_err();
917
918            (a_err, b_err)
919        });
920
921        assert_eq!(ErrorKind::InvalidData, a.kind());
922        assert_eq!(ErrorKind::InvalidData, b.kind());
923    }
924
925    #[test]
926    pub fn mc_buf_reader_parallel_with_delay1() {
927        let mut data = Data::new(60.kilobytes().0);
928        data.enable_pending();
929
930        let mut expected = vec![0; data.len];
931        let _ = data.clone().blocking_read(&mut expected[..]);
932
933        let mut a = McBufReader::new(data);
934        let mut b = a.clone();
935        let mut c = a.clone();
936
937        let (a, b, c) = async_test(async move {
938            let a = task::run(async move {
939                let mut buf = vec![];
940                a.read_to_end(&mut buf).await.unwrap();
941                buf
942            });
943            let b = task::run(async move {
944                let mut buf: Vec<u8> = vec![];
945                b.read_to_end(&mut buf).await.unwrap();
946                buf
947            });
948            let c = task::run(async move {
949                let mut buf: Vec<u8> = vec![];
950                c.read_to_end(&mut buf).await.unwrap();
951                buf
952            });
953
954            task::all!(a, b, c).await
955        });
956
957        crate::assert_vec_eq!(expected, a);
958        crate::assert_vec_eq!(expected, b);
959        crate::assert_vec_eq!(expected, c);
960    }
961
962    #[test]
963    pub fn mc_buf_reader_parallel_with_delay2() {
964        let mut data = Data::new(60.kilobytes().0);
965        data.enable_pending();
966
967        let mut expected = vec![0; data.len];
968        let _ = data.clone().blocking_read(&mut expected[..]);
969
970        let mut a = McBufReader::new(data);
971        let mut b = a.clone();
972        let mut c = a.clone();
973
974        let (a, b, c) = async_test(async move {
975            let a = task::run(async move {
976                let mut buf = vec![];
977                a.read_to_end(&mut buf).await.unwrap();
978                buf
979            });
980            let b = task::run(async move {
981                let mut buf: Vec<u8> = vec![];
982                task::deadline(5.ms()).await;
983                b.read_to_end(&mut buf).await.unwrap();
984                buf
985            });
986            let c = task::run(async move {
987                let mut buf: Vec<u8> = vec![];
988                c.read_to_end(&mut buf).await.unwrap();
989                buf
990            });
991
992            task::all!(a, b, c).await
993        });
994
995        crate::assert_vec_eq!(expected, a);
996        crate::assert_vec_eq!(expected, b);
997        crate::assert_vec_eq!(expected, c);
998    }
999
1000    #[derive(Clone)]
1001    struct Data {
1002        b: u8,
1003        len: usize,
1004        error: Option<CloneableError>,
1005        delay: Duration,
1006        pending: bool,
1007    }
1008    impl Data {
1009        pub fn new(len: usize) -> Self {
1010            Self {
1011                b: 0,
1012                len,
1013                error: None,
1014                delay: 0.ms(),
1015                pending: false,
1016            }
1017        }
1018        pub fn blocking_read(&mut self, buf: &mut [u8]) -> Result<usize> {
1019            let len = self.len;
1020            for b in buf.iter_mut().take(len) {
1021                *b = self.b;
1022                self.len -= 1;
1023                self.b = self.b.wrapping_add(1);
1024            }
1025
1026            if len == 0
1027                && let Some(e) = &self.error
1028            {
1029                return e.err();
1030            }
1031
1032            Ok(buf.len().min(len))
1033        }
1034        pub fn set_error(&mut self) {
1035            self.error = Some(CloneableError::new(&Error::new(ErrorKind::InvalidData, "test error")));
1036        }
1037
1038        pub fn enable_pending(&mut self) {
1039            self.delay = 3.ms();
1040        }
1041    }
1042    impl AsyncRead for Data {
1043        fn poll_read(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
1044            if self.delay > Duration::ZERO {
1045                self.pending = !self.pending;
1046                if self.pending {
1047                    let waker = cx.waker().clone();
1048                    let delay = self.delay;
1049                    task::spawn(async move {
1050                        task::deadline(delay).await;
1051                        waker.wake();
1052                    });
1053                    return Poll::Pending;
1054                }
1055            }
1056
1057            let r = self.as_mut().blocking_read(buf);
1058            Poll::Ready(r)
1059        }
1060    }
1061
1062    #[track_caller]
1063    fn async_test<F>(test: F) -> F::Output
1064    where
1065        F: Future,
1066    {
1067        task::block_on(task::with_deadline(test, 5.secs())).unwrap()
1068    }
1069
1070    /// Assert vector equality with better error message.
1071    #[macro_export]
1072    macro_rules! assert_vec_eq {
1073        ($a:expr, $b: expr) => {
1074            match (&$a, &$b) {
1075                (ref a, ref b) => {
1076                    let len_not_eq = a.len() != b.len();
1077                    let mut data_not_eq = None;
1078                    for (i, (a, b)) in a.iter().zip(b.iter()).enumerate() {
1079                        if a != b {
1080                            data_not_eq = Some(i);
1081                            break;
1082                        }
1083                    }
1084
1085                    if len_not_eq || data_not_eq.is_some() {
1086                        use std::fmt::*;
1087
1088                        let mut error = format!("`{}` != `{}`", stringify!($a), stringify!($b));
1089                        if len_not_eq {
1090                            let _ = write!(&mut error, "\n  lengths not equal: {} != {}", a.len(), b.len());
1091                        }
1092                        if let Some(i) = data_not_eq {
1093                            let _ = write!(&mut error, "\n  data not equal at index {}: {} != {:?}", i, a[i], b[i]);
1094                        }
1095                        panic!("{error}")
1096                    }
1097                }
1098            }
1099        };
1100    }
1101}