Skip to main content

studiole_command/services/
cli_progress.rs

1use crate::prelude::*;
2use indicatif::ProgressBar;
3use std::sync::atomic::{AtomicBool, Ordering};
4use tokio::spawn;
5use tokio::sync::broadcast::error::RecvError;
6use tracing::{error, warn};
7
8/// Display command progress as a terminal progress bar.
9pub struct CliProgress<T: ICommandInfo> {
10    mediator: Arc<CommandMediator<T>>,
11    bar: Arc<ProgressBar>,
12    handle: Mutex<Option<JoinHandle<()>>>,
13    finished: Arc<AtomicBool>,
14}
15
16impl<T: ICommandInfo + 'static> CliProgress<T> {
17    /// Create a new [`CliProgress`] backed by a [`CommandMediator`].
18    #[must_use]
19    pub fn new(mediator: Arc<CommandMediator<T>>) -> Self {
20        Self {
21            mediator,
22            bar: Arc::new(ProgressBar::new(0)),
23            handle: Mutex::default(),
24            finished: Arc::new(AtomicBool::new(false)),
25        }
26    }
27
28    /// Start listening for events and updating the progress bar.
29    pub async fn start(&self) {
30        let mut handle_guard = self.handle.lock().await;
31        if handle_guard.is_some() {
32            return;
33        }
34        let mediator = self.mediator.clone();
35        let mut receiver = mediator.subscribe();
36        let bar = self.bar.clone();
37        let finished = self.finished.clone();
38        let mut total: u64 = 0;
39        let handle = spawn(async move {
40            while !finished.load(Ordering::Acquire) {
41                match receiver.recv().await {
42                    Ok(event) => Self::handle_event(&bar, &mut total, event),
43                    Err(RecvError::Lagged(count)) => {
44                        warn!("CLI Progress missed {count} events due to lagging");
45                    }
46                    Err(RecvError::Closed) => {
47                        error!("Event pipe was closed. CLI Progress can't proceed.");
48                        break;
49                    }
50                }
51            }
52        });
53        *handle_guard = Some(handle);
54    }
55
56    fn handle_event(bar: &ProgressBar, total: &mut u64, event: T::Event) {
57        match event.get_kind() {
58            EventKind::Queued => {
59                *total += 1;
60                bar.set_length(*total);
61            }
62            EventKind::Executing => {}
63            EventKind::Succeeded | EventKind::Failed => {
64                bar.inc(1);
65            }
66        }
67    }
68
69    /// Signal completion and abort the listener task.
70    pub async fn finish(&self) {
71        self.finished.store(true, Ordering::Release);
72        let mut handle_guard = self.handle.lock().await;
73        if let Some(handle) = handle_guard.take() {
74            handle.abort();
75        }
76        drop(handle_guard);
77        self.bar.finish();
78    }
79
80    /// Hide the progress bar output.
81    #[cfg(test)]
82    pub fn hide(&self) {
83        self.bar
84            .set_draw_target(indicatif::ProgressDrawTarget::hidden());
85    }
86
87    /// Progress bar position (completed items).
88    #[cfg(test)]
89    pub fn position(&self) -> u64 {
90        self.bar.position()
91    }
92
93    /// Progress bar total length (queued items).
94    #[cfg(test)]
95    pub fn length(&self) -> Option<u64> {
96        self.bar.length()
97    }
98}
99
100impl<T: ICommandInfo + 'static> FromServices for CliProgress<T> {
101    type Error = ResolveError;
102
103    fn from_services(services: &ServiceProvider) -> Result<Self, Report<Self::Error>> {
104        Ok(Self::new(services.get::<CommandMediator<T>>()?))
105    }
106}
107
108#[cfg(all(test, feature = "server"))]
109mod tests {
110    #![expect(
111        clippy::as_conversions,
112        reason = "usize to u64 cast in test assertions"
113    )]
114    use super::*;
115
116    const COMMAND_COUNT: usize = CHANNEL_CAPACITY * 2;
117    const WORKER_COUNT: usize = 4;
118    const DELAY_MS: u64 = 1;
119
120    #[tokio::test]
121    async fn cli_progress_receives_all_events() {
122        // Arrange
123        let services = ServiceBuilder::new().with_commands().build();
124        let runner = services
125            .get_async::<CommandRunner<CommandInfo>>()
126            .await
127            .expect("should be able to get runner");
128        let progress = services
129            .get::<CliProgress<CommandInfo>>()
130            .expect("should be able to get progress");
131        let _logger = init_test_logger();
132        progress.hide();
133
134        // Act
135        progress.start().await;
136        runner.start(WORKER_COUNT).await;
137        for i in 1..=COMMAND_COUNT {
138            let request = DelayRequest::new(format!("P{i}"), DELAY_MS);
139            runner
140                .queue_request(request)
141                .await
142                .expect("should be able to queue request");
143        }
144        runner.drain().await;
145        progress.finish().await;
146
147        // Assert
148        assert_eq!(
149            progress.length(),
150            Some(COMMAND_COUNT as u64),
151            "progress bar total should match queued commands"
152        );
153        assert_eq!(
154            progress.position(),
155            COMMAND_COUNT as u64,
156            "progress bar position should match completed commands"
157        );
158    }
159}