studiole_command/services/
cli_progress.rs1use crate::prelude::*;
2use indicatif::ProgressBar;
3use std::sync::atomic::{AtomicBool, Ordering};
4use tokio::spawn;
5use tokio::sync::broadcast::error::RecvError;
6use tracing::{error, warn};
7
8pub struct CliProgress<T: ICommandInfo> {
10 mediator: Arc<CommandMediator<T>>,
11 bar: Arc<ProgressBar>,
12 handle: Mutex<Option<JoinHandle<()>>>,
13 finished: Arc<AtomicBool>,
14}
15
16impl<T: ICommandInfo + 'static> CliProgress<T> {
17 #[must_use]
19 pub fn new(mediator: Arc<CommandMediator<T>>) -> Self {
20 Self {
21 mediator,
22 bar: Arc::new(ProgressBar::new(0)),
23 handle: Mutex::default(),
24 finished: Arc::new(AtomicBool::new(false)),
25 }
26 }
27
28 pub async fn start(&self) {
30 let mut handle_guard = self.handle.lock().await;
31 if handle_guard.is_some() {
32 return;
33 }
34 let mediator = self.mediator.clone();
35 let mut receiver = mediator.subscribe();
36 let bar = self.bar.clone();
37 let finished = self.finished.clone();
38 let mut total: u64 = 0;
39 let handle = spawn(async move {
40 while !finished.load(Ordering::Acquire) {
41 match receiver.recv().await {
42 Ok(event) => Self::handle_event(&bar, &mut total, event),
43 Err(RecvError::Lagged(count)) => {
44 warn!("CLI Progress missed {count} events due to lagging");
45 }
46 Err(RecvError::Closed) => {
47 error!("Event pipe was closed. CLI Progress can't proceed.");
48 break;
49 }
50 }
51 }
52 });
53 *handle_guard = Some(handle);
54 }
55
56 fn handle_event(bar: &ProgressBar, total: &mut u64, event: T::Event) {
57 match event.get_kind() {
58 EventKind::Queued => {
59 *total += 1;
60 bar.set_length(*total);
61 }
62 EventKind::Executing => {}
63 EventKind::Succeeded | EventKind::Failed => {
64 bar.inc(1);
65 }
66 }
67 }
68
69 pub async fn finish(&self) {
71 self.finished.store(true, Ordering::Release);
72 let mut handle_guard = self.handle.lock().await;
73 if let Some(handle) = handle_guard.take() {
74 handle.abort();
75 }
76 drop(handle_guard);
77 self.bar.finish();
78 }
79
80 #[cfg(test)]
82 pub fn hide(&self) {
83 self.bar
84 .set_draw_target(indicatif::ProgressDrawTarget::hidden());
85 }
86
87 #[cfg(test)]
89 pub fn position(&self) -> u64 {
90 self.bar.position()
91 }
92
93 #[cfg(test)]
95 pub fn length(&self) -> Option<u64> {
96 self.bar.length()
97 }
98}
99
100impl<T: ICommandInfo + 'static> FromServices for CliProgress<T> {
101 type Error = ResolveError;
102
103 fn from_services(services: &ServiceProvider) -> Result<Self, Report<Self::Error>> {
104 Ok(Self::new(services.get::<CommandMediator<T>>()?))
105 }
106}
107
108#[cfg(all(test, feature = "server"))]
109mod tests {
110 #![expect(
111 clippy::as_conversions,
112 reason = "usize to u64 cast in test assertions"
113 )]
114 use super::*;
115
116 const COMMAND_COUNT: usize = CHANNEL_CAPACITY * 2;
117 const WORKER_COUNT: usize = 4;
118 const DELAY_MS: u64 = 1;
119
120 #[tokio::test]
121 async fn cli_progress_receives_all_events() {
122 let services = ServiceBuilder::new().with_commands().build();
124 let runner = services
125 .get_async::<CommandRunner<CommandInfo>>()
126 .await
127 .expect("should be able to get runner");
128 let progress = services
129 .get::<CliProgress<CommandInfo>>()
130 .expect("should be able to get progress");
131 let _logger = init_test_logger();
132 progress.hide();
133
134 progress.start().await;
136 runner.start(WORKER_COUNT).await;
137 for i in 1..=COMMAND_COUNT {
138 let request = DelayRequest::new(format!("P{i}"), DELAY_MS);
139 runner
140 .queue_request(request)
141 .await
142 .expect("should be able to queue request");
143 }
144 runner.drain().await;
145 progress.finish().await;
146
147 assert_eq!(
149 progress.length(),
150 Some(COMMAND_COUNT as u64),
151 "progress bar total should match queued commands"
152 );
153 assert_eq!(
154 progress.position(),
155 COMMAND_COUNT as u64,
156 "progress bar position should match completed commands"
157 );
158 }
159}