1use anyhow::bail;
4use parking_lot::RwLock;
5use rand::seq::SliceRandom;
6use std::{
7 collections::{
8 HashMap,
9 hash_map::Entry::{Occupied, Vacant},
10 },
11 sync::Arc,
12};
13use temporalio_common::{
14 protos::temporal::api::{
15 worker::v1::WorkerHeartbeat, workflowservice::v1::PollWorkflowTaskQueueResponse,
16 },
17 worker::{WorkerDeploymentOptions, WorkerTaskTypes},
18};
19use uuid::Uuid;
20
21#[cfg_attr(test, mockall::automock)]
23pub trait Slot {
24 fn schedule_wft(
26 self: Box<Self>,
27 task: PollWorkflowTaskQueueResponse,
28 ) -> Result<(), anyhow::Error>;
29}
30
31pub(crate) struct SlotReservation {
33 pub slot: Box<dyn Slot + Send>,
35 pub deployment_options: Option<WorkerDeploymentOptions>,
37}
38
39#[derive(PartialEq, Eq, Hash, Debug, Clone)]
40struct SlotKey {
41 namespace: String,
42 task_queue: String,
43}
44
45impl SlotKey {
46 fn new(namespace: String, task_queue: String) -> SlotKey {
47 SlotKey {
48 namespace,
49 task_queue,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56struct RegisteredWorkerInfo {
57 worker_id: Uuid,
59 build_id: Option<String>,
61 task_types: WorkerTaskTypes,
63}
64
65impl RegisteredWorkerInfo {
66 fn new(worker_id: Uuid, build_id: Option<String>, task_types: WorkerTaskTypes) -> Self {
67 Self {
68 worker_id,
69 build_id,
70 task_types,
71 }
72 }
73}
74
75struct ClientWorkerSetImpl {
77 slot_providers: HashMap<SlotKey, Vec<RegisteredWorkerInfo>>,
79 all_workers: HashMap<Uuid, Arc<dyn ClientWorker + Send + Sync>>,
81 shared_worker: HashMap<String, Box<dyn SharedNamespaceWorkerTrait + Send + Sync>>,
83}
84
85impl ClientWorkerSetImpl {
86 fn new() -> Self {
88 Self {
89 slot_providers: Default::default(),
90 all_workers: Default::default(),
91 shared_worker: Default::default(),
92 }
93 }
94
95 fn try_reserve_wft_slot(
96 &self,
97 namespace: String,
98 task_queue: String,
99 ) -> Option<SlotReservation> {
100 let key = SlotKey::new(namespace, task_queue);
101 if let Some(worker_list) = self.slot_providers.get(&key) {
102 let workflow_workers: Vec<&RegisteredWorkerInfo> = worker_list
103 .iter()
104 .filter(|info| info.task_types.enable_workflows)
105 .collect();
106
107 for worker_id in Self::worker_ids_in_selection_order(&workflow_workers) {
108 if let Some(worker) = self.all_workers.get(&worker_id)
109 && let Some(slot) = worker.try_reserve_wft_slot()
110 {
111 let deployment_options = worker.deployment_options();
112 return Some(SlotReservation {
113 slot,
114 deployment_options,
115 });
116 }
117 }
118 }
119 None
120 }
121
122 fn worker_ids_in_selection_order(worker_list: &[&RegisteredWorkerInfo]) -> Vec<Uuid> {
123 if cfg!(test) {
126 worker_list.iter().map(|info| info.worker_id).collect()
127 } else {
128 let mut rng = rand::rng();
129 let mut shuffled: Vec<_> = worker_list.to_vec();
130 shuffled.shuffle(&mut rng);
131 shuffled.iter().map(|info| info.worker_id).collect()
132 }
133 }
134
135 fn register(
136 &mut self,
137 worker: Arc<dyn ClientWorker + Send + Sync>,
138 skip_client_worker_set_check: bool,
139 ) -> Result<(), anyhow::Error> {
140 let slot_key = SlotKey::new(
141 worker.namespace().to_string(),
142 worker.task_queue().to_string(),
143 );
144 let build_id = worker
145 .deployment_options()
146 .map(|opts| opts.version.build_id);
147 let task_types = worker.worker_task_types();
148
149 if !task_types.enable_workflows
150 && !task_types.enable_local_activities
151 && !task_types.enable_remote_activities
152 && !task_types.enable_nexus
153 {
154 bail!(
155 "Worker must have at least one capability enabled (workflows, activities, or nexus)"
156 );
157 }
158
159 if !task_types.enable_workflows && task_types.enable_local_activities {
160 bail!("Local activities cannot be enabled without workflows")
161 }
162
163 if !skip_client_worker_set_check
164 && let Some(existing_workers) = self.slot_providers.get(&slot_key)
165 {
166 for existing_worker_info in existing_workers {
167 if existing_worker_info.build_id.as_ref() == build_id.as_ref()
168 && task_types.overlaps_with(&existing_worker_info.task_types)
169 {
170 bail!(
171 "Registration of multiple workers with overlapping worker task types \
172 on the same namespace, task queue, and deployment build ID not allowed: \
173 {slot_key:?}, worker_instance_key: {:?} \
174 build_id: {build_id:?}, \
175 new task types: {task_types:?}, \
176 existing task types: {:?}.",
177 existing_worker_info.task_types,
178 worker.worker_instance_key()
179 );
180 }
181 }
182 }
183
184 if worker.heartbeat_enabled()
185 && let Some(heartbeat_callback) = worker.heartbeat_callback()
186 {
187 let worker_instance_key = worker.worker_instance_key();
188 let namespace = worker.namespace().to_string();
189
190 let shared_worker = match self.shared_worker.entry(namespace.clone()) {
191 Occupied(o) => o.into_mut(),
192 Vacant(v) => {
193 let shared_worker = worker.new_shared_namespace_worker()?;
194 v.insert(shared_worker)
195 }
196 };
197 shared_worker.register_callback(worker_instance_key, heartbeat_callback);
198 }
199
200 let worker_info =
201 RegisteredWorkerInfo::new(worker.worker_instance_key(), build_id, task_types);
202
203 match self.slot_providers.entry(slot_key.clone()) {
204 Occupied(o) => o.into_mut().push(worker_info),
205 Vacant(v) => {
206 v.insert(vec![worker_info]);
207 }
208 };
209
210 self.all_workers
211 .insert(worker.worker_instance_key(), worker);
212
213 Ok(())
214 }
215
216 fn unregister_slot_provider(&mut self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
219 let worker = self.all_workers.get(&worker_instance_key).ok_or_else(|| {
220 anyhow::anyhow!("Worker not in all_workers during slot provider unregister")
221 })?;
222
223 let slot_key = SlotKey::new(
224 worker.namespace().to_string(),
225 worker.task_queue().to_string(),
226 );
227 if let Some(slot_vec) = self.slot_providers.get_mut(&slot_key) {
228 slot_vec.retain(|info| info.worker_id != worker_instance_key);
229 if slot_vec.is_empty() {
230 self.slot_providers.remove(&slot_key);
231 }
232 }
233 Ok(())
234 }
235
236 fn finalize_unregister(
237 &mut self,
238 worker_instance_key: Uuid,
239 ) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
240 if let Some(worker) = self.all_workers.get(&worker_instance_key)
241 && let Some(slot_vec) = self.slot_providers.get(&SlotKey::new(
242 worker.namespace().to_string(),
243 worker.task_queue().to_string(),
244 ))
245 && slot_vec
246 .iter()
247 .any(|info| info.worker_id == worker_instance_key)
248 {
249 return Err(anyhow::anyhow!(
250 "Worker still in slot_providers during finalize"
251 ));
252 }
253
254 let worker = self
255 .all_workers
256 .remove(&worker_instance_key)
257 .ok_or_else(|| anyhow::anyhow!("Worker not found in all_workers"))?;
258
259 if let Some(w) = self.shared_worker.get_mut(worker.namespace()) {
260 let (callback, is_empty) = w.unregister_callback(worker.worker_instance_key());
261 if callback.is_some() && is_empty {
262 self.shared_worker.remove(worker.namespace());
263 }
264 }
265
266 Ok(worker)
267 }
268
269 #[cfg(test)]
270 fn num_providers(&self) -> usize {
271 self.slot_providers.values().map(|v| v.len()).sum()
272 }
273
274 #[cfg(test)]
275 fn num_heartbeat_workers(&self) -> usize {
276 self.shared_worker.values().map(|v| v.num_workers()).sum()
277 }
278}
279
280pub trait SharedNamespaceWorkerTrait {
283 fn namespace(&self) -> String;
285
286 fn register_callback(&self, worker_instance_key: Uuid, heartbeat_callback: HeartbeatCallback);
288
289 fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option<HeartbeatCallback>, bool);
293
294 fn num_workers(&self) -> usize;
296}
297
298pub struct ClientWorkerSet {
304 worker_grouping_key: Uuid,
305 worker_manager: RwLock<ClientWorkerSetImpl>,
306}
307
308impl Default for ClientWorkerSet {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314impl ClientWorkerSet {
315 pub fn new() -> Self {
317 Self {
318 worker_grouping_key: Uuid::new_v4(),
319 worker_manager: RwLock::new(ClientWorkerSetImpl::new()),
320 }
321 }
322
323 pub(crate) fn try_reserve_wft_slot(
326 &self,
327 namespace: String,
328 task_queue: String,
329 ) -> Option<SlotReservation> {
330 self.worker_manager
331 .read()
332 .try_reserve_wft_slot(namespace, task_queue)
333 }
334
335 pub fn register_worker(
337 &self,
338 worker: Arc<dyn ClientWorker + Send + Sync>,
339 skip_client_worker_set_check: bool,
340 ) -> Result<(), anyhow::Error> {
341 self.worker_manager
342 .write()
343 .register(worker, skip_client_worker_set_check)
344 }
345
346 pub fn unregister_slot_provider(&self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
349 self.worker_manager
350 .write()
351 .unregister_slot_provider(worker_instance_key)
352 }
353
354 pub fn finalize_unregister(
358 &self,
359 worker_instance_key: Uuid,
360 ) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
361 self.worker_manager
362 .write()
363 .finalize_unregister(worker_instance_key)
364 }
365
366 pub fn worker_grouping_key(&self) -> Uuid {
368 self.worker_grouping_key
369 }
370
371 #[cfg(test)]
372 pub fn num_providers(&self) -> usize {
375 self.worker_manager.read().num_providers()
376 }
377
378 #[cfg(test)]
379 pub fn num_heartbeat_workers(&self) -> usize {
381 self.worker_manager.read().num_heartbeat_workers()
382 }
383}
384
385impl std::fmt::Debug for ClientWorkerSet {
386 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387 f.debug_struct("ClientWorkerSet")
388 .field("worker_grouping_key", &self.worker_grouping_key)
389 .finish()
390 }
391}
392
393pub type HeartbeatCallback = Arc<dyn Fn() -> WorkerHeartbeat + Send + Sync>;
395
396#[cfg_attr(test, mockall::automock)]
399pub trait ClientWorker: Send + Sync {
400 fn namespace(&self) -> &str;
402
403 fn task_queue(&self) -> &str;
405
406 fn try_reserve_wft_slot(&self) -> Option<Box<dyn Slot + Send>>;
412
413 fn deployment_options(&self) -> Option<WorkerDeploymentOptions>;
415
416 fn worker_instance_key(&self) -> Uuid;
419
420 fn heartbeat_enabled(&self) -> bool;
422
423 fn heartbeat_callback(&self) -> Option<HeartbeatCallback>;
425
426 fn new_shared_namespace_worker(
428 &self,
429 ) -> Result<Box<dyn SharedNamespaceWorkerTrait + Send + Sync>, anyhow::Error>;
430
431 fn worker_task_types(&self) -> WorkerTaskTypes;
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 fn new_mock_slot(with_error: bool) -> Box<MockSlot> {
440 let mut mock_slot = MockSlot::new();
441 if with_error {
442 mock_slot
443 .expect_schedule_wft()
444 .returning(|_| Err(anyhow::anyhow!("Changed my mind")));
445 } else {
446 mock_slot.expect_schedule_wft().returning(|_| Ok(()));
447 }
448 Box::new(mock_slot)
449 }
450
451 fn new_mock_provider(
452 namespace: String,
453 task_queue: String,
454 with_error: bool,
455 no_slots: bool,
456 heartbeat_enabled: bool,
457 ) -> MockClientWorker {
458 let mut mock_provider = MockClientWorker::new();
459 mock_provider
460 .expect_try_reserve_wft_slot()
461 .returning(move || {
462 if no_slots {
463 None
464 } else {
465 Some(new_mock_slot(with_error))
466 }
467 });
468 mock_provider.expect_namespace().return_const(namespace);
469 mock_provider.expect_task_queue().return_const(task_queue);
470 mock_provider.expect_deployment_options().return_const(None);
471 mock_provider
472 .expect_heartbeat_enabled()
473 .return_const(heartbeat_enabled);
474 mock_provider
475 .expect_worker_instance_key()
476 .return_const(Uuid::new_v4());
477 mock_provider
478 .expect_worker_task_types()
479 .return_const(WorkerTaskTypes {
480 enable_workflows: true,
481 enable_local_activities: true,
482 enable_remote_activities: true,
483 enable_nexus: true,
484 });
485 mock_provider
486 }
487
488 #[test]
489 fn reserve_wft_slot_retries_another_worker_when_first_has_no_slot() {
490 let mut manager = ClientWorkerSetImpl::new();
491 let namespace = "retry_namespace".to_string();
492 let task_queue = "retry_queue".to_string();
493
494 let failing_worker_id = Uuid::new_v4();
495 let mut failing_worker = MockClientWorker::new();
496 failing_worker
497 .expect_try_reserve_wft_slot()
498 .times(1)
499 .returning(|| None);
500 failing_worker
501 .expect_namespace()
502 .return_const(namespace.clone());
503 failing_worker
504 .expect_task_queue()
505 .return_const(task_queue.clone());
506 failing_worker
507 .expect_deployment_options()
508 .return_const(WorkerDeploymentOptions {
509 version: temporalio_common::worker::WorkerDeploymentVersion {
510 deployment_name: "test-deployment".to_string(),
511 build_id: "build-fail".to_string(),
512 },
513 use_worker_versioning: true,
514 default_versioning_behavior: None,
515 });
516 failing_worker
517 .expect_worker_instance_key()
518 .return_const(failing_worker_id);
519 failing_worker
520 .expect_heartbeat_enabled()
521 .return_const(false);
522 failing_worker
523 .expect_worker_task_types()
524 .return_const(WorkerTaskTypes {
525 enable_workflows: true,
526 enable_local_activities: true,
527 enable_remote_activities: true,
528 enable_nexus: true,
529 });
530
531 let succeeding_worker_id = Uuid::new_v4();
532 let mut succeeding_worker = MockClientWorker::new();
533 succeeding_worker
534 .expect_try_reserve_wft_slot()
535 .times(1)
536 .returning(|| Some(new_mock_slot(false)));
537 succeeding_worker
538 .expect_namespace()
539 .return_const(namespace.clone());
540 succeeding_worker
541 .expect_task_queue()
542 .return_const(task_queue.clone());
543 let success_deployment_options = WorkerDeploymentOptions {
544 version: temporalio_common::worker::WorkerDeploymentVersion {
545 deployment_name: "test-deployment".to_string(),
546 build_id: "build-success".to_string(),
547 },
548 use_worker_versioning: true,
549 default_versioning_behavior: None,
550 };
551 succeeding_worker
552 .expect_deployment_options()
553 .return_const(success_deployment_options.clone());
554 succeeding_worker
555 .expect_worker_instance_key()
556 .return_const(succeeding_worker_id);
557 succeeding_worker
558 .expect_heartbeat_enabled()
559 .return_const(false);
560 succeeding_worker
561 .expect_worker_task_types()
562 .return_const(WorkerTaskTypes {
563 enable_workflows: true,
564 enable_local_activities: true,
565 enable_remote_activities: true,
566 enable_nexus: true,
567 });
568
569 manager
570 .register(Arc::new(failing_worker), false)
571 .expect("failing worker registration succeeds");
572 manager
573 .register(Arc::new(succeeding_worker), false)
574 .expect("succeeding worker registration succeeds");
575
576 let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
577
578 let reservation_deployment_options = reservation
579 .expect("succeeding worker was used after failing worker failed")
580 .deployment_options
581 .unwrap();
582 assert_eq!(
583 reservation_deployment_options, success_deployment_options,
584 "deployment options bubble through from succeeding worker"
585 );
586 }
587
588 #[test]
589 fn reserve_wft_slot_retries_respects_slot_boundary() {
590 let mut manager = ClientWorkerSetImpl::new();
591 let namespace = "retry_namespace".to_string();
592 let task_queue = "retry_queue".to_string();
593
594 let failing_worker_id = Uuid::new_v4();
595 let mut failing_worker = MockClientWorker::new();
596 failing_worker
597 .expect_try_reserve_wft_slot()
598 .times(1)
599 .returning(|| None);
600 failing_worker
601 .expect_namespace()
602 .return_const(namespace.clone());
603 failing_worker
604 .expect_task_queue()
605 .return_const(task_queue.clone());
606 failing_worker
607 .expect_deployment_options()
608 .return_const(WorkerDeploymentOptions {
609 version: temporalio_common::worker::WorkerDeploymentVersion {
610 deployment_name: "test-deployment".to_string(),
611 build_id: "build-fail".to_string(),
612 },
613 use_worker_versioning: true,
614 default_versioning_behavior: None,
615 });
616 failing_worker
617 .expect_worker_instance_key()
618 .return_const(failing_worker_id);
619 failing_worker
620 .expect_heartbeat_enabled()
621 .return_const(false);
622 failing_worker
623 .expect_worker_task_types()
624 .return_const(WorkerTaskTypes {
625 enable_workflows: true,
626 enable_local_activities: true,
627 enable_remote_activities: true,
628 enable_nexus: true,
629 });
630
631 let succeeding_worker_id = Uuid::new_v4();
633 let mut succeeding_worker = MockClientWorker::new();
634 succeeding_worker.expect_try_reserve_wft_slot().times(0);
635 succeeding_worker
636 .expect_namespace()
637 .return_const(namespace.clone());
638 succeeding_worker
639 .expect_task_queue()
640 .return_const("other_task_queue".to_string());
641 succeeding_worker
642 .expect_deployment_options()
643 .return_const(None);
644 succeeding_worker
645 .expect_worker_instance_key()
646 .return_const(succeeding_worker_id);
647 succeeding_worker
648 .expect_heartbeat_enabled()
649 .return_const(false);
650 succeeding_worker
651 .expect_worker_task_types()
652 .return_const(WorkerTaskTypes {
653 enable_workflows: true,
654 enable_local_activities: true,
655 enable_remote_activities: true,
656 enable_nexus: true,
657 });
658
659 manager
660 .register(Arc::new(failing_worker), false)
661 .expect("failing worker registration succeeds");
662 manager
663 .register(Arc::new(succeeding_worker), false)
664 .expect("succeeding worker registration succeeds");
665
666 let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
667 assert!(
668 reservation.is_none(),
669 "succeeding_worker should not be picked due to it being on a separate task queue"
670 );
671 }
672
673 #[test]
674 fn registry_keeps_one_provider_per_namespace() {
675 let manager = ClientWorkerSet::new();
676 let mut worker_keys = vec![];
677 let mut successful_registrations = 0;
678
679 for i in 0..10 {
680 let namespace = format!("myId{}", i % 3);
681 let mock_provider =
682 new_mock_provider(namespace, "bar_q".to_string(), false, false, false);
683 let worker_instance_key = mock_provider.worker_instance_key();
684
685 let result = manager.register_worker(Arc::new(mock_provider), false);
686 if let Err(err) = result {
687 assert!(err.to_string().contains(
689 "Registration of multiple workers with overlapping worker task types"
690 ));
691 } else {
692 successful_registrations += 1;
693 worker_keys.push(worker_instance_key);
694 }
695 }
696
697 assert_eq!(successful_registrations, 3);
698 assert_eq!(3, manager.num_providers());
699
700 let count = worker_keys.iter().fold(0, |count, key| {
701 manager.unregister_slot_provider(*key).unwrap();
702 manager.finalize_unregister(*key).unwrap();
703 let result = manager.unregister_slot_provider(*key);
705 assert!(result.is_err());
706 let result = manager.finalize_unregister(*key);
707 assert!(result.is_err());
708 count + 1
709 });
710 assert_eq!(3, count);
711 assert_eq!(0, manager.num_providers());
712 }
713
714 struct MockSharedNamespaceWorker {
715 namespace: String,
716 callbacks: Arc<RwLock<HashMap<Uuid, HeartbeatCallback>>>,
717 }
718
719 impl std::fmt::Debug for MockSharedNamespaceWorker {
720 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
721 f.debug_struct("MockSharedNamespaceWorker")
722 .field("namespace", &self.namespace)
723 .field("callbacks_count", &self.callbacks.read().len())
724 .finish()
725 }
726 }
727
728 impl MockSharedNamespaceWorker {
729 fn new(namespace: String) -> Self {
730 Self {
731 namespace,
732 callbacks: Arc::new(RwLock::new(HashMap::new())),
733 }
734 }
735 }
736
737 impl SharedNamespaceWorkerTrait for MockSharedNamespaceWorker {
738 fn namespace(&self) -> String {
739 self.namespace.clone()
740 }
741
742 fn register_callback(
743 &self,
744 worker_instance_key: Uuid,
745 heartbeat_callback: HeartbeatCallback,
746 ) {
747 self.callbacks
748 .write()
749 .insert(worker_instance_key, heartbeat_callback);
750 }
751
752 fn unregister_callback(
753 &self,
754 worker_instance_key: Uuid,
755 ) -> (Option<HeartbeatCallback>, bool) {
756 let mut callbacks = self.callbacks.write();
757 let callback = callbacks.remove(&worker_instance_key);
758 let is_empty = callbacks.is_empty();
759 (callback, is_empty)
760 }
761
762 fn num_workers(&self) -> usize {
763 self.callbacks.read().len()
764 }
765 }
766
767 fn new_mock_provider_with_heartbeat(
768 namespace: String,
769 task_queue: String,
770 heartbeat_enabled: bool,
771 build_id: Option<String>,
772 ) -> MockClientWorker {
773 let mut mock_provider = MockClientWorker::new();
774 mock_provider
775 .expect_try_reserve_wft_slot()
776 .returning(|| Some(new_mock_slot(false)));
777 mock_provider
778 .expect_namespace()
779 .return_const(namespace.clone());
780 mock_provider.expect_task_queue().return_const(task_queue);
781 mock_provider
782 .expect_heartbeat_enabled()
783 .return_const(heartbeat_enabled);
784 mock_provider
785 .expect_worker_instance_key()
786 .return_const(Uuid::new_v4());
787 let deployment_name = "test-deployment".to_string();
788 let build_id_for_closure = build_id.clone();
789 mock_provider
790 .expect_deployment_options()
791 .returning(move || {
792 build_id_for_closure
793 .as_ref()
794 .map(|build_id| WorkerDeploymentOptions {
795 version: temporalio_common::worker::WorkerDeploymentVersion {
796 deployment_name: deployment_name.clone(),
797 build_id: build_id.clone(),
798 },
799 use_worker_versioning: true,
800 default_versioning_behavior: None,
801 })
802 });
803
804 if heartbeat_enabled {
805 mock_provider
806 .expect_heartbeat_callback()
807 .returning(|| Some(Arc::new(WorkerHeartbeat::default)));
808
809 let namespace_clone = namespace.clone();
810 mock_provider
811 .expect_new_shared_namespace_worker()
812 .returning(move || {
813 Ok(Box::new(MockSharedNamespaceWorker::new(
814 namespace_clone.clone(),
815 )))
816 });
817 }
818
819 mock_provider
820 .expect_worker_task_types()
821 .return_const(WorkerTaskTypes {
822 enable_workflows: true,
823 enable_local_activities: true,
824 enable_remote_activities: true,
825 enable_nexus: true,
826 });
827
828 mock_provider
829 }
830
831 #[test]
832 fn duplicate_namespace_task_queue_registration_fails() {
833 let manager = ClientWorkerSet::new();
834
835 let worker1 = new_mock_provider_with_heartbeat(
836 "test_namespace".to_string(),
837 "test_queue".to_string(),
838 true,
839 None,
840 );
841
842 let worker2 = new_mock_provider_with_heartbeat(
844 "test_namespace".to_string(),
845 "test_queue".to_string(),
846 true,
847 None,
848 );
849
850 manager.register_worker(Arc::new(worker1), false).unwrap();
851
852 let result = manager.register_worker(Arc::new(worker2), false);
854 assert!(result.is_err());
855 assert!(
856 result
857 .unwrap_err()
858 .to_string()
859 .contains("Registration of multiple workers with overlapping worker task types")
860 );
861
862 assert_eq!(1, manager.num_providers());
863 assert_eq!(manager.num_heartbeat_workers(), 1);
864
865 let impl_ref = manager.worker_manager.read();
866 assert_eq!(impl_ref.shared_worker.len(), 1);
867 assert!(impl_ref.shared_worker.contains_key("test_namespace"));
868 }
869
870 #[test]
871 fn duplicate_namespace_with_different_build_ids_succeeds() {
872 let manager = ClientWorkerSet::new();
873 let namespace = "test_namespace".to_string();
874 let task_queue = "test_queue".to_string();
875
876 let worker1 =
877 new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
878 let worker1_instance_key = worker1.worker_instance_key();
879 let worker2 = new_mock_provider_with_heartbeat(
880 namespace.clone(),
881 task_queue.clone(),
882 false,
883 Some("build-1".to_string()),
884 );
885 let worker2_instance_key = worker2.worker_instance_key();
886 let worker3 =
887 new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
888 let worker4 = new_mock_provider_with_heartbeat(
889 namespace.clone(),
890 task_queue.clone(),
891 false,
892 Some("build-1".to_string()),
893 );
894
895 manager.register_worker(Arc::new(worker1), false).unwrap();
896
897 manager
898 .register_worker(Arc::new(worker2), false)
899 .expect("worker with new build ID should register");
900 assert_eq!(2, manager.num_providers());
901
902 assert!(
903 manager
904 .register_worker(Arc::new(worker3), false)
905 .unwrap_err()
906 .to_string()
907 .contains("Registration of multiple workers with overlapping worker task types")
908 );
909
910 assert!(
911 manager
912 .register_worker(Arc::new(worker4), false)
913 .unwrap_err()
914 .to_string()
915 .contains("Registration of multiple workers with overlapping worker task types")
916 );
917 assert_eq!(2, manager.num_providers());
918
919 {
920 let impl_ref = manager.worker_manager.read();
921 let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
922 let providers = impl_ref
923 .slot_providers
924 .get(&slot_key)
925 .expect("slot providers should exist for namespace/task queue");
926 assert_eq!(2, providers.len());
927
928 assert_eq!(providers[0].worker_id, worker1_instance_key);
929 assert_eq!(providers[0].build_id, None);
930 assert_eq!(providers[1].worker_id, worker2_instance_key);
931 assert_eq!(providers[1].build_id, Some("build-1".to_string()));
932 }
933
934 manager
935 .unregister_slot_provider(worker2_instance_key)
936 .unwrap();
937 manager.finalize_unregister(worker2_instance_key).unwrap();
938
939 {
940 let impl_ref = manager.worker_manager.read();
941 let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
942 let providers = impl_ref
943 .slot_providers
944 .get(&slot_key)
945 .expect("slot providers should exist for namespace/task queue");
946
947 assert_eq!(1, providers.len());
948 assert_eq!(providers[0].worker_id, worker1_instance_key);
949 assert_eq!(providers[0].build_id, None);
950 }
951 }
952
953 #[test]
954 fn multiple_workers_same_namespace_share_heartbeat_manager() {
955 let manager = ClientWorkerSet::new();
956
957 let worker1 = new_mock_provider_with_heartbeat(
958 "shared_namespace".to_string(),
959 "queue1".to_string(),
960 true,
961 None,
962 );
963
964 let worker2 = new_mock_provider_with_heartbeat(
966 "shared_namespace".to_string(),
967 "queue2".to_string(),
968 true,
969 None,
970 );
971
972 manager.register_worker(Arc::new(worker1), false).unwrap();
973 manager.register_worker(Arc::new(worker2), false).unwrap();
974
975 assert_eq!(2, manager.num_providers());
976 assert_eq!(manager.num_heartbeat_workers(), 2);
977
978 let impl_ref = manager.worker_manager.read();
979 assert_eq!(impl_ref.shared_worker.len(), 1);
980 assert!(impl_ref.shared_worker.contains_key("shared_namespace"));
981
982 let shared_worker = impl_ref.shared_worker.get("shared_namespace").unwrap();
983 assert_eq!(shared_worker.namespace(), "shared_namespace");
984 }
985
986 #[test]
987 fn different_namespaces_get_separate_heartbeat_managers() {
988 let manager = ClientWorkerSet::new();
989 let worker1 = new_mock_provider_with_heartbeat(
990 "namespace1".to_string(),
991 "queue1".to_string(),
992 true,
993 None,
994 );
995 let worker2 = new_mock_provider_with_heartbeat(
996 "namespace2".to_string(),
997 "queue1".to_string(),
998 true,
999 None,
1000 );
1001
1002 manager.register_worker(Arc::new(worker1), false).unwrap();
1003 manager.register_worker(Arc::new(worker2), false).unwrap();
1004
1005 assert_eq!(2, manager.num_providers());
1006 assert_eq!(manager.num_heartbeat_workers(), 2);
1007
1008 let impl_ref = manager.worker_manager.read();
1009 assert_eq!(impl_ref.num_heartbeat_workers(), 2);
1010 assert!(impl_ref.shared_worker.contains_key("namespace1"));
1011 assert!(impl_ref.shared_worker.contains_key("namespace2"));
1012 }
1013
1014 #[test]
1015 fn unregister_heartbeat_workers_cleans_up_shared_worker_when_last_removed() {
1016 let manager = ClientWorkerSet::new();
1017
1018 let worker1 = new_mock_provider_with_heartbeat(
1020 "test_namespace".to_string(),
1021 "queue1".to_string(),
1022 true,
1023 None,
1024 );
1025 let worker2 = new_mock_provider_with_heartbeat(
1026 "test_namespace".to_string(),
1027 "queue2".to_string(),
1028 true,
1029 None,
1030 );
1031 let worker_instance_key1 = worker1.worker_instance_key();
1032 let worker_instance_key2 = worker2.worker_instance_key();
1033
1034 assert_ne!(worker_instance_key1, worker_instance_key2);
1035
1036 manager.register_worker(Arc::new(worker1), false).unwrap();
1037 manager.register_worker(Arc::new(worker2), false).unwrap();
1038
1039 assert_eq!(2, manager.num_providers());
1041 assert_eq!(manager.num_heartbeat_workers(), 2);
1042
1043 let impl_ref = manager.worker_manager.read();
1044 assert_eq!(impl_ref.shared_worker.len(), 1);
1045 assert!(impl_ref.shared_worker.contains_key("test_namespace"));
1046 assert_eq!(
1047 impl_ref
1048 .shared_worker
1049 .get("test_namespace")
1050 .unwrap()
1051 .num_workers(),
1052 2
1053 );
1054 drop(impl_ref);
1055
1056 manager
1058 .unregister_slot_provider(worker_instance_key1)
1059 .unwrap();
1060 manager.finalize_unregister(worker_instance_key1).unwrap();
1061
1062 assert_eq!(1, manager.num_providers());
1064 assert_eq!(manager.num_heartbeat_workers(), 1);
1065
1066 let impl_ref = manager.worker_manager.read();
1067 assert_eq!(impl_ref.num_heartbeat_workers(), 1); assert!(impl_ref.shared_worker.contains_key("test_namespace"));
1069 assert_eq!(
1070 impl_ref
1071 .shared_worker
1072 .get("test_namespace")
1073 .unwrap()
1074 .num_workers(),
1075 1
1076 );
1077 drop(impl_ref);
1078
1079 manager
1081 .unregister_slot_provider(worker_instance_key2)
1082 .unwrap();
1083 manager.finalize_unregister(worker_instance_key2).unwrap();
1084
1085 assert_eq!(0, manager.num_providers());
1087 assert_eq!(manager.num_heartbeat_workers(), 0);
1088
1089 let impl_ref = manager.worker_manager.read();
1090 assert_eq!(impl_ref.shared_worker.len(), 0); assert!(!impl_ref.shared_worker.contains_key("test_namespace"));
1092 }
1093
1094 #[test]
1095 fn workflow_and_activity_only_workers_coexist() {
1096 let manager = ClientWorkerSet::new();
1097 let namespace = "test_namespace".to_string();
1098 let task_queue = "test_queue".to_string();
1099
1100 let mut workflow_nexus_worker = MockClientWorker::new();
1101 workflow_nexus_worker
1102 .expect_namespace()
1103 .return_const(namespace.clone());
1104 workflow_nexus_worker
1105 .expect_task_queue()
1106 .return_const(task_queue.clone());
1107 workflow_nexus_worker
1108 .expect_deployment_options()
1109 .return_const(None);
1110 workflow_nexus_worker
1111 .expect_worker_instance_key()
1112 .return_const(Uuid::new_v4());
1113 workflow_nexus_worker
1114 .expect_heartbeat_enabled()
1115 .return_const(false);
1116 workflow_nexus_worker
1117 .expect_worker_task_types()
1118 .return_const(WorkerTaskTypes {
1119 enable_workflows: true,
1120 enable_local_activities: false,
1121 enable_remote_activities: false,
1122 enable_nexus: true,
1123 });
1124
1125 let mut activity_worker = MockClientWorker::new();
1126 activity_worker
1127 .expect_namespace()
1128 .return_const(namespace.clone());
1129 activity_worker
1130 .expect_task_queue()
1131 .return_const(task_queue.clone());
1132 activity_worker
1133 .expect_deployment_options()
1134 .return_const(None);
1135 activity_worker
1136 .expect_worker_instance_key()
1137 .return_const(Uuid::new_v4());
1138 activity_worker
1139 .expect_heartbeat_enabled()
1140 .return_const(false);
1141 activity_worker
1142 .expect_worker_task_types()
1143 .return_const(WorkerTaskTypes {
1144 enable_workflows: false,
1145 enable_local_activities: false,
1146 enable_remote_activities: true,
1147 enable_nexus: false,
1148 });
1149 activity_worker.expect_try_reserve_wft_slot().times(0); manager
1152 .register_worker(Arc::new(workflow_nexus_worker), false)
1153 .expect("workflow-nexus worker should register");
1154 manager
1155 .register_worker(Arc::new(activity_worker), false)
1156 .expect("activity-only worker should register");
1157
1158 assert_eq!(2, manager.num_providers());
1159 }
1160
1161 #[test]
1162 fn overlapping_capabilities_rejected() {
1163 let manager = ClientWorkerSet::new();
1164 let namespace = "test_namespace".to_string();
1165 let task_queue = "test_queue".to_string();
1166
1167 let mut worker1 = MockClientWorker::new();
1169 worker1.expect_namespace().return_const(namespace.clone());
1170 worker1.expect_task_queue().return_const(task_queue.clone());
1171 worker1.expect_deployment_options().return_const(None);
1172 worker1
1173 .expect_worker_instance_key()
1174 .return_const(Uuid::new_v4());
1175 worker1.expect_heartbeat_enabled().return_const(false);
1176 worker1
1177 .expect_worker_task_types()
1178 .return_const(WorkerTaskTypes {
1179 enable_workflows: true,
1180 enable_local_activities: true,
1181 enable_remote_activities: true,
1182 enable_nexus: false,
1183 });
1184
1185 let mut worker2 = MockClientWorker::new();
1187 worker2.expect_namespace().return_const(namespace.clone());
1188 worker2.expect_task_queue().return_const(task_queue.clone());
1189 worker2.expect_deployment_options().return_const(None);
1190 worker2
1191 .expect_worker_instance_key()
1192 .return_const(Uuid::new_v4());
1193 worker2.expect_heartbeat_enabled().return_const(false);
1194 worker2
1195 .expect_worker_task_types()
1196 .return_const(WorkerTaskTypes {
1197 enable_workflows: true,
1198 enable_local_activities: true,
1199 enable_remote_activities: true,
1200 enable_nexus: false,
1201 });
1202
1203 manager
1204 .register_worker(Arc::new(worker1), false)
1205 .expect("first worker should register");
1206
1207 let result = manager.register_worker(Arc::new(worker2), false);
1208 assert!(result.is_err());
1209 assert!(
1210 result
1211 .unwrap_err()
1212 .to_string()
1213 .contains("overlapping worker task types")
1214 );
1215
1216 let mut worker3 = MockClientWorker::new();
1218 worker3.expect_namespace().return_const(namespace.clone());
1219 worker3.expect_task_queue().return_const(task_queue.clone());
1220 worker3.expect_deployment_options().return_const(None);
1221 worker3
1222 .expect_worker_instance_key()
1223 .return_const(Uuid::new_v4());
1224 worker3.expect_heartbeat_enabled().return_const(false);
1225 worker3
1226 .expect_worker_task_types()
1227 .return_const(WorkerTaskTypes {
1228 enable_workflows: false,
1229 enable_local_activities: false,
1230 enable_remote_activities: true,
1231 enable_nexus: false,
1232 });
1233
1234 let result = manager.register_worker(Arc::new(worker3), false);
1235 assert!(result.is_err());
1236 assert!(
1237 result
1238 .unwrap_err()
1239 .to_string()
1240 .contains("overlapping worker task types")
1241 );
1242 }
1243
1244 #[test]
1245 fn wft_slot_reservation_ignores_non_workflow_workers() {
1246 let mut manager_impl = ClientWorkerSetImpl::new();
1247 let namespace = "test_namespace".to_string();
1248 let task_queue = "test_queue".to_string();
1249
1250 let mut activity_worker = MockClientWorker::new();
1251 activity_worker
1252 .expect_namespace()
1253 .return_const(namespace.clone());
1254 activity_worker
1255 .expect_task_queue()
1256 .return_const(task_queue.clone());
1257 activity_worker
1258 .expect_deployment_options()
1259 .return_const(None);
1260 activity_worker
1261 .expect_worker_instance_key()
1262 .return_const(Uuid::new_v4());
1263 activity_worker
1264 .expect_heartbeat_enabled()
1265 .return_const(false);
1266 activity_worker
1267 .expect_worker_task_types()
1268 .return_const(WorkerTaskTypes {
1269 enable_workflows: false,
1270 enable_local_activities: false,
1271 enable_remote_activities: true,
1272 enable_nexus: false,
1273 });
1274
1275 let mut nexus_worker = MockClientWorker::new();
1276 nexus_worker
1277 .expect_namespace()
1278 .return_const(namespace.clone());
1279 nexus_worker
1280 .expect_task_queue()
1281 .return_const(task_queue.clone());
1282 nexus_worker.expect_deployment_options().return_const(None);
1283 nexus_worker
1284 .expect_worker_instance_key()
1285 .return_const(Uuid::new_v4());
1286 nexus_worker.expect_heartbeat_enabled().return_const(false);
1287 nexus_worker
1288 .expect_worker_task_types()
1289 .return_const(WorkerTaskTypes {
1290 enable_workflows: false,
1291 enable_local_activities: false,
1292 enable_remote_activities: false,
1293 enable_nexus: true,
1294 });
1295
1296 manager_impl
1297 .register(Arc::new(activity_worker), false)
1298 .expect("activity worker should register");
1299 manager_impl
1300 .register(Arc::new(nexus_worker), false)
1301 .expect("nexus worker should register");
1302
1303 let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1304 assert!(
1305 reservation.is_none(),
1306 "should not find workflow workers when only activity/nexus workers registered"
1307 );
1308
1309 let mut workflow_worker = MockClientWorker::new();
1311 workflow_worker
1312 .expect_namespace()
1313 .return_const(namespace.clone());
1314 workflow_worker
1315 .expect_task_queue()
1316 .return_const(task_queue.clone());
1317 workflow_worker
1318 .expect_deployment_options()
1319 .return_const(None);
1320 workflow_worker
1321 .expect_worker_instance_key()
1322 .return_const(Uuid::new_v4());
1323 workflow_worker
1324 .expect_heartbeat_enabled()
1325 .return_const(false);
1326 workflow_worker
1327 .expect_worker_task_types()
1328 .return_const(WorkerTaskTypes {
1329 enable_workflows: true,
1330 enable_local_activities: true,
1331 enable_remote_activities: false,
1332 enable_nexus: false,
1333 });
1334 workflow_worker
1335 .expect_try_reserve_wft_slot()
1336 .times(1)
1337 .returning(|| Some(new_mock_slot(false)));
1338
1339 manager_impl
1340 .register(Arc::new(workflow_worker), false)
1341 .expect("workflow worker should register");
1342
1343 let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1344 assert!(
1345 reservation.is_some(),
1346 "should find workflow worker after it's registered"
1347 );
1348 }
1349
1350 #[test]
1351 fn worker_invalid_type_config_rejected() {
1352 let manager = ClientWorkerSet::new();
1353
1354 let mut worker = MockClientWorker::new();
1356 worker
1357 .expect_namespace()
1358 .return_const("test_namespace".to_string());
1359 worker
1360 .expect_task_queue()
1361 .return_const("test_queue".to_string());
1362 worker.expect_deployment_options().return_const(None);
1363 worker
1364 .expect_worker_instance_key()
1365 .return_const(Uuid::new_v4());
1366 worker.expect_heartbeat_enabled().return_const(false);
1367 worker
1368 .expect_worker_task_types()
1369 .return_const(WorkerTaskTypes {
1370 enable_workflows: false,
1371 enable_local_activities: false,
1372 enable_remote_activities: false,
1373 enable_nexus: false,
1374 });
1375
1376 let result = manager.register_worker(Arc::new(worker), false);
1377 assert!(result.is_err());
1378 assert!(
1379 result
1380 .unwrap_err()
1381 .to_string()
1382 .contains("must have at least one capability enabled")
1383 );
1384
1385 let mut worker = MockClientWorker::new();
1387 worker
1388 .expect_namespace()
1389 .return_const("test_namespace".to_string());
1390 worker
1391 .expect_task_queue()
1392 .return_const("test_queue".to_string());
1393 worker.expect_deployment_options().return_const(None);
1394 worker
1395 .expect_worker_instance_key()
1396 .return_const(Uuid::new_v4());
1397 worker.expect_heartbeat_enabled().return_const(false);
1398 worker
1399 .expect_worker_task_types()
1400 .return_const(WorkerTaskTypes {
1401 enable_workflows: false,
1402 enable_local_activities: true,
1403 enable_remote_activities: true,
1404 enable_nexus: false,
1405 });
1406
1407 let result = manager.register_worker(Arc::new(worker), false);
1408 assert!(result.is_err());
1409 assert_eq!(
1410 result.unwrap_err().to_string(),
1411 "Local activities cannot be enabled without workflows".to_string()
1412 );
1413 }
1414
1415 #[test]
1416 fn unregister_with_multiple_workers() {
1417 let manager = ClientWorkerSet::new();
1418 let namespace = "test_namespace".to_string();
1419 let task_queue = "test_queue".to_string();
1420
1421 let mut workflow_worker = MockClientWorker::new();
1423 workflow_worker
1424 .expect_namespace()
1425 .return_const(namespace.clone());
1426 workflow_worker
1427 .expect_task_queue()
1428 .return_const(task_queue.clone());
1429 workflow_worker
1430 .expect_deployment_options()
1431 .return_const(None);
1432 let wf_worker_key = Uuid::new_v4();
1433 workflow_worker
1434 .expect_worker_instance_key()
1435 .return_const(wf_worker_key);
1436 workflow_worker
1437 .expect_heartbeat_enabled()
1438 .return_const(false);
1439 workflow_worker
1440 .expect_worker_task_types()
1441 .return_const(WorkerTaskTypes {
1442 enable_workflows: true,
1443 enable_local_activities: true,
1444 enable_remote_activities: false,
1445 enable_nexus: false,
1446 });
1447 workflow_worker
1448 .expect_try_reserve_wft_slot()
1449 .returning(|| Some(new_mock_slot(false)));
1450
1451 let mut activity_worker = MockClientWorker::new();
1453 activity_worker
1454 .expect_namespace()
1455 .return_const(namespace.clone());
1456 activity_worker
1457 .expect_task_queue()
1458 .return_const(task_queue.clone());
1459 activity_worker
1460 .expect_deployment_options()
1461 .return_const(None);
1462 let act_worker_key = Uuid::new_v4();
1463 activity_worker
1464 .expect_worker_instance_key()
1465 .return_const(act_worker_key);
1466 activity_worker
1467 .expect_heartbeat_enabled()
1468 .return_const(false);
1469 activity_worker
1470 .expect_worker_task_types()
1471 .return_const(WorkerTaskTypes {
1472 enable_workflows: false,
1473 enable_local_activities: false,
1474 enable_remote_activities: true,
1475 enable_nexus: false,
1476 });
1477
1478 manager
1479 .register_worker(Arc::new(workflow_worker), false)
1480 .expect("workflow worker should register");
1481 manager
1482 .register_worker(Arc::new(activity_worker), false)
1483 .expect("activity worker should register");
1484
1485 assert_eq!(2, manager.num_providers());
1486
1487 let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1488 assert!(
1489 reservation.is_some(),
1490 "should be able to reserve slot from workflow worker"
1491 );
1492
1493 manager
1494 .unregister_slot_provider(wf_worker_key)
1495 .expect("should unregister slot provider for workflow worker");
1496 manager
1497 .finalize_unregister(wf_worker_key)
1498 .expect("should finalize unregister for workflow worker");
1499
1500 assert_eq!(1, manager.num_providers());
1502
1503 let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1504 assert!(
1505 reservation.is_none(),
1506 "should not find workflow worker after unregistration"
1507 );
1508
1509 manager
1510 .unregister_slot_provider(act_worker_key)
1511 .expect("should unregister slot provider for activity worker");
1512 manager
1513 .finalize_unregister(act_worker_key)
1514 .expect("should finalize unregister for activity worker");
1515
1516 assert_eq!(0, manager.num_providers());
1517 }
1518
1519 #[test]
1520 fn worker_unregister_order() {
1521 let manager = ClientWorkerSet::new();
1522 let worker = new_mock_provider_with_heartbeat(
1523 "namespace1".to_string(),
1524 "queue1".to_string(),
1525 true,
1526 None,
1527 );
1528 let worker_instance_key = worker.worker_instance_key();
1529 manager.register_worker(Arc::new(worker), false).unwrap();
1530
1531 let res = manager.finalize_unregister(worker_instance_key);
1532 assert!(res.is_err());
1533 let err_string = res.err().map(|e| e.to_string()).unwrap();
1534 assert!(err_string.contains("Worker still in slot_providers during finalize"));
1535
1536 manager
1539 .unregister_slot_provider(worker_instance_key)
1540 .unwrap();
1541 manager.finalize_unregister(worker_instance_key).unwrap();
1542 }
1543}