Skip to main content

studiole_command/services/
command_runner.rs

1#![allow(dead_code)]
2
3use crate::prelude::*;
4use tokio::sync::MutexGuard;
5
6/// Current state of the [`CommandRunner`].
7#[derive(Clone, Copy, Debug, Default, Eq, Error, PartialEq)]
8pub enum RunnerStatus {
9    #[default]
10    #[error("Runner is stopped")]
11    Stopped,
12    #[error("Stopping when the active commands are complete")]
13    Stopping,
14    #[error("Stopping when the queue is empty")]
15    Draining,
16    #[error("Running")]
17    Running,
18}
19
20/// Queue and execute commands across a pool of workers.
21pub struct CommandRunner<T: ICommandInfo> {
22    mediator: Arc<CommandMediator<T>>,
23    registry: Arc<CommandRegistry<T>>,
24    workers: Arc<WorkerPool<T>>,
25}
26
27impl<T: ICommandInfo + 'static> Service for CommandRunner<T> {
28    type Error = ServiceError;
29
30    async fn from_services(services: &ServiceProvider) -> Result<Self, Report<Self::Error>> {
31        Ok(Self::new(
32            services.get_service().await?,
33            services.get_service().await?,
34            services.get_service().await?,
35        ))
36    }
37}
38
39impl<T: ICommandInfo + 'static> CommandRunner<T> {
40    /// Create a new [`CommandRunner`].
41    #[must_use]
42    pub fn new(
43        mediator: Arc<CommandMediator<T>>,
44        registry: Arc<CommandRegistry<T>>,
45        workers: Arc<WorkerPool<T>>,
46    ) -> Self {
47        Self {
48            mediator,
49            registry,
50            workers,
51        }
52    }
53
54    /// Start any number of workers.
55    ///
56    /// Each worker will have a unique ID.
57    ///
58    /// Status will be set to `Running`.
59    pub async fn start(&self, worker_count: usize) {
60        self.workers.start(worker_count).await;
61    }
62
63    /// Stop workers after draining the queue.
64    pub async fn drain(&self) {
65        self.mediator
66            .set_runner_status(RunnerStatus::Draining)
67            .await;
68        self.workers.wait_for_stop().await;
69    }
70
71    /// Stop workers after their current work is complete
72    pub async fn stop(&self) {
73        self.mediator
74            .set_runner_status(RunnerStatus::Stopping)
75            .await;
76        self.workers.wait_for_stop().await;
77    }
78
79    /// Queue a command as a request.
80    pub async fn queue_request<R: Executable + Into<T::Request> + Send + Sync + 'static>(
81        &self,
82        request: R,
83    ) -> Result<(), Report<QueueError>> {
84        trace!(%request, type = type_name::<R>(), "Queueing");
85        let command = self.registry.resolve(request.clone())?;
86        trace!(%request, type = type_name::<R>(), "Resolved command");
87        self.mediator.queue(request.into(), command).await;
88        Ok(())
89    }
90
91    /// Lock and return the current command status map.
92    ///
93    /// The [`MutexGuard`] must be dropped promptly or [`Worker`] execution will block.
94    pub async fn get_commands(&self) -> MutexGuard<'_, HashMap<T::Request, CommandStatus<T>>> {
95        self.mediator.get_commands().await
96    }
97}
98
99#[cfg(all(test, feature = "server"))]
100mod tests {
101    use super::*;
102
103    use std::time::Duration;
104    use tokio::time::sleep;
105
106    const WORKER_COUNT: usize = 3;
107    const A_COUNT: usize = 10;
108    const B_COUNT: usize = 10;
109    const A_DURATON: u64 = 100;
110    const B_DURATON: u64 = 100;
111    #[allow(clippy::as_conversions, clippy::integer_division)]
112    const A_TOTAL_DURATON: u64 = (A_COUNT / WORKER_COUNT) as u64 * A_DURATON;
113
114    #[tokio::test]
115    async fn command_runner() {
116        // Arrange
117        let services = ServiceProvider::new()
118            .with_commands()
119            .await
120            .expect("should be able to create services with commands");
121        let runner = services
122            .get_service::<CommandRunner<CommandInfo>>()
123            .await
124            .expect("should be able to get runner");
125        let events = services
126            .get_service::<CommandEvents<CommandInfo>>()
127            .await
128            .expect("should be able to get events");
129        events.start().await;
130        let _logger = init_test_logger();
131
132        // Act
133        runner.start(WORKER_COUNT).await;
134
135        info!("Adding {A_COUNT} commands to queue");
136        for i in 1..=A_COUNT {
137            let request = DelayRequest::new(format!("A{i}"), A_DURATON);
138            runner
139                .queue_request(request)
140                .await
141                .expect("should be able to queue command");
142        }
143        info!("Added {A_COUNT} commands to queue");
144
145        // Assert
146        let length = events
147            .count()
148            .await
149            .get_currently_queued()
150            .expect("should be able to subtract");
151        debug!("Queue: {length}");
152        // assert_eq!(length, A_COUNT, "Queue immediately after sending batch A");
153
154        wait(50).await;
155        let length = events
156            .count()
157            .await
158            .get_currently_queued()
159            .expect("should be able to subtract");
160        debug!("Queue: {length}");
161        assert_ne!(length, 0, "Queue soon after adding batch A");
162
163        wait(A_TOTAL_DURATON + 100).await;
164        let length = events
165            .count()
166            .await
167            .get_currently_queued()
168            .expect("should be able to subtract");
169        debug!("Queue: {length}");
170        assert_eq!(length, 0, "Queue after batch A should have completed");
171
172        info!("Adding {B_COUNT} commands to queue");
173        for i in 1..=B_COUNT {
174            let request = DelayRequest::new(format!("B{i}"), B_DURATON);
175            runner
176                .queue_request(request)
177                .await
178                .expect("should be able to queue command");
179        }
180        info!("Added {B_COUNT} commands to queue");
181
182        wait(50).await;
183        info!("Requesting stop");
184        runner.workers.stop().await;
185        info!("Completed stop");
186
187        let count = events.count().await;
188        let length = count
189            .get_currently_queued()
190            .expect("should be able to subtract");
191        debug!("Queue: {length}");
192        assert_eq!(length, 7, "Queue after stop");
193        let length = count.succeeded;
194        debug!("Succeeded: {length}");
195        assert_eq!(length, 13, "Succeeded after stop");
196    }
197
198    async fn wait(wait: u64) {
199        sleep(Duration::from_millis(wait)).await;
200        info!("Waiting {wait} ms");
201    }
202}