Skip to main content

studiole_command/services/
command_mediator.rs

1use crate::prelude::*;
2use tokio::sync::broadcast::{Receiver, Sender, channel};
3
4const CHANNEL_CAPACITY: usize = 16;
5
6/// A mediator between the [`CommandRunner`], [`Worker`] and [`CliProgress`] services.
7pub struct CommandMediator<T: ICommandInfo> {
8    /// Events
9    events: Sender<T::Event>,
10    /// Queue of commands to execute
11    queue: Mutex<VecDeque<T::Request>>,
12    /// Map of requests to their current status.
13    commands: Mutex<HashMap<T::Request, CommandStatus<T>>>,
14    /// Notify workers when new work is available
15    notify_workers: Notify,
16    /// Current status of the runner
17    runner_status: Mutex<RunnerStatus>,
18}
19
20impl<T: ICommandInfo + 'static> Service for CommandMediator<T> {
21    type Error = Infallible;
22
23    async fn from_services(_services: &ServiceProvider) -> Result<Self, Report<Self::Error>> {
24        Ok(Self::new())
25    }
26}
27
28impl<T: ICommandInfo> CommandMediator<T> {
29    pub(super) fn new() -> Self {
30        let (events, _) = channel::<T::Event>(CHANNEL_CAPACITY);
31        Self {
32            events,
33            queue: Mutex::default(),
34            notify_workers: Notify::default(),
35            runner_status: Mutex::default(),
36            commands: Mutex::default(),
37        }
38    }
39
40    async fn get_runner_status(&self) -> RunnerStatus {
41        *self.runner_status.lock().await
42    }
43}
44
45// Implementation for `CommandRunner`
46impl<T: ICommandInfo> CommandMediator<T> {
47    pub(super) async fn set_runner_status(&self, status: RunnerStatus) {
48        trace!(?status, "Set runner status");
49        let mut status_guard = self.runner_status.lock().await;
50        *status_guard = status;
51        drop(status_guard);
52        self.notify_workers.notify_waiters();
53    }
54
55    /// Add a command to the queue.
56    ///
57    /// If the request is already queued or executing then it's ignored and `false` is returned.
58    ///
59    /// If added to the queue then progress is updated and subscribers are notified.
60    pub(super) async fn queue(&self, request: T::Request, command: T::Command) -> bool {
61        trace!(?request, "Queueing");
62        let mut commands = self.commands.lock().await;
63        if let Some(CommandStatus::Queued(_) | CommandStatus::Executing) = commands.get(&request) {
64            trace!(?request, "Skipping as already queued or executing");
65            return false;
66        }
67        commands.insert(request.clone(), CommandStatus::Queued(command));
68        drop(commands);
69        let _ = self
70            .events
71            .send(T::Event::new(EventKind::Queued, request.clone(), None));
72        let mut queue = self.queue.lock().await;
73        queue.push_back(request.clone());
74        drop(queue);
75        trace!(?request, "Queued");
76        trace!(?request, "Notifying worker");
77        self.notify_workers.notify_one();
78        true
79    }
80
81    /// Get the commands.
82    ///
83    /// Note: The [`MutexGuard`] must be dropped or the [`Worker`] will be unable to finish
84    /// execution.
85    pub(super) async fn get_commands(
86        &self,
87    ) -> MutexGuard<'_, HashMap<T::Request, CommandStatus<T>>> {
88        self.commands.lock().await
89    }
90}
91
92// Implementation for `Worker`
93impl<T: ICommandInfo> CommandMediator<T> {
94    /// Get the next instruction.
95    #[allow(clippy::panic)]
96    pub(super) async fn get_instruction(&self) -> Instruction<'_, T> {
97        let notify = self.notify_workers.notified();
98        let mut queue_guard = self.queue.lock().await;
99        if self.get_runner_status().await == RunnerStatus::Stopping {
100            return Instruction::Stop;
101        }
102        if let Some(request) = queue_guard.pop_front() {
103            drop(queue_guard);
104            let _ = self
105                .events
106                .send(T::Event::new(EventKind::Executing, request.clone(), None));
107            let mut commands = self.commands.lock().await;
108            let option = commands.insert(request.clone(), CommandStatus::Executing);
109            drop(commands);
110            let Some(CommandStatus::Queued(command)) = option else {
111                panic!("command should be queued but was {option:?}");
112            };
113            return Instruction::Execute(request, command);
114        }
115        drop(queue_guard);
116        if self.get_runner_status().await == RunnerStatus::Draining {
117            return Instruction::Stop;
118        }
119        Instruction::Wait(notify)
120    }
121
122    /// Add the result of a completed execution.
123    pub(super) async fn completed(
124        &self,
125        request: T::Request,
126        result: Result<T::Success, T::Failure>,
127    ) {
128        let mut commands = self.commands.lock().await;
129        match result {
130            Ok(success) => {
131                trace!(?request, "Command succeeded");
132                commands.insert(request.clone(), CommandStatus::Succeeded(success.clone()));
133                let _ =
134                    self.events
135                        .send(T::Event::new(EventKind::Succeeded, request, Some(success)));
136            }
137            Err(failure) => {
138                warn!(?request, error = ?failure, "Command failed");
139                commands.insert(request.clone(), CommandStatus::Failed(failure));
140                let _ = self
141                    .events
142                    .send(T::Event::new(EventKind::Failed, request, None));
143            }
144        }
145        drop(commands);
146    }
147}
148
149// Implementation for event subscribers
150impl<T: ICommandInfo> CommandMediator<T> {
151    /// Subscribe to events.
152    pub fn subscribe(&self) -> Receiver<T::Event> {
153        self.events.subscribe()
154    }
155}