squads_temporal_client/worker_registry/
mod.rs1use parking_lot::RwLock;
6use slotmap::SlotMap;
7use std::collections::{HashMap, hash_map::Entry::Vacant};
8
9use squads_temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse;
10
11slotmap::new_key_type! {
12 pub struct WorkerKey;
14}
15
16#[cfg_attr(test, mockall::automock)]
18pub trait SlotProvider: std::fmt::Debug {
19 fn namespace(&self) -> &str;
21 fn task_queue(&self) -> &str;
23 fn try_reserve_wft_slot(&self) -> Option<Box<dyn Slot + Send>>;
25}
26
27#[cfg_attr(test, mockall::automock)]
29pub trait Slot {
30 fn schedule_wft(
32 self: Box<Self>,
33 task: PollWorkflowTaskQueueResponse,
34 ) -> Result<(), anyhow::Error>;
35}
36
37#[derive(PartialEq, Eq, Hash, Debug, Clone)]
38struct SlotKey {
39 namespace: String,
40 task_queue: String,
41}
42
43impl SlotKey {
44 fn new(namespace: String, task_queue: String) -> SlotKey {
45 SlotKey {
46 namespace,
47 task_queue,
48 }
49 }
50}
51
52#[derive(Default, Debug)]
54struct SlotManagerImpl {
55 providers: HashMap<SlotKey, Box<dyn SlotProvider + Send + Sync>>,
57 index: SlotMap<WorkerKey, SlotKey>,
59}
60
61impl SlotManagerImpl {
62 fn new() -> Self {
64 Self {
65 index: Default::default(),
66 providers: Default::default(),
67 }
68 }
69
70 fn try_reserve_wft_slot(
71 &self,
72 namespace: String,
73 task_queue: String,
74 ) -> Option<Box<dyn Slot + Send>> {
75 let key = SlotKey::new(namespace, task_queue);
76 if let Some(p) = self.providers.get(&key)
77 && let Some(slot) = p.try_reserve_wft_slot()
78 {
79 return Some(slot);
80 }
81 None
82 }
83
84 fn register(&mut self, provider: Box<dyn SlotProvider + Send + Sync>) -> Option<WorkerKey> {
85 let key = SlotKey::new(
86 provider.namespace().to_string(),
87 provider.task_queue().to_string(),
88 );
89 if let Vacant(p) = self.providers.entry(key.clone()) {
90 p.insert(provider);
91 Some(self.index.insert(key))
92 } else {
93 warn!("Ignoring registration for worker: {key:?}.");
94 None
95 }
96 }
97
98 fn unregister(&mut self, id: WorkerKey) -> Option<Box<dyn SlotProvider + Send + Sync>> {
99 if let Some(key) = self.index.remove(id) {
100 self.providers.remove(&key)
101 } else {
102 None
103 }
104 }
105
106 #[cfg(test)]
107 fn num_providers(&self) -> (usize, usize) {
108 (self.index.len(), self.providers.len())
109 }
110}
111
112#[derive(Default, Debug)]
116pub struct SlotManager {
117 manager: RwLock<SlotManagerImpl>,
118}
119
120impl SlotManager {
121 pub fn new() -> Self {
123 Self {
124 manager: RwLock::new(SlotManagerImpl::new()),
125 }
126 }
127
128 pub(crate) fn try_reserve_wft_slot(
130 &self,
131 namespace: String,
132 task_queue: String,
133 ) -> Option<Box<dyn Slot + Send>> {
134 self.manager
135 .read()
136 .try_reserve_wft_slot(namespace, task_queue)
137 }
138
139 pub fn register(&self, provider: Box<dyn SlotProvider + Send + Sync>) -> Option<WorkerKey> {
141 self.manager.write().register(provider)
142 }
143
144 pub fn unregister(&self, id: WorkerKey) -> Option<Box<dyn SlotProvider + Send + Sync>> {
146 self.manager.write().unregister(id)
147 }
148
149 #[cfg(test)]
150 pub fn num_providers(&self) -> (usize, usize) {
153 self.manager.read().num_providers()
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 fn new_mock_slot(with_error: bool) -> Box<MockSlot> {
162 let mut mock_slot = MockSlot::new();
163 if with_error {
164 mock_slot
165 .expect_schedule_wft()
166 .returning(|_| Err(anyhow::anyhow!("Changed my mind")));
167 } else {
168 mock_slot.expect_schedule_wft().returning(|_| Ok(()));
169 }
170 Box::new(mock_slot)
171 }
172
173 fn new_mock_provider(
174 namespace: String,
175 task_queue: String,
176 with_error: bool,
177 no_slots: bool,
178 ) -> MockSlotProvider {
179 let mut mock_provider = MockSlotProvider::new();
180 mock_provider
181 .expect_try_reserve_wft_slot()
182 .returning(move || {
183 if no_slots {
184 None
185 } else {
186 Some(new_mock_slot(with_error))
187 }
188 });
189 mock_provider.expect_namespace().return_const(namespace);
190 mock_provider.expect_task_queue().return_const(task_queue);
191 mock_provider
192 }
193
194 #[test]
195 fn registry_respects_registration_order() {
196 let mock_provider1 =
197 new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false);
198 let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true);
199
200 let manager = SlotManager::new();
201 let some_slots = manager.register(Box::new(mock_provider1));
202 let no_slots = manager.register(Box::new(mock_provider2));
203 assert!(no_slots.is_none());
204
205 let mut found = 0;
206 for _ in 0..10 {
207 if manager
208 .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string())
209 .is_some()
210 {
211 found += 1;
212 }
213 }
214 assert_eq!(found, 10);
215 assert_eq!((1, 1), manager.num_providers());
216
217 manager.unregister(some_slots.unwrap());
218 assert_eq!((0, 0), manager.num_providers());
219
220 let mock_provider1 =
221 new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false);
222 let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true);
223
224 let no_slots = manager.register(Box::new(mock_provider2));
225 let some_slots = manager.register(Box::new(mock_provider1));
226 assert!(some_slots.is_none());
227
228 let mut not_found = 0;
229 for _ in 0..10 {
230 if manager
231 .try_reserve_wft_slot("foo".to_string(), "bar_q".to_string())
232 .is_none()
233 {
234 not_found += 1;
235 }
236 }
237 assert_eq!(not_found, 10);
238 assert_eq!((1, 1), manager.num_providers());
239 manager.unregister(no_slots.unwrap());
240 assert_eq!((0, 0), manager.num_providers());
241 }
242
243 #[test]
244 fn registry_keeps_one_provider_per_namespace() {
245 let manager = SlotManager::new();
246 let mut worker_keys = vec![];
247 for i in 0..10 {
248 let namespace = format!("myId{}", i % 3);
249 let mock_provider = new_mock_provider(namespace, "bar_q".to_string(), false, false);
250 worker_keys.push(manager.register(Box::new(mock_provider)));
251 }
252 assert_eq!((3, 3), manager.num_providers());
253
254 let count = worker_keys
255 .iter()
256 .filter(|key| key.is_some())
257 .fold(0, |count, key| {
258 manager.unregister(key.unwrap());
259 manager.unregister(key.unwrap());
261 count + 1
262 });
263 assert_eq!(3, count);
264 assert_eq!((0, 0), manager.num_providers());
265 }
266}