1use anyhow::bail;
4use parking_lot::RwLock;
5use rand::seq::SliceRandom;
6use std::{
7 collections::{
8 HashMap,
9 hash_map::Entry::{Occupied, Vacant},
10 },
11 sync::Arc,
12};
13use temporalio_common::{
14 protos::{
15 TaskToken,
16 temporal::api::{
17 worker::v1::WorkerHeartbeat, workflowservice::v1::PollWorkflowTaskQueueResponse,
18 },
19 },
20 worker::{WorkerDeploymentOptions, WorkerTaskTypes},
21};
22use uuid::Uuid;
23
24#[cfg_attr(test, mockall::automock)]
26pub trait Slot {
27 fn schedule_wft(
29 self: Box<Self>,
30 task: PollWorkflowTaskQueueResponse,
31 ) -> Result<(), anyhow::Error>;
32}
33
34pub(crate) struct SlotReservation {
36 pub slot: Box<dyn Slot + Send>,
38 pub deployment_options: Option<WorkerDeploymentOptions>,
40}
41
42#[derive(PartialEq, Eq, Hash, Debug, Clone)]
43struct SlotKey {
44 namespace: String,
45 task_queue: String,
46}
47
48impl SlotKey {
49 fn new(namespace: String, task_queue: String) -> SlotKey {
50 SlotKey {
51 namespace,
52 task_queue,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59struct RegisteredWorkerInfo {
60 worker_id: Uuid,
62 build_id: Option<String>,
64 task_types: WorkerTaskTypes,
66}
67
68impl RegisteredWorkerInfo {
69 fn new(worker_id: Uuid, build_id: Option<String>, task_types: WorkerTaskTypes) -> Self {
70 Self {
71 worker_id,
72 build_id,
73 task_types,
74 }
75 }
76}
77
78struct ClientWorkerSetImpl {
80 slot_providers: HashMap<SlotKey, Vec<RegisteredWorkerInfo>>,
82 all_workers: HashMap<Uuid, Arc<dyn ClientWorker + Send + Sync>>,
84 shared_worker: HashMap<String, Box<dyn SharedNamespaceWorkerTrait + Send + Sync>>,
86}
87
88impl ClientWorkerSetImpl {
89 fn new() -> Self {
91 Self {
92 slot_providers: Default::default(),
93 all_workers: Default::default(),
94 shared_worker: Default::default(),
95 }
96 }
97
98 fn try_reserve_wft_slot(
99 &self,
100 namespace: String,
101 task_queue: String,
102 ) -> Option<SlotReservation> {
103 let key = SlotKey::new(namespace, task_queue);
104 if let Some(worker_list) = self.slot_providers.get(&key) {
105 let workflow_workers: Vec<&RegisteredWorkerInfo> = worker_list
106 .iter()
107 .filter(|info| info.task_types.enable_workflows)
108 .collect();
109
110 for worker_id in Self::worker_ids_in_selection_order(&workflow_workers) {
111 if let Some(worker) = self.all_workers.get(&worker_id)
112 && let Some(slot) = worker.try_reserve_wft_slot()
113 {
114 let deployment_options = worker.deployment_options();
115 return Some(SlotReservation {
116 slot,
117 deployment_options,
118 });
119 }
120 }
121 }
122 None
123 }
124
125 fn worker_ids_in_selection_order(worker_list: &[&RegisteredWorkerInfo]) -> Vec<Uuid> {
126 if cfg!(test) {
129 worker_list.iter().map(|info| info.worker_id).collect()
130 } else {
131 let mut rng = rand::rng();
132 let mut shuffled: Vec<_> = worker_list.to_vec();
133 shuffled.shuffle(&mut rng);
134 shuffled.iter().map(|info| info.worker_id).collect()
135 }
136 }
137
138 fn register(
139 &mut self,
140 worker: Arc<dyn ClientWorker + Send + Sync>,
141 skip_client_worker_set_check: bool,
142 ) -> Result<(), anyhow::Error> {
143 let slot_key = SlotKey::new(
144 worker.namespace().to_string(),
145 worker.task_queue().to_string(),
146 );
147 let build_id = worker
148 .deployment_options()
149 .map(|opts| opts.version.build_id);
150 let task_types = worker.worker_task_types();
151
152 if !task_types.enable_workflows
153 && !task_types.enable_local_activities
154 && !task_types.enable_remote_activities
155 && !task_types.enable_nexus
156 {
157 bail!(
158 "Worker must have at least one capability enabled (workflows, activities, or nexus)"
159 );
160 }
161
162 if !task_types.enable_workflows && task_types.enable_local_activities {
163 bail!("Local activities cannot be enabled without workflows")
164 }
165
166 if !skip_client_worker_set_check
167 && let Some(existing_workers) = self.slot_providers.get(&slot_key)
168 {
169 for existing_worker_info in existing_workers {
170 if existing_worker_info.build_id.as_ref() == build_id.as_ref()
171 && task_types.overlaps_with(&existing_worker_info.task_types)
172 {
173 bail!(
174 "Registration of multiple workers with overlapping worker task types \
175 on the same namespace, task queue, and deployment build ID not allowed: \
176 {slot_key:?}, worker_instance_key: {:?} \
177 build_id: {build_id:?}, \
178 new task types: {task_types:?}, \
179 existing task types: {:?}.",
180 existing_worker_info.task_types,
181 worker.worker_instance_key()
182 );
183 }
184 }
185 }
186
187 if worker.heartbeat_enabled()
188 && let Some(heartbeat_callback) = worker.heartbeat_callback()
189 {
190 let worker_instance_key = worker.worker_instance_key();
191 let namespace = worker.namespace().to_string();
192
193 let shared_worker = match self.shared_worker.entry(namespace.clone()) {
194 Occupied(o) => o.into_mut(),
195 Vacant(v) => {
196 let shared_worker = worker.new_shared_namespace_worker()?;
197 v.insert(shared_worker)
198 }
199 };
200 shared_worker.register_callback(
201 worker_instance_key,
202 WorkerCallbacks {
203 heartbeat: heartbeat_callback,
204 cancel_activity: worker.cancel_activity_callback(),
205 },
206 );
207 }
208
209 let worker_info =
210 RegisteredWorkerInfo::new(worker.worker_instance_key(), build_id, task_types);
211
212 match self.slot_providers.entry(slot_key.clone()) {
213 Occupied(o) => o.into_mut().push(worker_info),
214 Vacant(v) => {
215 v.insert(vec![worker_info]);
216 }
217 };
218
219 self.all_workers
220 .insert(worker.worker_instance_key(), worker);
221
222 Ok(())
223 }
224
225 fn unregister_slot_provider(&mut self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
228 let worker = self.all_workers.get(&worker_instance_key).ok_or_else(|| {
229 anyhow::anyhow!("Worker not in all_workers during slot provider unregister")
230 })?;
231
232 let slot_key = SlotKey::new(
233 worker.namespace().to_string(),
234 worker.task_queue().to_string(),
235 );
236 if let Some(slot_vec) = self.slot_providers.get_mut(&slot_key) {
237 slot_vec.retain(|info| info.worker_id != worker_instance_key);
238 if slot_vec.is_empty() {
239 self.slot_providers.remove(&slot_key);
240 }
241 }
242 Ok(())
243 }
244
245 fn finalize_unregister(
246 &mut self,
247 worker_instance_key: Uuid,
248 ) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
249 if let Some(worker) = self.all_workers.get(&worker_instance_key)
250 && let Some(slot_vec) = self.slot_providers.get(&SlotKey::new(
251 worker.namespace().to_string(),
252 worker.task_queue().to_string(),
253 ))
254 && slot_vec
255 .iter()
256 .any(|info| info.worker_id == worker_instance_key)
257 {
258 return Err(anyhow::anyhow!(
259 "Worker still in slot_providers during finalize"
260 ));
261 }
262
263 let worker = self
264 .all_workers
265 .remove(&worker_instance_key)
266 .ok_or_else(|| anyhow::anyhow!("Worker not found in all_workers"))?;
267
268 if let Some(w) = self.shared_worker.get_mut(worker.namespace()) {
269 let (callback, is_empty) = w.unregister_callback(worker.worker_instance_key());
270 if callback.is_some() && is_empty {
271 self.shared_worker.remove(worker.namespace());
272 }
273 }
274
275 Ok(worker)
276 }
277
278 #[cfg(test)]
279 fn num_providers(&self) -> usize {
280 self.slot_providers.values().map(|v| v.len()).sum()
281 }
282
283 #[cfg(test)]
284 fn num_heartbeat_workers(&self) -> usize {
285 self.shared_worker.values().map(|v| v.num_workers()).sum()
286 }
287}
288
289pub trait SharedNamespaceWorkerTrait {
292 fn namespace(&self) -> String;
294
295 fn register_callback(&self, worker_instance_key: Uuid, callbacks: WorkerCallbacks);
297
298 fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option<WorkerCallbacks>, bool);
302
303 fn num_workers(&self) -> usize;
305}
306
307pub struct ClientWorkerSet {
313 worker_grouping_key: Uuid,
314 worker_manager: RwLock<ClientWorkerSetImpl>,
315}
316
317impl Default for ClientWorkerSet {
318 fn default() -> Self {
319 Self::new()
320 }
321}
322
323impl ClientWorkerSet {
324 pub fn new() -> Self {
326 Self {
327 worker_grouping_key: Uuid::new_v4(),
328 worker_manager: RwLock::new(ClientWorkerSetImpl::new()),
329 }
330 }
331
332 pub(crate) fn try_reserve_wft_slot(
335 &self,
336 namespace: String,
337 task_queue: String,
338 ) -> Option<SlotReservation> {
339 self.worker_manager
340 .read()
341 .try_reserve_wft_slot(namespace, task_queue)
342 }
343
344 pub fn register_worker(
346 &self,
347 worker: Arc<dyn ClientWorker + Send + Sync>,
348 skip_client_worker_set_check: bool,
349 ) -> Result<(), anyhow::Error> {
350 self.worker_manager
351 .write()
352 .register(worker, skip_client_worker_set_check)
353 }
354
355 pub fn unregister_slot_provider(&self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
358 self.worker_manager
359 .write()
360 .unregister_slot_provider(worker_instance_key)
361 }
362
363 pub fn finalize_unregister(
367 &self,
368 worker_instance_key: Uuid,
369 ) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
370 self.worker_manager
371 .write()
372 .finalize_unregister(worker_instance_key)
373 }
374
375 pub fn worker_grouping_key(&self) -> Uuid {
377 self.worker_grouping_key
378 }
379
380 #[cfg(test)]
381 pub fn num_providers(&self) -> usize {
384 self.worker_manager.read().num_providers()
385 }
386
387 #[cfg(test)]
388 pub fn num_heartbeat_workers(&self) -> usize {
390 self.worker_manager.read().num_heartbeat_workers()
391 }
392}
393
394impl std::fmt::Debug for ClientWorkerSet {
395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 f.debug_struct("ClientWorkerSet")
397 .field("worker_grouping_key", &self.worker_grouping_key)
398 .finish()
399 }
400}
401
402pub type HeartbeatCallback = Arc<dyn Fn() -> WorkerHeartbeat + Send + Sync>;
404
405pub type CancelActivityCallback = Arc<dyn Fn(TaskToken) -> bool + Send + Sync>;
407
408pub struct WorkerCallbacks {
410 pub heartbeat: HeartbeatCallback,
412 pub cancel_activity: Option<CancelActivityCallback>,
414}
415
416#[cfg_attr(test, mockall::automock)]
419pub trait ClientWorker: Send + Sync {
420 fn namespace(&self) -> &str;
422
423 fn task_queue(&self) -> &str;
425
426 fn try_reserve_wft_slot(&self) -> Option<Box<dyn Slot + Send>>;
432
433 fn deployment_options(&self) -> Option<WorkerDeploymentOptions>;
435
436 fn worker_instance_key(&self) -> Uuid;
439
440 fn heartbeat_enabled(&self) -> bool;
442
443 fn heartbeat_callback(&self) -> Option<HeartbeatCallback>;
445
446 fn cancel_activity_callback(&self) -> Option<CancelActivityCallback>;
448
449 fn new_shared_namespace_worker(
451 &self,
452 ) -> Result<Box<dyn SharedNamespaceWorkerTrait + Send + Sync>, anyhow::Error>;
453
454 fn worker_task_types(&self) -> WorkerTaskTypes;
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 fn new_mock_slot(with_error: bool) -> Box<MockSlot> {
463 let mut mock_slot = MockSlot::new();
464 if with_error {
465 mock_slot
466 .expect_schedule_wft()
467 .returning(|_| Err(anyhow::anyhow!("Changed my mind")));
468 } else {
469 mock_slot.expect_schedule_wft().returning(|_| Ok(()));
470 }
471 Box::new(mock_slot)
472 }
473
474 fn new_mock_provider(
475 namespace: String,
476 task_queue: String,
477 with_error: bool,
478 no_slots: bool,
479 heartbeat_enabled: bool,
480 ) -> MockClientWorker {
481 let mut mock_provider = MockClientWorker::new();
482 mock_provider
483 .expect_try_reserve_wft_slot()
484 .returning(move || {
485 if no_slots {
486 None
487 } else {
488 Some(new_mock_slot(with_error))
489 }
490 });
491 mock_provider.expect_namespace().return_const(namespace);
492 mock_provider.expect_task_queue().return_const(task_queue);
493 mock_provider.expect_deployment_options().return_const(None);
494 mock_provider
495 .expect_heartbeat_enabled()
496 .return_const(heartbeat_enabled);
497 mock_provider
498 .expect_worker_instance_key()
499 .return_const(Uuid::new_v4());
500 mock_provider
501 .expect_worker_task_types()
502 .return_const(WorkerTaskTypes {
503 enable_workflows: true,
504 enable_local_activities: true,
505 enable_remote_activities: true,
506 enable_nexus: true,
507 });
508 mock_provider
509 }
510
511 #[test]
512 fn reserve_wft_slot_retries_another_worker_when_first_has_no_slot() {
513 let mut manager = ClientWorkerSetImpl::new();
514 let namespace = "retry_namespace".to_string();
515 let task_queue = "retry_queue".to_string();
516
517 let failing_worker_id = Uuid::new_v4();
518 let mut failing_worker = MockClientWorker::new();
519 failing_worker
520 .expect_try_reserve_wft_slot()
521 .times(1)
522 .returning(|| None);
523 failing_worker
524 .expect_namespace()
525 .return_const(namespace.clone());
526 failing_worker
527 .expect_task_queue()
528 .return_const(task_queue.clone());
529 failing_worker
530 .expect_deployment_options()
531 .return_const(WorkerDeploymentOptions {
532 version: temporalio_common::worker::WorkerDeploymentVersion {
533 deployment_name: "test-deployment".to_string(),
534 build_id: "build-fail".to_string(),
535 },
536 use_worker_versioning: true,
537 default_versioning_behavior: None,
538 });
539 failing_worker
540 .expect_worker_instance_key()
541 .return_const(failing_worker_id);
542 failing_worker
543 .expect_heartbeat_enabled()
544 .return_const(false);
545 failing_worker
546 .expect_worker_task_types()
547 .return_const(WorkerTaskTypes {
548 enable_workflows: true,
549 enable_local_activities: true,
550 enable_remote_activities: true,
551 enable_nexus: true,
552 });
553
554 let succeeding_worker_id = Uuid::new_v4();
555 let mut succeeding_worker = MockClientWorker::new();
556 succeeding_worker
557 .expect_try_reserve_wft_slot()
558 .times(1)
559 .returning(|| Some(new_mock_slot(false)));
560 succeeding_worker
561 .expect_namespace()
562 .return_const(namespace.clone());
563 succeeding_worker
564 .expect_task_queue()
565 .return_const(task_queue.clone());
566 let success_deployment_options = WorkerDeploymentOptions {
567 version: temporalio_common::worker::WorkerDeploymentVersion {
568 deployment_name: "test-deployment".to_string(),
569 build_id: "build-success".to_string(),
570 },
571 use_worker_versioning: true,
572 default_versioning_behavior: None,
573 };
574 succeeding_worker
575 .expect_deployment_options()
576 .return_const(success_deployment_options.clone());
577 succeeding_worker
578 .expect_worker_instance_key()
579 .return_const(succeeding_worker_id);
580 succeeding_worker
581 .expect_heartbeat_enabled()
582 .return_const(false);
583 succeeding_worker
584 .expect_worker_task_types()
585 .return_const(WorkerTaskTypes {
586 enable_workflows: true,
587 enable_local_activities: true,
588 enable_remote_activities: true,
589 enable_nexus: true,
590 });
591
592 manager
593 .register(Arc::new(failing_worker), false)
594 .expect("failing worker registration succeeds");
595 manager
596 .register(Arc::new(succeeding_worker), false)
597 .expect("succeeding worker registration succeeds");
598
599 let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
600
601 let reservation_deployment_options = reservation
602 .expect("succeeding worker was used after failing worker failed")
603 .deployment_options
604 .unwrap();
605 assert_eq!(
606 reservation_deployment_options, success_deployment_options,
607 "deployment options bubble through from succeeding worker"
608 );
609 }
610
611 #[test]
612 fn reserve_wft_slot_retries_respects_slot_boundary() {
613 let mut manager = ClientWorkerSetImpl::new();
614 let namespace = "retry_namespace".to_string();
615 let task_queue = "retry_queue".to_string();
616
617 let failing_worker_id = Uuid::new_v4();
618 let mut failing_worker = MockClientWorker::new();
619 failing_worker
620 .expect_try_reserve_wft_slot()
621 .times(1)
622 .returning(|| None);
623 failing_worker
624 .expect_namespace()
625 .return_const(namespace.clone());
626 failing_worker
627 .expect_task_queue()
628 .return_const(task_queue.clone());
629 failing_worker
630 .expect_deployment_options()
631 .return_const(WorkerDeploymentOptions {
632 version: temporalio_common::worker::WorkerDeploymentVersion {
633 deployment_name: "test-deployment".to_string(),
634 build_id: "build-fail".to_string(),
635 },
636 use_worker_versioning: true,
637 default_versioning_behavior: None,
638 });
639 failing_worker
640 .expect_worker_instance_key()
641 .return_const(failing_worker_id);
642 failing_worker
643 .expect_heartbeat_enabled()
644 .return_const(false);
645 failing_worker
646 .expect_worker_task_types()
647 .return_const(WorkerTaskTypes {
648 enable_workflows: true,
649 enable_local_activities: true,
650 enable_remote_activities: true,
651 enable_nexus: true,
652 });
653
654 let succeeding_worker_id = Uuid::new_v4();
656 let mut succeeding_worker = MockClientWorker::new();
657 succeeding_worker.expect_try_reserve_wft_slot().times(0);
658 succeeding_worker
659 .expect_namespace()
660 .return_const(namespace.clone());
661 succeeding_worker
662 .expect_task_queue()
663 .return_const("other_task_queue".to_string());
664 succeeding_worker
665 .expect_deployment_options()
666 .return_const(None);
667 succeeding_worker
668 .expect_worker_instance_key()
669 .return_const(succeeding_worker_id);
670 succeeding_worker
671 .expect_heartbeat_enabled()
672 .return_const(false);
673 succeeding_worker
674 .expect_worker_task_types()
675 .return_const(WorkerTaskTypes {
676 enable_workflows: true,
677 enable_local_activities: true,
678 enable_remote_activities: true,
679 enable_nexus: true,
680 });
681
682 manager
683 .register(Arc::new(failing_worker), false)
684 .expect("failing worker registration succeeds");
685 manager
686 .register(Arc::new(succeeding_worker), false)
687 .expect("succeeding worker registration succeeds");
688
689 let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
690 assert!(
691 reservation.is_none(),
692 "succeeding_worker should not be picked due to it being on a separate task queue"
693 );
694 }
695
696 #[test]
697 fn registry_keeps_one_provider_per_namespace() {
698 let manager = ClientWorkerSet::new();
699 let mut worker_keys = vec![];
700 let mut successful_registrations = 0;
701
702 for i in 0..10 {
703 let namespace = format!("myId{}", i % 3);
704 let mock_provider =
705 new_mock_provider(namespace, "bar_q".to_string(), false, false, false);
706 let worker_instance_key = mock_provider.worker_instance_key();
707
708 let result = manager.register_worker(Arc::new(mock_provider), false);
709 if let Err(err) = result {
710 assert!(err.to_string().contains(
712 "Registration of multiple workers with overlapping worker task types"
713 ));
714 } else {
715 successful_registrations += 1;
716 worker_keys.push(worker_instance_key);
717 }
718 }
719
720 assert_eq!(successful_registrations, 3);
721 assert_eq!(3, manager.num_providers());
722
723 let count = worker_keys.iter().fold(0, |count, key| {
724 manager.unregister_slot_provider(*key).unwrap();
725 manager.finalize_unregister(*key).unwrap();
726 let result = manager.unregister_slot_provider(*key);
728 assert!(result.is_err());
729 let result = manager.finalize_unregister(*key);
730 assert!(result.is_err());
731 count + 1
732 });
733 assert_eq!(3, count);
734 assert_eq!(0, manager.num_providers());
735 }
736
737 struct MockSharedNamespaceWorker {
738 namespace: String,
739 callbacks: Arc<RwLock<HashMap<Uuid, WorkerCallbacks>>>,
740 }
741
742 impl std::fmt::Debug for MockSharedNamespaceWorker {
743 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
744 f.debug_struct("MockSharedNamespaceWorker")
745 .field("namespace", &self.namespace)
746 .field("callbacks_count", &self.callbacks.read().len())
747 .finish()
748 }
749 }
750
751 impl MockSharedNamespaceWorker {
752 fn new(namespace: String) -> Self {
753 Self {
754 namespace,
755 callbacks: Arc::new(RwLock::new(HashMap::new())),
756 }
757 }
758 }
759
760 impl SharedNamespaceWorkerTrait for MockSharedNamespaceWorker {
761 fn namespace(&self) -> String {
762 self.namespace.clone()
763 }
764
765 fn register_callback(&self, worker_instance_key: Uuid, callbacks: WorkerCallbacks) {
766 self.callbacks
767 .write()
768 .insert(worker_instance_key, callbacks);
769 }
770
771 fn unregister_callback(
772 &self,
773 worker_instance_key: Uuid,
774 ) -> (Option<WorkerCallbacks>, bool) {
775 let mut callbacks = self.callbacks.write();
776 let callback = callbacks.remove(&worker_instance_key);
777 let is_empty = callbacks.is_empty();
778 (callback, is_empty)
779 }
780
781 fn num_workers(&self) -> usize {
782 self.callbacks.read().len()
783 }
784 }
785
786 fn new_mock_provider_with_heartbeat(
787 namespace: String,
788 task_queue: String,
789 heartbeat_enabled: bool,
790 build_id: Option<String>,
791 ) -> MockClientWorker {
792 let mut mock_provider = MockClientWorker::new();
793 mock_provider
794 .expect_try_reserve_wft_slot()
795 .returning(|| Some(new_mock_slot(false)));
796 mock_provider
797 .expect_namespace()
798 .return_const(namespace.clone());
799 mock_provider.expect_task_queue().return_const(task_queue);
800 mock_provider
801 .expect_heartbeat_enabled()
802 .return_const(heartbeat_enabled);
803 mock_provider
804 .expect_worker_instance_key()
805 .return_const(Uuid::new_v4());
806 let deployment_name = "test-deployment".to_string();
807 let build_id_for_closure = build_id.clone();
808 mock_provider
809 .expect_deployment_options()
810 .returning(move || {
811 build_id_for_closure
812 .as_ref()
813 .map(|build_id| WorkerDeploymentOptions {
814 version: temporalio_common::worker::WorkerDeploymentVersion {
815 deployment_name: deployment_name.clone(),
816 build_id: build_id.clone(),
817 },
818 use_worker_versioning: true,
819 default_versioning_behavior: None,
820 })
821 });
822
823 if heartbeat_enabled {
824 mock_provider
825 .expect_heartbeat_callback()
826 .returning(|| Some(Arc::new(WorkerHeartbeat::default)));
827 mock_provider
828 .expect_cancel_activity_callback()
829 .returning(|| None);
830
831 let namespace_clone = namespace.clone();
832 mock_provider
833 .expect_new_shared_namespace_worker()
834 .returning(move || {
835 Ok(Box::new(MockSharedNamespaceWorker::new(
836 namespace_clone.clone(),
837 )))
838 });
839 }
840
841 mock_provider
842 .expect_worker_task_types()
843 .return_const(WorkerTaskTypes {
844 enable_workflows: true,
845 enable_local_activities: true,
846 enable_remote_activities: true,
847 enable_nexus: true,
848 });
849
850 mock_provider
851 }
852
853 #[test]
854 fn duplicate_namespace_task_queue_registration_fails() {
855 let manager = ClientWorkerSet::new();
856
857 let worker1 = new_mock_provider_with_heartbeat(
858 "test_namespace".to_string(),
859 "test_queue".to_string(),
860 true,
861 None,
862 );
863
864 let worker2 = new_mock_provider_with_heartbeat(
866 "test_namespace".to_string(),
867 "test_queue".to_string(),
868 true,
869 None,
870 );
871
872 manager.register_worker(Arc::new(worker1), false).unwrap();
873
874 let result = manager.register_worker(Arc::new(worker2), false);
876 assert!(result.is_err());
877 assert!(
878 result
879 .unwrap_err()
880 .to_string()
881 .contains("Registration of multiple workers with overlapping worker task types")
882 );
883
884 assert_eq!(1, manager.num_providers());
885 assert_eq!(manager.num_heartbeat_workers(), 1);
886
887 let impl_ref = manager.worker_manager.read();
888 assert_eq!(impl_ref.shared_worker.len(), 1);
889 assert!(impl_ref.shared_worker.contains_key("test_namespace"));
890 }
891
892 #[test]
893 fn duplicate_namespace_with_different_build_ids_succeeds() {
894 let manager = ClientWorkerSet::new();
895 let namespace = "test_namespace".to_string();
896 let task_queue = "test_queue".to_string();
897
898 let worker1 =
899 new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
900 let worker1_instance_key = worker1.worker_instance_key();
901 let worker2 = new_mock_provider_with_heartbeat(
902 namespace.clone(),
903 task_queue.clone(),
904 false,
905 Some("build-1".to_string()),
906 );
907 let worker2_instance_key = worker2.worker_instance_key();
908 let worker3 =
909 new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
910 let worker4 = new_mock_provider_with_heartbeat(
911 namespace.clone(),
912 task_queue.clone(),
913 false,
914 Some("build-1".to_string()),
915 );
916
917 manager.register_worker(Arc::new(worker1), false).unwrap();
918
919 manager
920 .register_worker(Arc::new(worker2), false)
921 .expect("worker with new build ID should register");
922 assert_eq!(2, manager.num_providers());
923
924 assert!(
925 manager
926 .register_worker(Arc::new(worker3), false)
927 .unwrap_err()
928 .to_string()
929 .contains("Registration of multiple workers with overlapping worker task types")
930 );
931
932 assert!(
933 manager
934 .register_worker(Arc::new(worker4), false)
935 .unwrap_err()
936 .to_string()
937 .contains("Registration of multiple workers with overlapping worker task types")
938 );
939 assert_eq!(2, manager.num_providers());
940
941 {
942 let impl_ref = manager.worker_manager.read();
943 let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
944 let providers = impl_ref
945 .slot_providers
946 .get(&slot_key)
947 .expect("slot providers should exist for namespace/task queue");
948 assert_eq!(2, providers.len());
949
950 assert_eq!(providers[0].worker_id, worker1_instance_key);
951 assert_eq!(providers[0].build_id, None);
952 assert_eq!(providers[1].worker_id, worker2_instance_key);
953 assert_eq!(providers[1].build_id, Some("build-1".to_string()));
954 }
955
956 manager
957 .unregister_slot_provider(worker2_instance_key)
958 .unwrap();
959 manager.finalize_unregister(worker2_instance_key).unwrap();
960
961 {
962 let impl_ref = manager.worker_manager.read();
963 let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
964 let providers = impl_ref
965 .slot_providers
966 .get(&slot_key)
967 .expect("slot providers should exist for namespace/task queue");
968
969 assert_eq!(1, providers.len());
970 assert_eq!(providers[0].worker_id, worker1_instance_key);
971 assert_eq!(providers[0].build_id, None);
972 }
973 }
974
975 #[test]
976 fn multiple_workers_same_namespace_share_heartbeat_manager() {
977 let manager = ClientWorkerSet::new();
978
979 let worker1 = new_mock_provider_with_heartbeat(
980 "shared_namespace".to_string(),
981 "queue1".to_string(),
982 true,
983 None,
984 );
985
986 let worker2 = new_mock_provider_with_heartbeat(
988 "shared_namespace".to_string(),
989 "queue2".to_string(),
990 true,
991 None,
992 );
993
994 manager.register_worker(Arc::new(worker1), false).unwrap();
995 manager.register_worker(Arc::new(worker2), false).unwrap();
996
997 assert_eq!(2, manager.num_providers());
998 assert_eq!(manager.num_heartbeat_workers(), 2);
999
1000 let impl_ref = manager.worker_manager.read();
1001 assert_eq!(impl_ref.shared_worker.len(), 1);
1002 assert!(impl_ref.shared_worker.contains_key("shared_namespace"));
1003
1004 let shared_worker = impl_ref.shared_worker.get("shared_namespace").unwrap();
1005 assert_eq!(shared_worker.namespace(), "shared_namespace");
1006 }
1007
1008 #[test]
1009 fn different_namespaces_get_separate_heartbeat_managers() {
1010 let manager = ClientWorkerSet::new();
1011 let worker1 = new_mock_provider_with_heartbeat(
1012 "namespace1".to_string(),
1013 "queue1".to_string(),
1014 true,
1015 None,
1016 );
1017 let worker2 = new_mock_provider_with_heartbeat(
1018 "namespace2".to_string(),
1019 "queue1".to_string(),
1020 true,
1021 None,
1022 );
1023
1024 manager.register_worker(Arc::new(worker1), false).unwrap();
1025 manager.register_worker(Arc::new(worker2), false).unwrap();
1026
1027 assert_eq!(2, manager.num_providers());
1028 assert_eq!(manager.num_heartbeat_workers(), 2);
1029
1030 let impl_ref = manager.worker_manager.read();
1031 assert_eq!(impl_ref.num_heartbeat_workers(), 2);
1032 assert!(impl_ref.shared_worker.contains_key("namespace1"));
1033 assert!(impl_ref.shared_worker.contains_key("namespace2"));
1034 }
1035
1036 #[test]
1037 fn unregister_heartbeat_workers_cleans_up_shared_worker_when_last_removed() {
1038 let manager = ClientWorkerSet::new();
1039
1040 let worker1 = new_mock_provider_with_heartbeat(
1042 "test_namespace".to_string(),
1043 "queue1".to_string(),
1044 true,
1045 None,
1046 );
1047 let worker2 = new_mock_provider_with_heartbeat(
1048 "test_namespace".to_string(),
1049 "queue2".to_string(),
1050 true,
1051 None,
1052 );
1053 let worker_instance_key1 = worker1.worker_instance_key();
1054 let worker_instance_key2 = worker2.worker_instance_key();
1055
1056 assert_ne!(worker_instance_key1, worker_instance_key2);
1057
1058 manager.register_worker(Arc::new(worker1), false).unwrap();
1059 manager.register_worker(Arc::new(worker2), false).unwrap();
1060
1061 assert_eq!(2, manager.num_providers());
1063 assert_eq!(manager.num_heartbeat_workers(), 2);
1064
1065 let impl_ref = manager.worker_manager.read();
1066 assert_eq!(impl_ref.shared_worker.len(), 1);
1067 assert!(impl_ref.shared_worker.contains_key("test_namespace"));
1068 assert_eq!(
1069 impl_ref
1070 .shared_worker
1071 .get("test_namespace")
1072 .unwrap()
1073 .num_workers(),
1074 2
1075 );
1076 drop(impl_ref);
1077
1078 manager
1080 .unregister_slot_provider(worker_instance_key1)
1081 .unwrap();
1082 manager.finalize_unregister(worker_instance_key1).unwrap();
1083
1084 assert_eq!(1, manager.num_providers());
1086 assert_eq!(manager.num_heartbeat_workers(), 1);
1087
1088 let impl_ref = manager.worker_manager.read();
1089 assert_eq!(impl_ref.num_heartbeat_workers(), 1); assert!(impl_ref.shared_worker.contains_key("test_namespace"));
1091 assert_eq!(
1092 impl_ref
1093 .shared_worker
1094 .get("test_namespace")
1095 .unwrap()
1096 .num_workers(),
1097 1
1098 );
1099 drop(impl_ref);
1100
1101 manager
1103 .unregister_slot_provider(worker_instance_key2)
1104 .unwrap();
1105 manager.finalize_unregister(worker_instance_key2).unwrap();
1106
1107 assert_eq!(0, manager.num_providers());
1109 assert_eq!(manager.num_heartbeat_workers(), 0);
1110
1111 let impl_ref = manager.worker_manager.read();
1112 assert_eq!(impl_ref.shared_worker.len(), 0); assert!(!impl_ref.shared_worker.contains_key("test_namespace"));
1114 }
1115
1116 #[test]
1117 fn workflow_and_activity_only_workers_coexist() {
1118 let manager = ClientWorkerSet::new();
1119 let namespace = "test_namespace".to_string();
1120 let task_queue = "test_queue".to_string();
1121
1122 let mut workflow_nexus_worker = MockClientWorker::new();
1123 workflow_nexus_worker
1124 .expect_namespace()
1125 .return_const(namespace.clone());
1126 workflow_nexus_worker
1127 .expect_task_queue()
1128 .return_const(task_queue.clone());
1129 workflow_nexus_worker
1130 .expect_deployment_options()
1131 .return_const(None);
1132 workflow_nexus_worker
1133 .expect_worker_instance_key()
1134 .return_const(Uuid::new_v4());
1135 workflow_nexus_worker
1136 .expect_heartbeat_enabled()
1137 .return_const(false);
1138 workflow_nexus_worker
1139 .expect_worker_task_types()
1140 .return_const(WorkerTaskTypes {
1141 enable_workflows: true,
1142 enable_local_activities: false,
1143 enable_remote_activities: false,
1144 enable_nexus: true,
1145 });
1146
1147 let mut activity_worker = MockClientWorker::new();
1148 activity_worker
1149 .expect_namespace()
1150 .return_const(namespace.clone());
1151 activity_worker
1152 .expect_task_queue()
1153 .return_const(task_queue.clone());
1154 activity_worker
1155 .expect_deployment_options()
1156 .return_const(None);
1157 activity_worker
1158 .expect_worker_instance_key()
1159 .return_const(Uuid::new_v4());
1160 activity_worker
1161 .expect_heartbeat_enabled()
1162 .return_const(false);
1163 activity_worker
1164 .expect_worker_task_types()
1165 .return_const(WorkerTaskTypes {
1166 enable_workflows: false,
1167 enable_local_activities: false,
1168 enable_remote_activities: true,
1169 enable_nexus: false,
1170 });
1171 activity_worker.expect_try_reserve_wft_slot().times(0); manager
1174 .register_worker(Arc::new(workflow_nexus_worker), false)
1175 .expect("workflow-nexus worker should register");
1176 manager
1177 .register_worker(Arc::new(activity_worker), false)
1178 .expect("activity-only worker should register");
1179
1180 assert_eq!(2, manager.num_providers());
1181 }
1182
1183 #[test]
1184 fn overlapping_capabilities_rejected() {
1185 let manager = ClientWorkerSet::new();
1186 let namespace = "test_namespace".to_string();
1187 let task_queue = "test_queue".to_string();
1188
1189 let mut worker1 = MockClientWorker::new();
1191 worker1.expect_namespace().return_const(namespace.clone());
1192 worker1.expect_task_queue().return_const(task_queue.clone());
1193 worker1.expect_deployment_options().return_const(None);
1194 worker1
1195 .expect_worker_instance_key()
1196 .return_const(Uuid::new_v4());
1197 worker1.expect_heartbeat_enabled().return_const(false);
1198 worker1
1199 .expect_worker_task_types()
1200 .return_const(WorkerTaskTypes {
1201 enable_workflows: true,
1202 enable_local_activities: true,
1203 enable_remote_activities: true,
1204 enable_nexus: false,
1205 });
1206
1207 let mut worker2 = MockClientWorker::new();
1209 worker2.expect_namespace().return_const(namespace.clone());
1210 worker2.expect_task_queue().return_const(task_queue.clone());
1211 worker2.expect_deployment_options().return_const(None);
1212 worker2
1213 .expect_worker_instance_key()
1214 .return_const(Uuid::new_v4());
1215 worker2.expect_heartbeat_enabled().return_const(false);
1216 worker2
1217 .expect_worker_task_types()
1218 .return_const(WorkerTaskTypes {
1219 enable_workflows: true,
1220 enable_local_activities: true,
1221 enable_remote_activities: true,
1222 enable_nexus: false,
1223 });
1224
1225 manager
1226 .register_worker(Arc::new(worker1), false)
1227 .expect("first worker should register");
1228
1229 let result = manager.register_worker(Arc::new(worker2), false);
1230 assert!(result.is_err());
1231 assert!(
1232 result
1233 .unwrap_err()
1234 .to_string()
1235 .contains("overlapping worker task types")
1236 );
1237
1238 let mut worker3 = MockClientWorker::new();
1240 worker3.expect_namespace().return_const(namespace.clone());
1241 worker3.expect_task_queue().return_const(task_queue.clone());
1242 worker3.expect_deployment_options().return_const(None);
1243 worker3
1244 .expect_worker_instance_key()
1245 .return_const(Uuid::new_v4());
1246 worker3.expect_heartbeat_enabled().return_const(false);
1247 worker3
1248 .expect_worker_task_types()
1249 .return_const(WorkerTaskTypes {
1250 enable_workflows: false,
1251 enable_local_activities: false,
1252 enable_remote_activities: true,
1253 enable_nexus: false,
1254 });
1255
1256 let result = manager.register_worker(Arc::new(worker3), false);
1257 assert!(result.is_err());
1258 assert!(
1259 result
1260 .unwrap_err()
1261 .to_string()
1262 .contains("overlapping worker task types")
1263 );
1264 }
1265
1266 #[test]
1267 fn wft_slot_reservation_ignores_non_workflow_workers() {
1268 let mut manager_impl = ClientWorkerSetImpl::new();
1269 let namespace = "test_namespace".to_string();
1270 let task_queue = "test_queue".to_string();
1271
1272 let mut activity_worker = MockClientWorker::new();
1273 activity_worker
1274 .expect_namespace()
1275 .return_const(namespace.clone());
1276 activity_worker
1277 .expect_task_queue()
1278 .return_const(task_queue.clone());
1279 activity_worker
1280 .expect_deployment_options()
1281 .return_const(None);
1282 activity_worker
1283 .expect_worker_instance_key()
1284 .return_const(Uuid::new_v4());
1285 activity_worker
1286 .expect_heartbeat_enabled()
1287 .return_const(false);
1288 activity_worker
1289 .expect_worker_task_types()
1290 .return_const(WorkerTaskTypes {
1291 enable_workflows: false,
1292 enable_local_activities: false,
1293 enable_remote_activities: true,
1294 enable_nexus: false,
1295 });
1296
1297 let mut nexus_worker = MockClientWorker::new();
1298 nexus_worker
1299 .expect_namespace()
1300 .return_const(namespace.clone());
1301 nexus_worker
1302 .expect_task_queue()
1303 .return_const(task_queue.clone());
1304 nexus_worker.expect_deployment_options().return_const(None);
1305 nexus_worker
1306 .expect_worker_instance_key()
1307 .return_const(Uuid::new_v4());
1308 nexus_worker.expect_heartbeat_enabled().return_const(false);
1309 nexus_worker
1310 .expect_worker_task_types()
1311 .return_const(WorkerTaskTypes {
1312 enable_workflows: false,
1313 enable_local_activities: false,
1314 enable_remote_activities: false,
1315 enable_nexus: true,
1316 });
1317
1318 manager_impl
1319 .register(Arc::new(activity_worker), false)
1320 .expect("activity worker should register");
1321 manager_impl
1322 .register(Arc::new(nexus_worker), false)
1323 .expect("nexus worker should register");
1324
1325 let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1326 assert!(
1327 reservation.is_none(),
1328 "should not find workflow workers when only activity/nexus workers registered"
1329 );
1330
1331 let mut workflow_worker = MockClientWorker::new();
1333 workflow_worker
1334 .expect_namespace()
1335 .return_const(namespace.clone());
1336 workflow_worker
1337 .expect_task_queue()
1338 .return_const(task_queue.clone());
1339 workflow_worker
1340 .expect_deployment_options()
1341 .return_const(None);
1342 workflow_worker
1343 .expect_worker_instance_key()
1344 .return_const(Uuid::new_v4());
1345 workflow_worker
1346 .expect_heartbeat_enabled()
1347 .return_const(false);
1348 workflow_worker
1349 .expect_worker_task_types()
1350 .return_const(WorkerTaskTypes {
1351 enable_workflows: true,
1352 enable_local_activities: true,
1353 enable_remote_activities: false,
1354 enable_nexus: false,
1355 });
1356 workflow_worker
1357 .expect_try_reserve_wft_slot()
1358 .times(1)
1359 .returning(|| Some(new_mock_slot(false)));
1360
1361 manager_impl
1362 .register(Arc::new(workflow_worker), false)
1363 .expect("workflow worker should register");
1364
1365 let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1366 assert!(
1367 reservation.is_some(),
1368 "should find workflow worker after it's registered"
1369 );
1370 }
1371
1372 #[test]
1373 fn worker_invalid_type_config_rejected() {
1374 let manager = ClientWorkerSet::new();
1375
1376 let mut worker = MockClientWorker::new();
1378 worker
1379 .expect_namespace()
1380 .return_const("test_namespace".to_string());
1381 worker
1382 .expect_task_queue()
1383 .return_const("test_queue".to_string());
1384 worker.expect_deployment_options().return_const(None);
1385 worker
1386 .expect_worker_instance_key()
1387 .return_const(Uuid::new_v4());
1388 worker.expect_heartbeat_enabled().return_const(false);
1389 worker
1390 .expect_worker_task_types()
1391 .return_const(WorkerTaskTypes {
1392 enable_workflows: false,
1393 enable_local_activities: false,
1394 enable_remote_activities: false,
1395 enable_nexus: false,
1396 });
1397
1398 let result = manager.register_worker(Arc::new(worker), false);
1399 assert!(result.is_err());
1400 assert!(
1401 result
1402 .unwrap_err()
1403 .to_string()
1404 .contains("must have at least one capability enabled")
1405 );
1406
1407 let mut worker = MockClientWorker::new();
1409 worker
1410 .expect_namespace()
1411 .return_const("test_namespace".to_string());
1412 worker
1413 .expect_task_queue()
1414 .return_const("test_queue".to_string());
1415 worker.expect_deployment_options().return_const(None);
1416 worker
1417 .expect_worker_instance_key()
1418 .return_const(Uuid::new_v4());
1419 worker.expect_heartbeat_enabled().return_const(false);
1420 worker
1421 .expect_worker_task_types()
1422 .return_const(WorkerTaskTypes {
1423 enable_workflows: false,
1424 enable_local_activities: true,
1425 enable_remote_activities: true,
1426 enable_nexus: false,
1427 });
1428
1429 let result = manager.register_worker(Arc::new(worker), false);
1430 assert!(result.is_err());
1431 assert_eq!(
1432 result.unwrap_err().to_string(),
1433 "Local activities cannot be enabled without workflows".to_string()
1434 );
1435 }
1436
1437 #[test]
1438 fn unregister_with_multiple_workers() {
1439 let manager = ClientWorkerSet::new();
1440 let namespace = "test_namespace".to_string();
1441 let task_queue = "test_queue".to_string();
1442
1443 let mut workflow_worker = MockClientWorker::new();
1445 workflow_worker
1446 .expect_namespace()
1447 .return_const(namespace.clone());
1448 workflow_worker
1449 .expect_task_queue()
1450 .return_const(task_queue.clone());
1451 workflow_worker
1452 .expect_deployment_options()
1453 .return_const(None);
1454 let wf_worker_key = Uuid::new_v4();
1455 workflow_worker
1456 .expect_worker_instance_key()
1457 .return_const(wf_worker_key);
1458 workflow_worker
1459 .expect_heartbeat_enabled()
1460 .return_const(false);
1461 workflow_worker
1462 .expect_worker_task_types()
1463 .return_const(WorkerTaskTypes {
1464 enable_workflows: true,
1465 enable_local_activities: true,
1466 enable_remote_activities: false,
1467 enable_nexus: false,
1468 });
1469 workflow_worker
1470 .expect_try_reserve_wft_slot()
1471 .returning(|| Some(new_mock_slot(false)));
1472
1473 let mut activity_worker = MockClientWorker::new();
1475 activity_worker
1476 .expect_namespace()
1477 .return_const(namespace.clone());
1478 activity_worker
1479 .expect_task_queue()
1480 .return_const(task_queue.clone());
1481 activity_worker
1482 .expect_deployment_options()
1483 .return_const(None);
1484 let act_worker_key = Uuid::new_v4();
1485 activity_worker
1486 .expect_worker_instance_key()
1487 .return_const(act_worker_key);
1488 activity_worker
1489 .expect_heartbeat_enabled()
1490 .return_const(false);
1491 activity_worker
1492 .expect_worker_task_types()
1493 .return_const(WorkerTaskTypes {
1494 enable_workflows: false,
1495 enable_local_activities: false,
1496 enable_remote_activities: true,
1497 enable_nexus: false,
1498 });
1499
1500 manager
1501 .register_worker(Arc::new(workflow_worker), false)
1502 .expect("workflow worker should register");
1503 manager
1504 .register_worker(Arc::new(activity_worker), false)
1505 .expect("activity worker should register");
1506
1507 assert_eq!(2, manager.num_providers());
1508
1509 let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1510 assert!(
1511 reservation.is_some(),
1512 "should be able to reserve slot from workflow worker"
1513 );
1514
1515 manager
1516 .unregister_slot_provider(wf_worker_key)
1517 .expect("should unregister slot provider for workflow worker");
1518 manager
1519 .finalize_unregister(wf_worker_key)
1520 .expect("should finalize unregister for workflow worker");
1521
1522 assert_eq!(1, manager.num_providers());
1524
1525 let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1526 assert!(
1527 reservation.is_none(),
1528 "should not find workflow worker after unregistration"
1529 );
1530
1531 manager
1532 .unregister_slot_provider(act_worker_key)
1533 .expect("should unregister slot provider for activity worker");
1534 manager
1535 .finalize_unregister(act_worker_key)
1536 .expect("should finalize unregister for activity worker");
1537
1538 assert_eq!(0, manager.num_providers());
1539 }
1540
1541 #[test]
1542 fn worker_unregister_order() {
1543 let manager = ClientWorkerSet::new();
1544 let worker = new_mock_provider_with_heartbeat(
1545 "namespace1".to_string(),
1546 "queue1".to_string(),
1547 true,
1548 None,
1549 );
1550 let worker_instance_key = worker.worker_instance_key();
1551 manager.register_worker(Arc::new(worker), false).unwrap();
1552
1553 let res = manager.finalize_unregister(worker_instance_key);
1554 assert!(res.is_err());
1555 let err_string = res.err().map(|e| e.to_string()).unwrap();
1556 assert!(err_string.contains("Worker still in slot_providers during finalize"));
1557
1558 manager
1561 .unregister_slot_provider(worker_instance_key)
1562 .unwrap();
1563 manager.finalize_unregister(worker_instance_key).unwrap();
1564 }
1565}