Skip to main content

scirs2_core/
progress.rs

1//! # Progress Tracking for Iterative Algorithms
2//!
3//! Pure-Rust progress tracking with no external dependencies.
4//!
5//! ## Features
6//!
7//! - `ProgressBar` -- configurable progress display (percentage, bar, ETA)
8//! - Callback-based progress notifications
9//! - Nested progress bars via `ProgressGroup`
10//! - Rate estimation (iterations/sec)
11//! - Pure text output (no terminal escape sequences required)
12//!
13//! ## Example
14//!
15//! ```rust
16//! use scirs2_core::progress::{ProgressBar, ProgressStyle};
17//!
18//! let mut pb = ProgressBar::new(100)
19//!     .with_style(ProgressStyle::Bar)
20//!     .with_prefix("Training");
21//!
22//! for i in 0..100 {
23//!     // ... do work ...
24//!     pb.inc(1);
25//! }
26//! pb.finish();
27//! ```
28
29use std::fmt;
30use std::sync::{Arc, Mutex};
31use std::time::{Duration, Instant};
32
33// ---------------------------------------------------------------------------
34// ProgressStyle
35// ---------------------------------------------------------------------------
36
37/// Display style for the progress bar.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum ProgressStyle {
40    /// Show only percentage: `[Training] 42%`
41    Percentage,
42    /// Show a text bar: `[Training] [=====>    ] 50% (5/10) 2.3 it/s ETA 2s`
43    Bar,
44    /// Show only the counter: `[Training] 5/10`
45    Counter,
46    /// Silent -- no output, only callbacks fire.
47    Silent,
48}
49
50impl Default for ProgressStyle {
51    fn default() -> Self {
52        Self::Bar
53    }
54}
55
56// ---------------------------------------------------------------------------
57// ProgressState (internal)
58// ---------------------------------------------------------------------------
59
60#[derive(Debug, Clone)]
61struct ProgressState {
62    current: u64,
63    total: u64,
64    start_time: Instant,
65    last_print_time: Instant,
66    finished: bool,
67}
68
69impl ProgressState {
70    fn new(total: u64) -> Self {
71        let now = Instant::now();
72        Self {
73            current: 0,
74            total,
75            start_time: now,
76            last_print_time: now,
77            finished: false,
78        }
79    }
80
81    fn fraction(&self) -> f64 {
82        if self.total == 0 {
83            return 1.0;
84        }
85        self.current as f64 / self.total as f64
86    }
87
88    fn elapsed(&self) -> Duration {
89        self.start_time.elapsed()
90    }
91
92    fn rate(&self) -> f64 {
93        let secs = self.elapsed().as_secs_f64();
94        if secs > 0.0 {
95            self.current as f64 / secs
96        } else {
97            0.0
98        }
99    }
100
101    fn eta(&self) -> Option<Duration> {
102        if self.current == 0 || self.total == 0 {
103            return None;
104        }
105        let elapsed = self.elapsed().as_secs_f64();
106        let rate = self.current as f64 / elapsed;
107        if rate <= 0.0 {
108            return None;
109        }
110        let remaining = (self.total - self.current) as f64 / rate;
111        Some(Duration::from_secs_f64(remaining))
112    }
113}
114
115// ---------------------------------------------------------------------------
116// ProgressCallback
117// ---------------------------------------------------------------------------
118
119/// A callback invoked each time progress is updated.
120///
121/// Receives `(current, total, elapsed)`.
122pub type ProgressCallback = Box<dyn Fn(u64, u64, Duration) + Send + 'static>;
123
124// ---------------------------------------------------------------------------
125// ProgressBar
126// ---------------------------------------------------------------------------
127
128/// A configurable progress bar for iterative algorithms.
129///
130/// Thread-safe: the internal state is behind an `Arc<Mutex<_>>`.
131pub struct ProgressBar {
132    state: Arc<Mutex<ProgressState>>,
133    style: ProgressStyle,
134    prefix: String,
135    bar_width: usize,
136    min_print_interval: Duration,
137    callbacks: Vec<ProgressCallback>,
138    print_on_finish: bool,
139}
140
141impl ProgressBar {
142    /// Create a new progress bar with the given total number of steps.
143    pub fn new(total: u64) -> Self {
144        Self {
145            state: Arc::new(Mutex::new(ProgressState::new(total))),
146            style: ProgressStyle::default(),
147            prefix: String::new(),
148            bar_width: 30,
149            min_print_interval: Duration::from_millis(100),
150            callbacks: Vec::new(),
151            print_on_finish: true,
152        }
153    }
154
155    /// Set the display style.
156    pub fn with_style(mut self, style: ProgressStyle) -> Self {
157        self.style = style;
158        self
159    }
160
161    /// Set the prefix label.
162    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
163        self.prefix = prefix.into();
164        self
165    }
166
167    /// Set the visual bar width (number of characters). Default: 30.
168    pub fn with_bar_width(mut self, width: usize) -> Self {
169        self.bar_width = width;
170        self
171    }
172
173    /// Set the minimum interval between printed updates. Default: 100ms.
174    pub fn with_min_print_interval(mut self, interval: Duration) -> Self {
175        self.min_print_interval = interval;
176        self
177    }
178
179    /// Disable automatic printing on `finish()`.
180    pub fn silent_finish(mut self) -> Self {
181        self.print_on_finish = false;
182        self
183    }
184
185    /// Register a callback that fires on each `inc()` / `set()` call.
186    pub fn on_progress<F>(mut self, callback: F) -> Self
187    where
188        F: Fn(u64, u64, Duration) + Send + 'static,
189    {
190        self.callbacks.push(Box::new(callback));
191        self
192    }
193
194    /// Increment the progress by `delta` steps.
195    pub fn inc(&mut self, delta: u64) {
196        let (current, total, elapsed, should_print) = {
197            let mut st = match self.state.lock() {
198                Ok(g) => g,
199                Err(poisoned) => poisoned.into_inner(),
200            };
201            st.current = st.current.saturating_add(delta).min(st.total);
202            let now = Instant::now();
203            let should_print = now.duration_since(st.last_print_time) >= self.min_print_interval;
204            if should_print {
205                st.last_print_time = now;
206            }
207            (st.current, st.total, st.elapsed(), should_print)
208        };
209
210        // Fire callbacks
211        for cb in &self.callbacks {
212            cb(current, total, elapsed);
213        }
214
215        // Print if enough time has passed
216        if should_print && self.style != ProgressStyle::Silent {
217            self.print_line();
218        }
219    }
220
221    /// Set the progress to an absolute value.
222    pub fn set(&mut self, value: u64) {
223        let (current, total, elapsed) = {
224            let mut st = match self.state.lock() {
225                Ok(g) => g,
226                Err(poisoned) => poisoned.into_inner(),
227            };
228            st.current = value.min(st.total);
229            (st.current, st.total, st.elapsed())
230        };
231
232        for cb in &self.callbacks {
233            cb(current, total, elapsed);
234        }
235
236        if self.style != ProgressStyle::Silent {
237            self.print_line();
238        }
239    }
240
241    /// Mark the progress as finished and print a final line.
242    pub fn finish(&mut self) {
243        {
244            let mut st = match self.state.lock() {
245                Ok(g) => g,
246                Err(poisoned) => poisoned.into_inner(),
247            };
248            st.current = st.total;
249            st.finished = true;
250        }
251
252        if self.print_on_finish && self.style != ProgressStyle::Silent {
253            self.print_line();
254        }
255    }
256
257    /// Mark the progress as finished with a custom message.
258    pub fn finish_with_message(&mut self, msg: &str) {
259        {
260            let mut st = match self.state.lock() {
261                Ok(g) => g,
262                Err(poisoned) => poisoned.into_inner(),
263            };
264            st.current = st.total;
265            st.finished = true;
266        }
267
268        if self.style != ProgressStyle::Silent {
269            let prefix = if self.prefix.is_empty() {
270                String::new()
271            } else {
272                format!("[{}] ", self.prefix)
273            };
274            eprintln!("{prefix}{msg}");
275        }
276    }
277
278    /// Reset the progress bar to 0 with a new total.
279    pub fn reset(&mut self, total: u64) {
280        let mut st = match self.state.lock() {
281            Ok(g) => g,
282            Err(poisoned) => poisoned.into_inner(),
283        };
284        *st = ProgressState::new(total);
285    }
286
287    /// Get the current position.
288    pub fn position(&self) -> u64 {
289        match self.state.lock() {
290            Ok(st) => st.current,
291            Err(poisoned) => poisoned.into_inner().current,
292        }
293    }
294
295    /// Get the total.
296    pub fn total(&self) -> u64 {
297        match self.state.lock() {
298            Ok(st) => st.total,
299            Err(poisoned) => poisoned.into_inner().total,
300        }
301    }
302
303    /// Get elapsed time.
304    pub fn elapsed(&self) -> Duration {
305        match self.state.lock() {
306            Ok(st) => st.elapsed(),
307            Err(poisoned) => poisoned.into_inner().elapsed(),
308        }
309    }
310
311    /// Get current rate (iterations/sec).
312    pub fn rate(&self) -> f64 {
313        match self.state.lock() {
314            Ok(st) => st.rate(),
315            Err(poisoned) => poisoned.into_inner().rate(),
316        }
317    }
318
319    /// Get the estimated time remaining.
320    pub fn eta(&self) -> Option<Duration> {
321        match self.state.lock() {
322            Ok(st) => st.eta(),
323            Err(poisoned) => poisoned.into_inner().eta(),
324        }
325    }
326
327    /// Format the current state as a string (without printing).
328    pub fn format_line(&self) -> String {
329        let st = match self.state.lock() {
330            Ok(g) => g.clone(),
331            Err(poisoned) => poisoned.into_inner().clone(),
332        };
333        self.format_state(&st)
334    }
335
336    // -- internal helpers --
337
338    fn print_line(&self) {
339        let st = match self.state.lock() {
340            Ok(g) => g.clone(),
341            Err(poisoned) => poisoned.into_inner().clone(),
342        };
343        let line = self.format_state(&st);
344        eprintln!("{line}");
345    }
346
347    fn format_state(&self, st: &ProgressState) -> String {
348        let prefix = if self.prefix.is_empty() {
349            String::new()
350        } else {
351            format!("[{}] ", self.prefix)
352        };
353
354        match self.style {
355            ProgressStyle::Percentage => {
356                let pct = (st.fraction() * 100.0) as u32;
357                format!("{prefix}{pct}%")
358            }
359            ProgressStyle::Counter => {
360                format!("{prefix}{}/{}", st.current, st.total)
361            }
362            ProgressStyle::Bar => {
363                let pct = (st.fraction() * 100.0) as u32;
364                let filled = (st.fraction() * self.bar_width as f64) as usize;
365                let empty = self.bar_width.saturating_sub(filled);
366
367                let bar: String = "=".repeat(filled);
368                let arrow = if filled < self.bar_width && !st.finished {
369                    ">"
370                } else {
371                    ""
372                };
373                let spaces_count = if !arrow.is_empty() {
374                    empty.saturating_sub(1)
375                } else {
376                    empty
377                };
378                let spaces: String = " ".repeat(spaces_count);
379
380                let rate = st.rate();
381                let rate_str = if rate >= 1.0 {
382                    format!("{rate:.1} it/s")
383                } else if rate > 0.0 {
384                    let spi = 1.0 / rate;
385                    format!("{spi:.1} s/it")
386                } else {
387                    "-- it/s".to_string()
388                };
389
390                let eta_str = match st.eta() {
391                    Some(eta) => format_duration_short(eta),
392                    None => "?".to_string(),
393                };
394
395                let elapsed_str = format_duration_short(st.elapsed());
396
397                if st.finished {
398                    format!(
399                        "{prefix}[{bar}{arrow}{spaces}] {pct}% ({}/{}) {rate_str} elapsed {elapsed_str}",
400                        st.current, st.total,
401                    )
402                } else {
403                    format!(
404                        "{prefix}[{bar}{arrow}{spaces}] {pct}% ({}/{}) {rate_str} ETA {eta_str}",
405                        st.current, st.total,
406                    )
407                }
408            }
409            ProgressStyle::Silent => String::new(),
410        }
411    }
412}
413
414impl fmt::Debug for ProgressBar {
415    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416        let st = match self.state.lock() {
417            Ok(g) => g.clone(),
418            Err(poisoned) => poisoned.into_inner().clone(),
419        };
420        f.debug_struct("ProgressBar")
421            .field("current", &st.current)
422            .field("total", &st.total)
423            .field("style", &self.style)
424            .field("prefix", &self.prefix)
425            .finish()
426    }
427}
428
429// ---------------------------------------------------------------------------
430// ProgressGroup (nested progress bars)
431// ---------------------------------------------------------------------------
432
433/// A group of named progress bars for tracking nested / multi-phase operations.
434///
435/// Each bar is identified by a string key.
436pub struct ProgressGroup {
437    bars: Vec<(String, ProgressBar)>,
438}
439
440impl ProgressGroup {
441    /// Create an empty progress group.
442    pub fn new() -> Self {
443        Self { bars: Vec::new() }
444    }
445
446    /// Add a new progress bar to the group.
447    pub fn add(&mut self, name: impl Into<String>, total: u64, style: ProgressStyle) {
448        let name = name.into();
449        let pb = ProgressBar::new(total)
450            .with_style(style)
451            .with_prefix(name.clone())
452            .silent_finish(); // we manage printing ourselves
453        self.bars.push((name, pb));
454    }
455
456    /// Get a mutable reference to a named bar. Returns `None` if not found.
457    pub fn get_mut(&mut self, name: &str) -> Option<&mut ProgressBar> {
458        self.bars
459            .iter_mut()
460            .find(|(n, _)| n == name)
461            .map(|(_, pb)| pb)
462    }
463
464    /// Print a summary of all progress bars.
465    pub fn print_summary(&self) {
466        for (name, pb) in &self.bars {
467            let line = pb.format_line();
468            if !line.is_empty() {
469                eprintln!("{line}");
470            } else {
471                // Silent bar -- print basic info
472                eprintln!("[{name}] {}/{}", pb.position(), pb.total());
473            }
474        }
475    }
476
477    /// Get the overall fraction complete (average across all bars).
478    pub fn overall_fraction(&self) -> f64 {
479        if self.bars.is_empty() {
480            return 0.0;
481        }
482        let sum: f64 = self
483            .bars
484            .iter()
485            .map(|(_, pb)| {
486                let total = pb.total();
487                if total == 0 {
488                    1.0
489                } else {
490                    pb.position() as f64 / total as f64
491                }
492            })
493            .sum();
494        sum / self.bars.len() as f64
495    }
496
497    /// Number of bars in the group.
498    pub fn len(&self) -> usize {
499        self.bars.len()
500    }
501
502    /// Whether the group is empty.
503    pub fn is_empty(&self) -> bool {
504        self.bars.is_empty()
505    }
506}
507
508impl Default for ProgressGroup {
509    fn default() -> Self {
510        Self::new()
511    }
512}
513
514impl fmt::Debug for ProgressGroup {
515    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
516        f.debug_struct("ProgressGroup")
517            .field("count", &self.bars.len())
518            .field(
519                "bars",
520                &self
521                    .bars
522                    .iter()
523                    .map(|(n, pb)| format!("{n}: {}/{}", pb.position(), pb.total()))
524                    .collect::<Vec<_>>(),
525            )
526            .finish()
527    }
528}
529
530// ---------------------------------------------------------------------------
531// Helper: format duration
532// ---------------------------------------------------------------------------
533
534fn format_duration_short(d: Duration) -> String {
535    let total_secs = d.as_secs();
536    if total_secs < 60 {
537        format!("{}s", total_secs)
538    } else if total_secs < 3600 {
539        let m = total_secs / 60;
540        let s = total_secs % 60;
541        format!("{m}m{s:02}s")
542    } else {
543        let h = total_secs / 3600;
544        let m = (total_secs % 3600) / 60;
545        let s = total_secs % 60;
546        format!("{h}h{m:02}m{s:02}s")
547    }
548}
549
550// ---------------------------------------------------------------------------
551// Tests
552// ---------------------------------------------------------------------------
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[test]
559    fn test_progress_bar_basic() {
560        let mut pb = ProgressBar::new(10).with_style(ProgressStyle::Silent);
561        assert_eq!(pb.position(), 0);
562        assert_eq!(pb.total(), 10);
563
564        pb.inc(3);
565        assert_eq!(pb.position(), 3);
566
567        pb.inc(7);
568        assert_eq!(pb.position(), 10);
569    }
570
571    #[test]
572    fn test_progress_bar_set() {
573        let mut pb = ProgressBar::new(100).with_style(ProgressStyle::Silent);
574        pb.set(50);
575        assert_eq!(pb.position(), 50);
576        pb.set(200); // clamped to total
577        assert_eq!(pb.position(), 100);
578    }
579
580    #[test]
581    fn test_progress_bar_finish() {
582        let mut pb = ProgressBar::new(100).with_style(ProgressStyle::Silent);
583        pb.inc(50);
584        pb.finish();
585        assert_eq!(pb.position(), 100);
586    }
587
588    #[test]
589    fn test_progress_bar_reset() {
590        let mut pb = ProgressBar::new(100).with_style(ProgressStyle::Silent);
591        pb.inc(50);
592        pb.reset(200);
593        assert_eq!(pb.position(), 0);
594        assert_eq!(pb.total(), 200);
595    }
596
597    #[test]
598    fn test_progress_bar_rate() {
599        let mut pb = ProgressBar::new(100).with_style(ProgressStyle::Silent);
600        pb.inc(10);
601        // Rate should be positive after some work
602        let rate = pb.rate();
603        assert!(rate >= 0.0);
604    }
605
606    #[test]
607    fn test_progress_bar_eta() {
608        let mut pb = ProgressBar::new(100).with_style(ProgressStyle::Silent);
609        // ETA is None at start
610        assert!(pb.eta().is_none() || pb.eta().is_some());
611        pb.inc(50);
612        // After progress, ETA might be available
613    }
614
615    #[test]
616    fn test_progress_bar_format_percentage() {
617        let mut pb = ProgressBar::new(100)
618            .with_style(ProgressStyle::Percentage)
619            .with_prefix("Test");
620        pb.set(42);
621        let line = pb.format_line();
622        assert!(line.contains("42%"));
623        assert!(line.contains("[Test]"));
624    }
625
626    #[test]
627    fn test_progress_bar_format_counter() {
628        let mut pb = ProgressBar::new(100)
629            .with_style(ProgressStyle::Counter)
630            .with_prefix("Count");
631        pb.set(25);
632        let line = pb.format_line();
633        assert!(line.contains("25/100"));
634    }
635
636    #[test]
637    fn test_progress_bar_format_bar() {
638        let mut pb = ProgressBar::new(100)
639            .with_style(ProgressStyle::Bar)
640            .with_prefix("Work")
641            .with_bar_width(20);
642        pb.set(50);
643        let line = pb.format_line();
644        assert!(line.contains("50%"));
645        assert!(line.contains("[Work]"));
646        assert!(line.contains("ETA"));
647    }
648
649    #[test]
650    fn test_progress_bar_format_bar_finished() {
651        let mut pb = ProgressBar::new(100)
652            .with_style(ProgressStyle::Bar)
653            .with_bar_width(10);
654        pb.finish();
655        let line = pb.format_line();
656        assert!(line.contains("100%"));
657        assert!(line.contains("elapsed"));
658    }
659
660    #[test]
661    fn test_progress_bar_silent() {
662        let mut pb = ProgressBar::new(100).with_style(ProgressStyle::Silent);
663        pb.inc(50);
664        let line = pb.format_line();
665        assert!(line.is_empty());
666    }
667
668    #[test]
669    fn test_progress_bar_callback() {
670        let called = Arc::new(Mutex::new(0u64));
671        let called_clone = Arc::clone(&called);
672
673        let mut pb = ProgressBar::new(10)
674            .with_style(ProgressStyle::Silent)
675            .on_progress(move |current, _total, _elapsed| {
676                let mut c = called_clone.lock().expect("lock failed");
677                *c = current;
678            });
679
680        pb.inc(5);
681        let val = {
682            let c = called.lock().expect("lock failed");
683            *c
684        };
685        assert_eq!(val, 5);
686    }
687
688    #[test]
689    fn test_progress_bar_overflow_protection() {
690        let mut pb = ProgressBar::new(10).with_style(ProgressStyle::Silent);
691        pb.inc(100); // should clamp to 10
692        assert_eq!(pb.position(), 10);
693    }
694
695    #[test]
696    fn test_progress_group_basic() {
697        let mut group = ProgressGroup::new();
698        group.add("phase1", 100, ProgressStyle::Silent);
699        group.add("phase2", 200, ProgressStyle::Silent);
700
701        assert_eq!(group.len(), 2);
702        assert!(!group.is_empty());
703
704        if let Some(pb) = group.get_mut("phase1") {
705            pb.inc(50);
706        }
707
708        let frac = group.overall_fraction();
709        assert!(frac > 0.0 && frac < 1.0);
710    }
711
712    #[test]
713    fn test_progress_group_get_unknown() {
714        let mut group = ProgressGroup::new();
715        assert!(group.get_mut("nonexistent").is_none());
716    }
717
718    #[test]
719    fn test_progress_group_empty() {
720        let group = ProgressGroup::new();
721        assert!(group.is_empty());
722        assert_eq!(group.overall_fraction(), 0.0);
723    }
724
725    #[test]
726    fn test_format_duration_short() {
727        assert_eq!(format_duration_short(Duration::from_secs(5)), "5s");
728        assert_eq!(format_duration_short(Duration::from_secs(65)), "1m05s");
729        assert_eq!(format_duration_short(Duration::from_secs(3661)), "1h01m01s");
730    }
731
732    #[test]
733    fn test_progress_bar_debug() {
734        let pb = ProgressBar::new(100)
735            .with_style(ProgressStyle::Bar)
736            .with_prefix("Debug");
737        let dbg = format!("{pb:?}");
738        assert!(dbg.contains("ProgressBar"));
739        assert!(dbg.contains("Debug"));
740    }
741
742    #[test]
743    fn test_progress_bar_elapsed() {
744        let pb = ProgressBar::new(100).with_style(ProgressStyle::Silent);
745        let e = pb.elapsed();
746        assert!(e >= Duration::ZERO);
747    }
748
749    #[test]
750    fn test_progress_group_debug() {
751        let mut group = ProgressGroup::new();
752        group.add("a", 10, ProgressStyle::Silent);
753        let dbg = format!("{group:?}");
754        assert!(dbg.contains("ProgressGroup"));
755    }
756
757    #[test]
758    fn test_progress_bar_zero_total() {
759        let mut pb = ProgressBar::new(0).with_style(ProgressStyle::Silent);
760        pb.inc(1); // should clamp
761        assert_eq!(pb.position(), 0);
762    }
763
764    #[test]
765    fn test_progress_bar_finish_with_message() {
766        let mut pb = ProgressBar::new(100).with_style(ProgressStyle::Silent);
767        pb.finish_with_message("Done!");
768        assert_eq!(pb.position(), 100);
769    }
770}