studiole_command/services/
cli_progress.rs1use crate::prelude::*;
2use indicatif::ProgressBar;
3use std::sync::atomic::{AtomicBool, Ordering};
4use tokio::spawn;
5use tokio::sync::broadcast::error::RecvError;
6use tracing::{error, warn};
7
8pub struct CliProgress<T: ICommandInfo> {
10 mediator: Arc<CommandMediator<T>>,
11 bar: Arc<ProgressBar>,
12 handle: Mutex<Option<JoinHandle<()>>>,
13 finished: Arc<AtomicBool>,
14}
15
16impl<T: ICommandInfo + 'static> CliProgress<T> {
17 #[must_use]
19 pub fn new(mediator: Arc<CommandMediator<T>>) -> Self {
20 Self {
21 mediator,
22 bar: Arc::new(ProgressBar::new(0)),
23 handle: Mutex::default(),
24 finished: Arc::new(AtomicBool::new(false)),
25 }
26 }
27
28 pub async fn start(&self) {
30 let mut handle_guard = self.handle.lock().await;
31 if handle_guard.is_some() {
32 return;
33 }
34 let mediator = self.mediator.clone();
35 let mut receiver = mediator.subscribe();
36 let bar = self.bar.clone();
37 let finished = self.finished.clone();
38 let mut total: u64 = 0;
39 let handle = spawn(async move {
40 while !finished.load(Ordering::Acquire) {
41 match receiver.recv().await {
42 Ok(event) => Self::handle_event(&bar, &mut total, event),
43 Err(RecvError::Lagged(count)) => {
44 warn!("CLI Progress missed {count} events due to lagging");
45 }
46 Err(RecvError::Closed) => {
47 error!("Event pipe was closed. CLI Progress can't proceed.");
48 break;
49 }
50 }
51 }
52 });
53 *handle_guard = Some(handle);
54 }
55
56 fn handle_event(bar: &ProgressBar, total: &mut u64, event: T::Event) {
57 match event.get_kind() {
58 EventKind::Queued => {
59 *total += 1;
60 bar.set_length(*total);
61 }
62 EventKind::Executing => {}
63 EventKind::Succeeded | EventKind::Failed => {
64 bar.inc(1);
65 }
66 }
67 }
68
69 pub async fn finish(&self) {
71 self.finished.store(true, Ordering::Release);
72 let mut handle_guard = self.handle.lock().await;
73 if let Some(handle) = handle_guard.take() {
74 handle.abort();
75 }
76 drop(handle_guard);
77 self.bar.finish();
78 }
79
80 #[cfg(test)]
82 pub fn hide(&self) {
83 self.bar
84 .set_draw_target(indicatif::ProgressDrawTarget::hidden());
85 }
86
87 #[cfg(test)]
89 pub fn position(&self) -> u64 {
90 self.bar.position()
91 }
92
93 #[cfg(test)]
95 pub fn length(&self) -> Option<u64> {
96 self.bar.length()
97 }
98}
99
100impl<T: ICommandInfo + 'static> Service for CliProgress<T> {
101 type Error = ServiceError;
102
103 async fn from_services(services: &ServiceProvider) -> Result<Self, Report<Self::Error>> {
104 Ok(Self::new(services.get_service().await?))
105 }
106}
107
108#[cfg(all(test, feature = "server"))]
109mod tests {
110 #![expect(
111 clippy::as_conversions,
112 reason = "usize to u64 cast in test assertions"
113 )]
114 use super::*;
115
116 const COMMAND_COUNT: usize = CHANNEL_CAPACITY * 2;
117 const WORKER_COUNT: usize = 4;
118 const DELAY_MS: u64 = 1;
119
120 #[tokio::test]
121 async fn cli_progress_receives_all_events() {
122 let services = ServiceProvider::new()
124 .with_commands()
125 .await
126 .expect("should be able to create services with commands");
127 let runner = services
128 .get_service::<CommandRunner<CommandInfo>>()
129 .await
130 .expect("should be able to get runner");
131 let progress = services
132 .get_service::<CliProgress<CommandInfo>>()
133 .await
134 .expect("should be able to get progress");
135 let _logger = init_test_logger();
136 progress.hide();
137
138 progress.start().await;
140 runner.start(WORKER_COUNT).await;
141 for i in 1..=COMMAND_COUNT {
142 let request = DelayRequest::new(format!("P{i}"), DELAY_MS);
143 runner
144 .queue_request(request)
145 .await
146 .expect("should be able to queue request");
147 }
148 runner.drain().await;
149 progress.finish().await;
150
151 assert_eq!(
153 progress.length(),
154 Some(COMMAND_COUNT as u64),
155 "progress bar total should match queued commands"
156 );
157 assert_eq!(
158 progress.position(),
159 COMMAND_COUNT as u64,
160 "progress bar position should match completed commands"
161 );
162 }
163}