squads_temporal_client/worker_registry/
mod.rs

1//! This module enables the tracking of workers that are associated with a client instance.
2//! This is needed to implement Eager Workflow Start, a latency optimization in which the client,
3//!  after reserving a slot, directly forwards a WFT to a local worker.
4
5use parking_lot::RwLock;
6use slotmap::SlotMap;
7use std::collections::{HashMap, hash_map::Entry::Vacant};
8
9use squads_temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse;
10
11slotmap::new_key_type! {
12    /// Registration key for a worker
13    pub struct WorkerKey;
14}
15
16/// This trait is implemented by an object associated with a worker, which provides WFT processing slots.
17#[cfg_attr(test, mockall::automock)]
18pub trait SlotProvider: std::fmt::Debug {
19    /// The namespace for the WFTs that it can process.
20    fn namespace(&self) -> &str;
21    /// The task queue this provider listens to.
22    fn task_queue(&self) -> &str;
23    /// Try to reserve a slot on this worker.
24    fn try_reserve_wft_slot(&self) -> Option<Box<dyn Slot + Send>>;
25}
26
27/// This trait represents a slot reserved for processing a WFT by a worker.
28#[cfg_attr(test, mockall::automock)]
29pub trait Slot {
30    /// Consumes this slot by dispatching a WFT to its worker. This can only be called once.
31    fn schedule_wft(
32        self: Box<Self>,
33        task: PollWorkflowTaskQueueResponse,
34    ) -> Result<(), anyhow::Error>;
35}
36
37#[derive(PartialEq, Eq, Hash, Debug, Clone)]
38struct SlotKey {
39    namespace: String,
40    task_queue: String,
41}
42
43impl SlotKey {
44    fn new(namespace: String, task_queue: String) -> SlotKey {
45        SlotKey {
46            namespace,
47            task_queue,
48        }
49    }
50}
51
52/// This is an inner class for [SlotManager] needed to hide the mutex.
53#[derive(Default, Debug)]
54struct SlotManagerImpl {
55    /// Maps keys, i.e., namespace#task_queue, to provider.
56    providers: HashMap<SlotKey, Box<dyn SlotProvider + Send + Sync>>,
57    /// Maps ids to keys in `providers`.
58    index: SlotMap<WorkerKey, SlotKey>,
59}
60
61impl SlotManagerImpl {
62    /// Factory method.
63    fn new() -> Self {
64        Self {
65            index: Default::default(),
66            providers: Default::default(),
67        }
68    }
69
70    fn try_reserve_wft_slot(
71        &self,
72        namespace: String,
73        task_queue: String,
74    ) -> Option<Box<dyn Slot + Send>> {
75        let key = SlotKey::new(namespace, task_queue);
76        if let Some(p) = self.providers.get(&key)
77            && let Some(slot) = p.try_reserve_wft_slot()
78        {
79            return Some(slot);
80        }
81        None
82    }
83
84    fn register(&mut self, provider: Box<dyn SlotProvider + Send + Sync>) -> Option<WorkerKey> {
85        let key = SlotKey::new(
86            provider.namespace().to_string(),
87            provider.task_queue().to_string(),
88        );
89        if let Vacant(p) = self.providers.entry(key.clone()) {
90            p.insert(provider);
91            Some(self.index.insert(key))
92        } else {
93            warn!("Ignoring registration for worker: {key:?}.");
94            None
95        }
96    }
97
98    fn unregister(&mut self, id: WorkerKey) -> Option<Box<dyn SlotProvider + Send + Sync>> {
99        if let Some(key) = self.index.remove(id) {
100            self.providers.remove(&key)
101        } else {
102            None
103        }
104    }
105
106    #[cfg(test)]
107    fn num_providers(&self) -> (usize, usize) {
108        (self.index.len(), self.providers.len())
109    }
110}
111
112/// Enables local workers to make themselves visible to a shared client instance.
113/// There can only be one worker registered per namespace+queue_name+client, others will get ignored.
114/// It also provides a convenient method to find compatible slots within the collection.
115#[derive(Default, Debug)]
116pub struct SlotManager {
117    manager: RwLock<SlotManagerImpl>,
118}
119
120impl SlotManager {
121    /// Factory method.
122    pub fn new() -> Self {
123        Self {
124            manager: RwLock::new(SlotManagerImpl::new()),
125        }
126    }
127
128    /// Try to reserve a compatible processing slot in any of the registered workers.
129    pub(crate) fn try_reserve_wft_slot(
130        &self,
131        namespace: String,
132        task_queue: String,
133    ) -> Option<Box<dyn Slot + Send>> {
134        self.manager
135            .read()
136            .try_reserve_wft_slot(namespace, task_queue)
137    }
138
139    /// Register a local worker that can provide WFT processing slots.
140    pub fn register(&self, provider: Box<dyn SlotProvider + Send + Sync>) -> Option<WorkerKey> {
141        self.manager.write().register(provider)
142    }
143
144    /// Unregister a provider, typically when its worker starts shutdown.
145    pub fn unregister(&self, id: WorkerKey) -> Option<Box<dyn SlotProvider + Send + Sync>> {
146        self.manager.write().unregister(id)
147    }
148
149    #[cfg(test)]
150    /// Returns (num_providers, num_buckets), where a bucket key is namespace+task_queue.
151    /// There is only one provider per bucket so `num_providers` should be equal to `num_buckets`.
152    pub fn num_providers(&self) -> (usize, usize) {
153        self.manager.read().num_providers()
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    fn new_mock_slot(with_error: bool) -> Box<MockSlot> {
162        let mut mock_slot = MockSlot::new();
163        if with_error {
164            mock_slot
165                .expect_schedule_wft()
166                .returning(|_| Err(anyhow::anyhow!("Changed my mind")));
167        } else {
168            mock_slot.expect_schedule_wft().returning(|_| Ok(()));
169        }
170        Box::new(mock_slot)
171    }
172
173    fn new_mock_provider(
174        namespace: String,
175        task_queue: String,
176        with_error: bool,
177        no_slots: bool,
178    ) -> MockSlotProvider {
179        let mut mock_provider = MockSlotProvider::new();
180        mock_provider
181            .expect_try_reserve_wft_slot()
182            .returning(move || {
183                if no_slots {
184                    None
185                } else {
186                    Some(new_mock_slot(with_error))
187                }
188            });
189        mock_provider.expect_namespace().return_const(namespace);
190        mock_provider.expect_task_queue().return_const(task_queue);
191        mock_provider
192    }
193
194    #[test]
195    fn registry_respects_registration_order() {
196        let mock_provider1 =
197            new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false);
198        let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true);
199
200        let manager = SlotManager::new();
201        let some_slots = manager.register(Box::new(mock_provider1));
202        let no_slots = manager.register(Box::new(mock_provider2));
203        assert!(no_slots.is_none());
204
205        let mut found = 0;
206        for _ in 0..10 {
207            if manager
208                .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string())
209                .is_some()
210            {
211                found += 1;
212            }
213        }
214        assert_eq!(found, 10);
215        assert_eq!((1, 1), manager.num_providers());
216
217        manager.unregister(some_slots.unwrap());
218        assert_eq!((0, 0), manager.num_providers());
219
220        let mock_provider1 =
221            new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false);
222        let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true);
223
224        let no_slots = manager.register(Box::new(mock_provider2));
225        let some_slots = manager.register(Box::new(mock_provider1));
226        assert!(some_slots.is_none());
227
228        let mut not_found = 0;
229        for _ in 0..10 {
230            if manager
231                .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string())
232                .is_none()
233            {
234                not_found += 1;
235            }
236        }
237        assert_eq!(not_found, 10);
238        assert_eq!((1, 1), manager.num_providers());
239        manager.unregister(no_slots.unwrap());
240        assert_eq!((0, 0), manager.num_providers());
241    }
242
243    #[test]
244    fn registry_keeps_one_provider_per_namespace() {
245        let manager = SlotManager::new();
246        let mut worker_keys = vec![];
247        for i in 0..10 {
248            let namespace = format!("myId{}", i % 3);
249            let mock_provider = new_mock_provider(namespace, "bar_q".to_string(), false, false);
250            worker_keys.push(manager.register(Box::new(mock_provider)));
251        }
252        assert_eq!((3, 3), manager.num_providers());
253
254        let count = worker_keys
255            .iter()
256            .filter(|key| key.is_some())
257            .fold(0, |count, key| {
258                manager.unregister(key.unwrap());
259                // Should be idempotent
260                manager.unregister(key.unwrap());
261                count + 1
262            });
263        assert_eq!(3, count);
264        assert_eq!((0, 0), manager.num_providers());
265    }
266}