Skip to main content

studiole_command/services/
command_runner.rs

1#![allow(dead_code)]
2
3use crate::prelude::*;
4use tokio::sync::MutexGuard;
5
6/// Current state of the [`CommandRunner`].
7#[derive(Clone, Copy, Debug, Default, Eq, Error, PartialEq)]
8pub enum RunnerStatus {
9    #[default]
10    #[error("Runner is stopped")]
11    Stopped,
12    #[error("Stopping when the active commands are complete")]
13    Stopping,
14    #[error("Stopping when the queue is empty")]
15    Draining,
16    #[error("Running")]
17    Running,
18}
19
20/// Queue and execute commands across a pool of workers.
21pub struct CommandRunner<T: ICommandInfo> {
22    mediator: Arc<CommandMediator<T>>,
23    registry: Arc<CommandRegistry<T>>,
24    workers: Arc<WorkerPool<T>>,
25}
26
27impl<T: ICommandInfo + 'static> FromServicesAsync for CommandRunner<T> {
28    type Error = ResolveError;
29
30    async fn from_services_async(services: &ServiceProvider) -> Result<Self, Report<Self::Error>> {
31        Ok(Self::new(
32            services.get::<CommandMediator<T>>()?,
33            services.get_async::<CommandRegistry<T>>().await?,
34            services.get::<WorkerPool<T>>()?,
35        ))
36    }
37}
38
39impl<T: ICommandInfo + 'static> CommandRunner<T> {
40    /// Create a new [`CommandRunner`].
41    #[must_use]
42    pub fn new(
43        mediator: Arc<CommandMediator<T>>,
44        registry: Arc<CommandRegistry<T>>,
45        workers: Arc<WorkerPool<T>>,
46    ) -> Self {
47        Self {
48            mediator,
49            registry,
50            workers,
51        }
52    }
53
54    /// Start any number of workers.
55    ///
56    /// Each worker will have a unique ID.
57    ///
58    /// Status will be set to `Running`.
59    pub async fn start(&self, worker_count: usize) {
60        self.workers.start(worker_count).await;
61    }
62
63    /// Stop workers after draining the queue.
64    pub async fn drain(&self) {
65        self.mediator
66            .set_runner_status(RunnerStatus::Draining)
67            .await;
68        self.workers.wait_for_stop().await;
69    }
70
71    /// Stop workers after their current work is complete
72    pub async fn stop(&self) {
73        self.mediator
74            .set_runner_status(RunnerStatus::Stopping)
75            .await;
76        self.workers.wait_for_stop().await;
77    }
78
79    /// Queue a command as a request.
80    pub async fn queue_request<R: Executable + Into<T::Request> + Send + Sync + 'static>(
81        &self,
82        request: R,
83    ) -> Result<(), Report<QueueError>> {
84        trace!(%request, type = type_name::<R>(), "Queueing");
85        let command = self.registry.resolve(request.clone())?;
86        trace!(%request, type = type_name::<R>(), "Resolved command");
87        self.mediator.queue(request.into(), command).await;
88        Ok(())
89    }
90
91    /// Lock and return the current command status map.
92    ///
93    /// The [`MutexGuard`] must be dropped promptly or [`Worker`] execution will block.
94    pub async fn get_commands(&self) -> MutexGuard<'_, HashMap<T::Request, CommandStatus<T>>> {
95        self.mediator.get_commands().await
96    }
97}
98
99#[cfg(all(test, feature = "server"))]
100mod tests {
101    use super::*;
102
103    use std::time::Duration;
104    use tokio::time::sleep;
105
106    const WORKER_COUNT: usize = 3;
107    const A_COUNT: usize = 10;
108    const B_COUNT: usize = 10;
109    const A_DURATON: u64 = 100;
110    const B_DURATON: u64 = 100;
111    #[allow(clippy::as_conversions, clippy::integer_division)]
112    const A_TOTAL_DURATON: u64 = (A_COUNT / WORKER_COUNT) as u64 * A_DURATON;
113
114    #[tokio::test]
115    async fn command_runner() {
116        // Arrange
117        let services = ServiceBuilder::new().with_commands().build();
118        let runner = services
119            .get_async::<CommandRunner<CommandInfo>>()
120            .await
121            .expect("should be able to get runner");
122        let events = services
123            .get::<CommandEvents<CommandInfo>>()
124            .expect("should be able to get events");
125        events.start().await;
126        let _logger = init_test_logger();
127
128        // Act
129        runner.start(WORKER_COUNT).await;
130
131        info!("Adding {A_COUNT} commands to queue");
132        for i in 1..=A_COUNT {
133            let request = DelayRequest::new(format!("A{i}"), A_DURATON);
134            runner
135                .queue_request(request)
136                .await
137                .expect("should be able to queue command");
138        }
139        info!("Added {A_COUNT} commands to queue");
140
141        // Assert
142        let length = events
143            .count()
144            .await
145            .get_currently_queued()
146            .expect("should be able to subtract");
147        debug!("Queue: {length}");
148        // assert_eq!(length, A_COUNT, "Queue immediately after sending batch A");
149
150        wait(50).await;
151        let length = events
152            .count()
153            .await
154            .get_currently_queued()
155            .expect("should be able to subtract");
156        debug!("Queue: {length}");
157        assert_ne!(length, 0, "Queue soon after adding batch A");
158
159        wait(A_TOTAL_DURATON + 100).await;
160        let length = events
161            .count()
162            .await
163            .get_currently_queued()
164            .expect("should be able to subtract");
165        debug!("Queue: {length}");
166        assert_eq!(length, 0, "Queue after batch A should have completed");
167
168        info!("Adding {B_COUNT} commands to queue");
169        for i in 1..=B_COUNT {
170            let request = DelayRequest::new(format!("B{i}"), B_DURATON);
171            runner
172                .queue_request(request)
173                .await
174                .expect("should be able to queue command");
175        }
176        info!("Added {B_COUNT} commands to queue");
177
178        wait(50).await;
179        info!("Requesting stop");
180        runner.workers.stop().await;
181        info!("Completed stop");
182
183        let count = events.count().await;
184        let length = count
185            .get_currently_queued()
186            .expect("should be able to subtract");
187        debug!("Queue: {length}");
188        assert_eq!(length, 7, "Queue after stop");
189        let length = count.succeeded;
190        debug!("Succeeded: {length}");
191        assert_eq!(length, 13, "Succeeded after stop");
192    }
193
194    async fn wait(wait: u64) {
195        sleep(Duration::from_millis(wait)).await;
196        info!("Waiting {wait} ms");
197    }
198}