zzz/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    fmt::{self, Write as _},
5    io::{self, stderr, IsTerminal, Write as _},
6    sync::atomic::{AtomicUsize, Ordering::Relaxed},
7    sync::RwLock,
8    time::{Duration, Instant},
9};
10
11// ============================================================================================== //
12// [Prelude module]                                                                               //
13// ============================================================================================== //
14
15/// Mass-import for the main progress bar type as well as the convenience extension traits.
16pub mod prelude {
17    #[cfg(feature = "streams")]
18    pub use crate::ProgressBarStreamExt;
19    pub use crate::{ProgressBar, ProgressBarIterExt};
20}
21
22// ============================================================================================== //
23// [General configuration]                                                                        //
24// ============================================================================================== //
25
26#[doc(hidden)]
27#[deprecated(note = "renamed to just `Config`")]
28pub type ProgressBarConfig = Config;
29
30/// Configuration for a progress bar.
31///
32/// This is a separate struct from the actual progress bar in order to allow a
33/// configuration to be reused in different progress bar instances.
34#[derive(Clone)]
35pub struct Config {
36    /// Width of the progress bar.
37    pub width: Option<u32>,
38    /// Minimum width to bother with drawing the bar for.
39    pub min_bar_width: u32,
40    /// Theme to use when drawing.
41    pub theme: &'static dyn Theme,
42    /// Maximum redraw rate rate (draws per second).
43    pub max_fps: f32,
44    /// Called to determine whether the progress bar should be drawn or not.
45    ///
46    /// The default value always returns `true` if stderr is a terminal.
47    pub should_draw: &'static (dyn Fn() -> bool + Sync),
48}
49
50static DEFAULT_CFG: Config = Config::const_default();
51
52impl Config {
53    /// `const` variant of [`Config::default`].
54    pub const fn const_default() -> Self {
55        Config {
56            width: None,
57            min_bar_width: 5,
58            theme: &DefaultTheme,
59            max_fps: 60.0,
60            should_draw: &|| stderr().is_terminal(),
61        }
62    }
63}
64
65impl Default for Config {
66    #[inline]
67    fn default() -> Self {
68        Config::const_default()
69    }
70}
71
72/// Selects the currently active global configuration.
73///
74/// This stores a `*const ProgressBarConfig`. We use `AtomicUsize` instead of
75/// the seemingly more idiomatic `AtomicPtr` here because the latter requires a
76/// **mutable** pointer, which would in turn force us to take the config as
77/// mutable reference to not run into UB. There is no const variant of `AtomicPtr`.
78/// Using `AtomicUsize` seemed like the lesser evil here.
79static GLOBAL_CFG: AtomicUsize = AtomicUsize::new(0);
80
81/// Gets the currently active global configuration.
82pub fn global_config() -> &'static Config {
83    match GLOBAL_CFG.load(Relaxed) {
84        0 => &DEFAULT_CFG,
85        ptr => unsafe { &*(ptr as *const Config) },
86    }
87}
88
89/// Set a new global default configuration.
90///
91/// This configuration is used when no explicit per instance configuration
92/// is specified via [`ProgressBar::config`].
93pub fn set_global_config(new_cfg: &'static Config) {
94    GLOBAL_CFG.store(new_cfg as *const _ as _, Relaxed);
95}
96
97// ============================================================================================== //
98// [Utils]                                                                                        //
99// ============================================================================================== //
100
101/// Pads and aligns a value to the length of a cache line.
102///
103/// Adapted from crossbeam:
104/// https://docs.rs/crossbeam/0.7.3/crossbeam/utils/struct.CachePadded.html
105#[cfg_attr(target_arch = "x86_64", repr(align(128)))]
106#[cfg_attr(not(target_arch = "x86_64"), repr(align(64)))]
107struct CachePadded<T>(T);
108
109impl<T> std::ops::Deref for CachePadded<T> {
110    type Target = T;
111
112    fn deref(&self) -> &T {
113        &self.0
114    }
115}
116
117impl<T> std::ops::DerefMut for CachePadded<T> {
118    fn deref_mut(&mut self) -> &mut T {
119        &mut self.0
120    }
121}
122
123// ============================================================================================== //
124// [Error type]                                                                                   //
125// ============================================================================================== //
126
127/// Errors that can ocurr while drawing the progress bar.
128#[derive(Debug)]
129pub enum RenderError {
130    Io(io::Error),
131    Fmt(fmt::Error),
132}
133
134impl fmt::Display for RenderError {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        match self {
137            RenderError::Fmt(e) => e.fmt(f),
138            RenderError::Io(e) => e.fmt(f),
139        }
140    }
141}
142
143// TODO: this should probably forward everything
144impl std::error::Error for RenderError {}
145
146impl From<io::Error> for RenderError {
147    fn from(e: io::Error) -> Self {
148        RenderError::Io(e)
149    }
150}
151
152impl From<fmt::Error> for RenderError {
153    fn from(e: fmt::Error) -> Self {
154        RenderError::Fmt(e)
155    }
156}
157
158// ============================================================================================== //
159// [Customizable printing]                                                                        //
160// ============================================================================================== //
161
162/// Trait defining how the progress bar is rendered.
163pub trait Theme: Sync {
164    fn render(&self, pb: &ProgressBar) -> Result<(), RenderError>;
165}
166
167#[derive(Debug, Default)]
168struct DefaultTheme;
169
170/// Creates a unicode progress bar.
171fn bar(progress: f32, length: u32) -> String {
172    if length == 0 {
173        return String::new();
174    }
175
176    let inner_len = length.saturating_sub(2);
177    let rescaled = (progress * (inner_len - 1) as f32 * 8.0).round() as u32;
178    let (i, r) = (rescaled / 8, rescaled % 8);
179    let main = "█".repeat(i as usize);
180    let tail = '▏' as u32 - r;
181    let tail = unsafe { std::char::from_u32_unchecked(tail) };
182    let pad_len = inner_len - i - 1 /* tail */;
183    let pad = " ".repeat(pad_len as usize);
184
185    let bar = format!("|{}{}{}|", main, tail, pad);
186    debug_assert_eq!(bar.chars().count() as u32, length);
187    bar
188}
189
190fn human_time(duration: Duration) -> String {
191    let total = duration.as_secs();
192    let h = total / 3600;
193    let m = total % 3600 / 60;
194    let s = total % 60;
195    format!("{:02}:{:02}:{:02}", h, m, s)
196}
197
198fn spinner(x: f32, width: u32) -> String {
199    // Subtract two pipes + spinner char
200    let inner_width = width.saturating_sub(3);
201
202    // fn easing_inout_quad(mut x: f32) -> f32 {
203    //     x *= 2.0;
204    //
205    //     if x > 1.0 {
206    //         -0.5 * ((x - 1.0) * (x - 3.0) - 1.0)
207    //     } else {
208    //         0.5 * x * x
209    //     }
210    // }
211
212    fn easing_inout_cubic(mut x: f32) -> f32 {
213        x *= 2.0;
214
215        if x < 1.0 {
216            0.5 * x.powi(3)
217        } else {
218            x -= 2.;
219            0.5 * (x.powi(3) + 2.)
220        }
221    }
222
223    // fn easing_out_quad(x: f32) -> f32 {
224    //     -x * (x - 2.)
225    // }
226
227    // Make the spinner turn around after half the period.
228    let x = ((-x + 0.5).abs() - 0.5) * -2.;
229
230    // Apply easing function.
231    let x = easing_inout_cubic(x).max(0.).min(1.);
232    // Transform 0..1 scale to int width.
233    let x = ((inner_width as f32) * x).round() as u32;
234
235    let lpad = x as usize;
236    let rpad = inner_width.saturating_sub(x) as usize;
237
238    let ball_offs = x / 8 % 8; // slow anim down
239    let ball = unsafe { std::char::from_u32_unchecked('🌑' as u32 + ball_offs) };
240
241    let spinner = format!("[{}{}{}]", " ".repeat(lpad), ball, " ".repeat(rpad));
242    debug_assert_eq!(spinner.chars().count() as u32, width);
243    spinner
244}
245
246/*
247barr1 = UInt32[0x00, 0x40, 0x04, 0x02, 0x01]
248barr2 = UInt32[0x00, 0x80, 0x20, 0x10, 0x08]
249function braille(a::Float64, b::Float64)
250    bchar(a::UInt32) = '⠀' + a
251    idx(x) = min(x * 4 + 1, 5) |> round |> UInt32
252
253    x = barr1[1:idx(a)] |> sum
254    x |= barr2[1:idx(b)] |> sum
255
256    x |> UInt32 |> bchar
257end
258*/
259
260/// Determines the dimensions of stderr.
261#[cfg(feature = "auto-width")]
262fn stderr_dimensions() -> (usize, usize) {
263    // term_size doesn't support stderr on Windows, so just use stdout and
264    // hope for the best. We should probably replace term_size anyway in the
265    // long run since it's unmaintained, but this works for the moment.
266    #[cfg(target_os = "windows")]
267    return term_size::dimensions_stdout().unwrap_or((80, 30));
268
269    #[cfg(not(target_os = "windows"))]
270    return term_size::dimensions_stderr().unwrap_or((80, 30));
271}
272
273/// Determines the dimensions of stderr.
274#[cfg(not(feature = "auto-width"))]
275fn stderr_dimensions() -> (usize, usize) {
276    (80, 30)
277}
278
279impl Theme for DefaultTheme {
280    fn render(&self, pb: &ProgressBar) -> Result<(), RenderError> {
281        let mut o = stderr();
282        let cfg = pb.active_config();
283
284        // Draw left side.
285        let left = {
286            let mut buf = String::new();
287
288            // If a description is set, print it.
289            if let Some(desc) = pb.message() {
290                write!(buf, "{} ", desc)?;
291            }
292
293            if let Some(progress) = pb.progress() {
294                write!(buf, "{:>6.2}% ", progress * 100.0)?;
295            }
296
297            buf
298        };
299
300        // Draw right side.
301        let right = {
302            let mut buf = String::new();
303
304            // Print "done/total" part
305            buf.write_char(' ')?;
306            pb.unit.write_total(&mut buf, pb.value())?;
307            buf.write_char('/')?;
308            match pb.target {
309                Some(target) => pb.unit.write_total(&mut buf, target)?,
310                None => buf.write_char('?')?,
311            }
312
313            // Print ETA / time elapsed.
314            if let Some(eta) = pb.eta() {
315                write!(buf, " [{}]", human_time(eta))?;
316            } else {
317                write!(buf, " [{}]", human_time(pb.elapsed()))?;
318            }
319
320            // Print iteration rate.
321            buf.write_str(" (")?;
322            pb.unit.write_rate(&mut buf, pb.iters_per_sec())?;
323            buf.write_char(')')?;
324
325            buf
326        };
327
328        let max_width = cfg.width.unwrap_or_else(|| stderr_dimensions().0 as u32);
329
330        let bar_width = max_width
331            .saturating_sub(left.len() as u32)
332            .saturating_sub(right.len() as u32);
333
334        write!(o, "{}", left)?;
335
336        if bar_width > cfg.min_bar_width {
337            // Draw a progress bar for known-length bars.
338            if let Some(progress) = pb.progress() {
339                write!(o, "{}", bar(progress, bar_width))?;
340            }
341            // And a spinner for unknown-length bars.
342            else {
343                let duration = Duration::from_secs(3);
344                let pos = pb.timer_progress(duration);
345                // Sub 1 from width because many terms render emojis with double width.
346                write!(o, "{}", spinner(pos, bar_width - 1))?;
347            }
348        }
349
350        write!(o, "{}\r", right)?;
351
352        o.flush().map_err(Into::into)
353    }
354}
355
356// ============================================================================================== //
357// [Units]                                                                                        //
358// ============================================================================================== //
359
360/// Determines the unit used for printing iteration speed.
361#[non_exhaustive]
362#[derive(Debug, PartialEq, Eq, Copy, Clone)]
363pub enum Unit {
364    Iterations,
365    Bytes,
366}
367
368fn human_iter_unit(x: usize) -> (&'static str, f32) {
369    if x > 10usize.pow(9) {
370        ("B", 1e9)
371    } else if x > 10usize.pow(6) {
372        ("M", 1e6)
373    } else if x > 10usize.pow(3) {
374        ("K", 1e3)
375    } else {
376        ("", 1e0)
377    }
378}
379
380fn bytes_unit(x: usize) -> (&'static str, f32) {
381    if x > 1024usize.pow(4) {
382        ("TiB", 1024_f32.powi(4))
383    } else if x > 1024usize.pow(3) {
384        ("GiB", 1024_f32.powi(3))
385    } else if x > 1024usize.pow(2) {
386        ("MiB", 1024_f32.powi(2))
387    } else if x > 1024usize.pow(1) {
388        ("KiB", 1024_f32.powi(1))
389    } else {
390        ("b", 1024_f32.powi(0))
391    }
392}
393
394impl Unit {
395    /// Formats an absolute amount, e.g. "1200 iterations".
396    fn write_total<W: fmt::Write>(self, mut out: W, amount: usize) -> fmt::Result {
397        match self {
398            Unit::Iterations => {
399                let (unit, div) = human_iter_unit(amount);
400                write!(out, "{:.2}{}", (amount as f32) / div, unit)
401            }
402            Unit::Bytes => {
403                let (unit, div) = bytes_unit(amount);
404                write!(out, "{:.2}{}", (amount as f32) / div, unit)
405            }
406        }
407    }
408
409    /// Formats a rate of change, e.g. "1200 it/sec".
410    fn write_rate<W: fmt::Write>(self, mut out: W, rate: f32) -> fmt::Result {
411        match self {
412            Unit::Iterations => {
413                if rate >= 1.0 {
414                    let (unit, div) = human_iter_unit(rate as usize);
415                    write!(out, "{:.2}{} it/s", rate / div, unit)
416                } else {
417                    write!(out, "{:.0} s/it", 1.0 / rate)
418                }
419            }
420            Unit::Bytes => {
421                let (unit, div) = bytes_unit(rate as usize);
422                write!(out, "{:.2}{}/s", rate / div, unit)
423            }
424        }
425    }
426}
427
428// ============================================================================================== //
429// [Main progress bar struct]                                                                     //
430// ============================================================================================== //
431
432/// Progress bar to be rendered on the terminal.
433///
434/// # Example
435///
436/// ```rust
437/// use zzz::prelude::*;
438///
439/// let mut bar = ProgressBar::with_target(123);
440/// for _ in 0..123 {
441///     bar.add(1);
442/// }
443/// ```
444pub struct ProgressBar {
445    /// Configuration to use.
446    cfg: Option<&'static Config>,
447    /// The expected, possibly approximate target of the progress bar.
448    target: Option<usize>,
449    /// Whether the target was specified explicitly.
450    explicit_target: bool,
451    /// Whether the target was specified explicitly.
452    pub(crate) unit: Unit,
453    /// Creation time of the progress bar.
454    start: Instant,
455    /// Description of the progress bar, e.g. "Downloading image".
456    message: RwLock<Option<String>>,
457    /// Progress value displayed to the user.
458    value: CachePadded<AtomicUsize>,
459    /// Number of progress bar updates so far.
460    update_ctr: CachePadded<AtomicUsize>,
461    /// Next print at `update_ctr == next_print`.
462    next_print: CachePadded<AtomicUsize>,
463}
464
465impl Drop for ProgressBar {
466    fn drop(&mut self) {
467        if (self.active_config().should_draw)() {
468            self.redraw();
469            eprintln!();
470        }
471    }
472}
473
474/// Constructors.
475impl ProgressBar {
476    fn new(target: Option<usize>, explicit_target: bool) -> Self {
477        Self {
478            cfg: None,
479            target,
480            explicit_target,
481            start: Instant::now(),
482            unit: Unit::Iterations,
483            value: CachePadded(0.into()),
484            update_ctr: CachePadded(0.into()),
485            next_print: CachePadded(1.into()),
486            message: RwLock::new(None),
487        }
488    }
489
490    /// Creates a smart progress bar, attempting to infer the target from size hints.
491    pub fn smart() -> Self {
492        Self::new(None, false)
493    }
494
495    /// Creates a spinner, a progress bar with indeterminate target value.
496    pub fn spinner() -> Self {
497        Self::new(None, true)
498    }
499
500    /// Creates a progress bar with an explicit target value.
501    pub fn with_target(target: usize) -> Self {
502        Self::new(Some(target), true)
503    }
504}
505
506/// Builder-style methods.
507impl ProgressBar {
508    /// Replace the config of the progress bar.
509    ///
510    /// Takes precedence over a global config set via [`set_global_config`].
511    pub fn config(mut self, cfg: &'static Config) -> Self {
512        self.cfg = Some(cfg);
513        self
514    }
515
516    /// Force display as a spinner even if size hints are present.
517    pub fn force_spinner(mut self) -> Self {
518        self.explicit_target = true;
519        self.target = None;
520        self
521    }
522
523    /// Set the unit to be used when formatting values.
524    pub fn unit(mut self, unit: Unit) -> Self {
525        self.unit = unit;
526        self
527    }
528}
529
530/// Value manipulation and access.
531impl ProgressBar {
532    /// Returns the currently active configuration.
533    #[inline]
534    pub fn active_config(&self) -> &'static Config {
535        self.cfg.unwrap_or_else(global_config)
536    }
537
538    #[rustfmt::skip]
539    pub fn process_size_hint(&mut self, hint: (usize, Option<usize>)) {
540        // If an explicit target is set, disregard size hints.
541        if self.explicit_target {
542            return;
543        }
544
545        // Prefer hi over lo, treat lo = 0 as unknown.
546        self.target = match hint {
547            (_ , Some(hi)) => Some(hi),
548            (0 , None    ) => None,
549            (lo, None    ) => Some(lo),
550        };
551    }
552
553    /// Set the progress bar value to a new, absolute value.
554    ///
555    /// This doesn't automatically redraw the progress-bar.
556    ///
557    /// See `set_sync` for a thread-safe version.
558    #[inline]
559    pub fn set(&mut self, n: usize) {
560        *self.update_ctr.get_mut() += 1;
561        *self.value.get_mut() = n;
562    }
563
564    /// Synchronized version fo `set`.
565    #[inline]
566    pub fn set_sync(&self, n: usize) {
567        self.update_ctr.fetch_add(1, Relaxed);
568        self.value.store(n, Relaxed);
569    }
570
571    /// Add `n` to the value of the progress bar.
572    ///
573    /// See `add_sync` for a thread-safe version.
574    #[inline]
575    pub fn add(&mut self, n: usize) -> usize {
576        *self.value.get_mut() += n;
577        let prev = *self.update_ctr.get_mut();
578        *self.update_ctr.get_mut() += 1;
579        self.maybe_redraw(prev);
580        prev
581    }
582
583    /// Synchronized version fo `add`.
584    #[inline]
585    pub fn add_sync(&self, n: usize) -> usize {
586        self.value.fetch_add(n, Relaxed);
587        let prev = self.update_ctr.fetch_add(1, Relaxed);
588        self.maybe_redraw(prev);
589        prev
590    }
591
592    /// How often has the value been changed since creation?
593    #[inline]
594    pub fn update_ctr(&self) -> usize {
595        self.update_ctr.load(Relaxed)
596    }
597
598    /// Get the current value of the progress bar.
599    #[inline]
600    pub fn value(&self) -> usize {
601        self.value.load(Relaxed)
602    }
603
604    /// Get the current task description text.
605    pub fn message(&self) -> Option<String> {
606        self.message.read().unwrap().clone()
607    }
608
609    /// Set the current task description text.
610    pub fn set_message(&mut self, text: Option<impl Into<String>>) {
611        *self.message.get_mut().unwrap() = text.map(Into::into);
612    }
613
614    /// Synchronized version for `set_message`.
615    pub fn set_message_sync(&self, text: Option<impl Into<String>>) {
616        let mut message_lock = self.message.write().unwrap();
617        *message_lock = text.map(Into::into);
618    }
619
620    /// Calculate the current progress, `0.0 .. 1.0`.
621    #[inline]
622    pub fn progress(&self) -> Option<f32> {
623        let target = self.target?;
624        if target == 0 {
625            return None;
626        }
627        Some(self.value() as f32 / target as f32)
628    }
629
630    /// Calculate the elapsed time since creation of the progress bar.
631    pub fn elapsed(&self) -> Duration {
632        self.start.elapsed()
633    }
634
635    /// Estimate the duration until completion.
636    pub fn eta(&self) -> Option<Duration> {
637        let progress = self.progress()?;
638        if progress == 0.0 {
639            return None;
640        }
641
642        let left = 1. / progress;
643        let elapsed = self.elapsed();
644        let estimated_total = elapsed.mul_f32(left);
645        Some(estimated_total.saturating_sub(elapsed))
646    }
647
648    /// Calculate the mean iterations per second since creation of the progress bar.
649    pub fn iters_per_sec(&self) -> f32 {
650        let elapsed_sec = self.elapsed().as_secs_f32();
651        self.value() as f32 / elapsed_sec
652    }
653
654    /// Calculate the mean progress bar updates per second since creation of the progress bar.
655    pub fn updates_per_sec(&self) -> f32 {
656        let elapsed_sec = self.elapsed().as_secs_f32();
657        self.update_ctr() as f32 / elapsed_sec
658    }
659
660    /// Calculates the progress of a rolling timer.
661    ///
662    /// Returned values are always between 0 and 1. Timers are calculated
663    /// from the start of the progress bar.
664    pub fn timer_progress(&self, timer: Duration) -> f32 {
665        let elapsed_sec = self.elapsed().as_secs_f32();
666        let timer_sec = timer.as_secs_f32();
667
668        (elapsed_sec % timer_sec) / timer_sec
669    }
670
671    /// Forces a redraw of the progress bar.
672    pub fn redraw(&self) {
673        self.active_config().theme.render(self).unwrap();
674        self.update_next_print();
675    }
676}
677
678/// Internals.
679impl ProgressBar {
680    #[inline]
681    fn next_print(&self) -> usize {
682        self.next_print.load(Relaxed)
683    }
684
685    /// Calculate next print
686    fn update_next_print(&self) {
687        // Give the loop some time to warm up.
688        if self.update_ctr() < 10 {
689            self.next_print.fetch_add(1, Relaxed);
690            return;
691        }
692
693        let freq = (self.updates_per_sec() / self.active_config().max_fps) as usize;
694        let freq = freq.max(1);
695
696        self.next_print.fetch_add(freq as usize, Relaxed);
697    }
698
699    #[inline]
700    fn maybe_redraw(&self, prev: usize) {
701        #[cold]
702        fn cold_redraw(this: &ProgressBar) {
703            if (this.active_config().should_draw)() {
704                this.redraw();
705            }
706        }
707
708        if prev == self.next_print() {
709            cold_redraw(self);
710        }
711    }
712}
713
714// ============================================================================================== //
715// [Iterator integration]                                                                         //
716// ============================================================================================== //
717
718/// Iterator / stream wrapper that automatically updates a progress bar during iteration.
719pub struct ProgressBarIter<Inner> {
720    bar: ProgressBar,
721    inner: Inner,
722}
723
724impl<Inner> ProgressBarIter<Inner> {
725    pub fn into_inner(self) -> Inner {
726        self.inner
727    }
728}
729
730impl<Inner: Iterator> Iterator for ProgressBarIter<Inner> {
731    type Item = Inner::Item;
732
733    fn next(&mut self) -> Option<Self::Item> {
734        let next = self.inner.next()?;
735        self.bar.add(1);
736        Some(next)
737    }
738}
739
740/// Extension trait implemented for all iterators, adding methods for
741/// conveniently adding a progress bar to an existing iterator.
742///
743/// # Example
744///
745/// ```rust
746/// # fn main() {
747/// use zzz::prelude::*;
748///
749/// for _ in (0..123).progress() {
750///     // ...
751/// }
752/// # }
753/// ```
754pub trait ProgressBarIterExt: Iterator + Sized {
755    fn progress(self) -> ProgressBarIter<Self> {
756        let mut bar = ProgressBar::smart();
757        bar.process_size_hint(self.size_hint());
758        ProgressBarIter { bar, inner: self }
759    }
760
761    fn with_progress(self, mut bar: ProgressBar) -> ProgressBarIter<Self> {
762        bar.process_size_hint(self.size_hint());
763        ProgressBarIter { bar, inner: self }
764    }
765}
766
767impl<Inner: Iterator + Sized> ProgressBarIterExt for Inner {}
768
769// ============================================================================================== //
770// [Stream integration]                                                                           //
771// ============================================================================================== //
772
773#[cfg(feature = "streams")]
774pub mod streams {
775    use super::*;
776    use core::pin::Pin;
777    use futures_core::{
778        task::{Context, Poll},
779        Stream,
780    };
781
782    impl<Inner: Stream> Stream for ProgressBarIter<Inner> {
783        type Item = Inner::Item;
784
785        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
786            // SAFETY: This is no different than what pin_project would do, except without
787            //         requiring the dependency on the lib.
788            let (inner, bar) = unsafe {
789                let this = self.get_unchecked_mut();
790                (Pin::new_unchecked(&mut this.inner), &mut this.bar)
791            };
792
793            match inner.poll_next(cx) {
794                x @ Poll::Ready(Some(_)) => {
795                    bar.add(1);
796                    x
797                }
798                x => x,
799            }
800        }
801    }
802
803    /// Extension trait implemented for all streams, adding methods for conveniently adding a
804    /// progress bar to an existing iterator.
805    pub trait ProgressBarStreamExt: Stream + Sized {
806        fn progress(self) -> ProgressBarIter<Self> {
807            let mut bar = ProgressBar::smart();
808            bar.process_size_hint(self.size_hint());
809            ProgressBarIter { bar, inner: self }
810        }
811
812        fn with_progress(self, mut bar: ProgressBar) -> ProgressBarIter<Self> {
813            bar.process_size_hint(self.size_hint());
814            ProgressBarIter { bar, inner: self }
815        }
816    }
817
818    impl<Inner: Stream + Sized> ProgressBarStreamExt for Inner {}
819}
820
821#[cfg(feature = "streams")]
822pub use streams::*;
823
824// ============================================================================================== //
825// [Tests]                                                                                        //
826// ============================================================================================== //
827
828#[cfg(test)]
829mod tests {
830    use super::*;
831
832    #[test]
833    fn zero_target() {
834        for v in [0, 1, 100] {
835            let mut pb = ProgressBar::smart();
836            pb.target = Some(0);
837            pb.value.0.store(v, Relaxed);
838            assert_eq!(pb.progress(), None);
839            assert_eq!(pb.eta(), None);
840        }
841    }
842}
843
844#[cfg(doctest)]
845mod doctests {
846    macro_rules! external_doc_test {
847        ($x:expr) => {
848            #[doc = $x]
849            extern "C" {}
850        };
851    }
852
853    // Ensure the examples in README.md work.
854    external_doc_test!(include_str!("../README.md"));
855}
856
857// ============================================================================================== //