Skip to main content

temporalio_client/
worker.rs

1//! Contains types and logic for interactions between clients and Core/SDK workers
2
3use anyhow::bail;
4use parking_lot::RwLock;
5use rand::seq::SliceRandom;
6use std::{
7    collections::{
8        HashMap,
9        hash_map::Entry::{Occupied, Vacant},
10    },
11    sync::Arc,
12};
13use temporalio_common::{
14    protos::temporal::api::{
15        worker::v1::WorkerHeartbeat, workflowservice::v1::PollWorkflowTaskQueueResponse,
16    },
17    worker::{WorkerDeploymentOptions, WorkerTaskTypes},
18};
19use uuid::Uuid;
20
21/// This trait represents a slot reserved for processing a WFT by a worker.
22#[cfg_attr(test, mockall::automock)]
23pub trait Slot {
24    /// Consumes this slot by dispatching a WFT to its worker. This can only be called once.
25    fn schedule_wft(
26        self: Box<Self>,
27        task: PollWorkflowTaskQueueResponse,
28    ) -> Result<(), anyhow::Error>;
29}
30
31/// Result of reserving a workflow task slot, including deployment options if applicable.
32pub(crate) struct SlotReservation {
33    /// The reserved slot for processing the workflow task
34    pub slot: Box<dyn Slot + Send>,
35    /// Worker deployment options, if the worker is using deployment-based versioning
36    pub deployment_options: Option<WorkerDeploymentOptions>,
37}
38
39#[derive(PartialEq, Eq, Hash, Debug, Clone)]
40struct SlotKey {
41    namespace: String,
42    task_queue: String,
43}
44
45impl SlotKey {
46    fn new(namespace: String, task_queue: String) -> SlotKey {
47        SlotKey {
48            namespace,
49            task_queue,
50        }
51    }
52}
53
54/// Information about a registered worker in the slot provider registry
55#[derive(Debug, Clone)]
56struct RegisteredWorkerInfo {
57    /// Unique identifier for this worker instance
58    worker_id: Uuid,
59    /// Optional deployment build ID for versioning
60    build_id: Option<String>,
61    /// Task types this worker can handle
62    task_types: WorkerTaskTypes,
63}
64
65impl RegisteredWorkerInfo {
66    fn new(worker_id: Uuid, build_id: Option<String>, task_types: WorkerTaskTypes) -> Self {
67        Self {
68            worker_id,
69            build_id,
70            task_types,
71        }
72    }
73}
74
75/// This is an inner class for [ClientWorkerSet] needed to hide the mutex.
76struct ClientWorkerSetImpl {
77    /// Maps slot keys to registered worker information
78    slot_providers: HashMap<SlotKey, Vec<RegisteredWorkerInfo>>,
79    /// Maps worker_instance_key to registered workers
80    all_workers: HashMap<Uuid, Arc<dyn ClientWorker + Send + Sync>>,
81    /// Maps namespace to shared worker for worker heartbeating
82    shared_worker: HashMap<String, Box<dyn SharedNamespaceWorkerTrait + Send + Sync>>,
83}
84
85impl ClientWorkerSetImpl {
86    /// Factory method.
87    fn new() -> Self {
88        Self {
89            slot_providers: Default::default(),
90            all_workers: Default::default(),
91            shared_worker: Default::default(),
92        }
93    }
94
95    fn try_reserve_wft_slot(
96        &self,
97        namespace: String,
98        task_queue: String,
99    ) -> Option<SlotReservation> {
100        let key = SlotKey::new(namespace, task_queue);
101        if let Some(worker_list) = self.slot_providers.get(&key) {
102            let workflow_workers: Vec<&RegisteredWorkerInfo> = worker_list
103                .iter()
104                .filter(|info| info.task_types.enable_workflows)
105                .collect();
106
107            for worker_id in Self::worker_ids_in_selection_order(&workflow_workers) {
108                if let Some(worker) = self.all_workers.get(&worker_id)
109                    && let Some(slot) = worker.try_reserve_wft_slot()
110                {
111                    let deployment_options = worker.deployment_options();
112                    return Some(SlotReservation {
113                        slot,
114                        deployment_options,
115                    });
116                }
117            }
118        }
119        None
120    }
121
122    fn worker_ids_in_selection_order(worker_list: &[&RegisteredWorkerInfo]) -> Vec<Uuid> {
123        // For tests we return workers in the order they're registered, so we can test
124        // the retry mechanism deterministically
125        if cfg!(test) {
126            worker_list.iter().map(|info| info.worker_id).collect()
127        } else {
128            let mut rng = rand::rng();
129            let mut shuffled: Vec<_> = worker_list.to_vec();
130            shuffled.shuffle(&mut rng);
131            shuffled.iter().map(|info| info.worker_id).collect()
132        }
133    }
134
135    fn register(
136        &mut self,
137        worker: Arc<dyn ClientWorker + Send + Sync>,
138        skip_client_worker_set_check: bool,
139    ) -> Result<(), anyhow::Error> {
140        let slot_key = SlotKey::new(
141            worker.namespace().to_string(),
142            worker.task_queue().to_string(),
143        );
144        let build_id = worker
145            .deployment_options()
146            .map(|opts| opts.version.build_id);
147        let task_types = worker.worker_task_types();
148
149        if !task_types.enable_workflows
150            && !task_types.enable_local_activities
151            && !task_types.enable_remote_activities
152            && !task_types.enable_nexus
153        {
154            bail!(
155                "Worker must have at least one capability enabled (workflows, activities, or nexus)"
156            );
157        }
158
159        if !task_types.enable_workflows && task_types.enable_local_activities {
160            bail!("Local activities cannot be enabled without workflows")
161        }
162
163        if !skip_client_worker_set_check
164            && let Some(existing_workers) = self.slot_providers.get(&slot_key)
165        {
166            for existing_worker_info in existing_workers {
167                if existing_worker_info.build_id.as_ref() == build_id.as_ref()
168                    && task_types.overlaps_with(&existing_worker_info.task_types)
169                {
170                    bail!(
171                        "Registration of multiple workers with overlapping worker task types \
172                        on the same namespace, task queue, and deployment build ID not allowed: \
173                        {slot_key:?}, worker_instance_key: {:?} \
174                        build_id: {build_id:?}, \
175                        new task types: {task_types:?}, \
176                        existing task types: {:?}.",
177                        existing_worker_info.task_types,
178                        worker.worker_instance_key()
179                    );
180                }
181            }
182        }
183
184        if worker.heartbeat_enabled()
185            && let Some(heartbeat_callback) = worker.heartbeat_callback()
186        {
187            let worker_instance_key = worker.worker_instance_key();
188            let namespace = worker.namespace().to_string();
189
190            let shared_worker = match self.shared_worker.entry(namespace.clone()) {
191                Occupied(o) => o.into_mut(),
192                Vacant(v) => {
193                    let shared_worker = worker.new_shared_namespace_worker()?;
194                    v.insert(shared_worker)
195                }
196            };
197            shared_worker.register_callback(worker_instance_key, heartbeat_callback);
198        }
199
200        let worker_info =
201            RegisteredWorkerInfo::new(worker.worker_instance_key(), build_id, task_types);
202
203        match self.slot_providers.entry(slot_key.clone()) {
204            Occupied(o) => o.into_mut().push(worker_info),
205            Vacant(v) => {
206                v.insert(vec![worker_info]);
207            }
208        };
209
210        self.all_workers
211            .insert(worker.worker_instance_key(), worker);
212
213        Ok(())
214    }
215
216    /// Slot provider should be unregistered at the beginning of worker shutdown, in order to disable
217    /// eager workflow start.
218    fn unregister_slot_provider(&mut self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
219        let worker = self.all_workers.get(&worker_instance_key).ok_or_else(|| {
220            anyhow::anyhow!("Worker not in all_workers during slot provider unregister")
221        })?;
222
223        let slot_key = SlotKey::new(
224            worker.namespace().to_string(),
225            worker.task_queue().to_string(),
226        );
227        if let Some(slot_vec) = self.slot_providers.get_mut(&slot_key) {
228            slot_vec.retain(|info| info.worker_id != worker_instance_key);
229            if slot_vec.is_empty() {
230                self.slot_providers.remove(&slot_key);
231            }
232        }
233        Ok(())
234    }
235
236    fn finalize_unregister(
237        &mut self,
238        worker_instance_key: Uuid,
239    ) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
240        if let Some(worker) = self.all_workers.get(&worker_instance_key)
241            && let Some(slot_vec) = self.slot_providers.get(&SlotKey::new(
242                worker.namespace().to_string(),
243                worker.task_queue().to_string(),
244            ))
245            && slot_vec
246                .iter()
247                .any(|info| info.worker_id == worker_instance_key)
248        {
249            return Err(anyhow::anyhow!(
250                "Worker still in slot_providers during finalize"
251            ));
252        }
253
254        let worker = self
255            .all_workers
256            .remove(&worker_instance_key)
257            .ok_or_else(|| anyhow::anyhow!("Worker not found in all_workers"))?;
258
259        if let Some(w) = self.shared_worker.get_mut(worker.namespace()) {
260            let (callback, is_empty) = w.unregister_callback(worker.worker_instance_key());
261            if callback.is_some() && is_empty {
262                self.shared_worker.remove(worker.namespace());
263            }
264        }
265
266        Ok(worker)
267    }
268
269    #[cfg(test)]
270    fn num_providers(&self) -> usize {
271        self.slot_providers.values().map(|v| v.len()).sum()
272    }
273
274    #[cfg(test)]
275    fn num_heartbeat_workers(&self) -> usize {
276        self.shared_worker.values().map(|v| v.num_workers()).sum()
277    }
278}
279
280/// This trait represents a shared namespace worker that sends worker heartbeats and
281/// receives worker commands.
282pub trait SharedNamespaceWorkerTrait {
283    /// Namespace that the shared namespace worker is connected to.
284    fn namespace(&self) -> String;
285
286    /// Registers a heartbeat callback.
287    fn register_callback(&self, worker_instance_key: Uuid, heartbeat_callback: HeartbeatCallback);
288
289    /// Unregisters a heartbeat callback. Returns the callback removed, as well as a bool that
290    /// indicates if there are no remaining callbacks in the SharedNamespaceWorker, indicating
291    /// the shared worker itself can be shut down.
292    fn unregister_callback(&self, worker_instance_key: Uuid) -> (Option<HeartbeatCallback>, bool);
293
294    /// Returns the number of workers registered to this shared worker.
295    fn num_workers(&self) -> usize;
296}
297
298/// Enables local workers to make themselves visible to a shared client instance.
299///
300/// For slot managing, there can only be one worker registered per
301/// namespace+queue_name+connection, others will return an error.
302/// It also provides a convenient method to find compatible slots within the collection.
303pub struct ClientWorkerSet {
304    worker_grouping_key: Uuid,
305    worker_manager: RwLock<ClientWorkerSetImpl>,
306}
307
308impl Default for ClientWorkerSet {
309    fn default() -> Self {
310        Self::new()
311    }
312}
313
314impl ClientWorkerSet {
315    /// Factory method.
316    pub fn new() -> Self {
317        Self {
318            worker_grouping_key: Uuid::new_v4(),
319            worker_manager: RwLock::new(ClientWorkerSetImpl::new()),
320        }
321    }
322
323    /// Try to reserve a compatible processing slot in any of the registered workers.
324    /// Returns the slot and the worker's deployment options (if using deployment-based versioning).
325    pub(crate) fn try_reserve_wft_slot(
326        &self,
327        namespace: String,
328        task_queue: String,
329    ) -> Option<SlotReservation> {
330        self.worker_manager
331            .read()
332            .try_reserve_wft_slot(namespace, task_queue)
333    }
334
335    /// Register a local worker that can provide WFT processing slots and potentially worker heartbeating.
336    pub fn register_worker(
337        &self,
338        worker: Arc<dyn ClientWorker + Send + Sync>,
339        skip_client_worker_set_check: bool,
340    ) -> Result<(), anyhow::Error> {
341        self.worker_manager
342            .write()
343            .register(worker, skip_client_worker_set_check)
344    }
345
346    /// Disables Eager Workflow Start for this worker. This must be called before
347    /// `finalize_unregister`, otherwise `finalize_unregister` will return an err.
348    pub fn unregister_slot_provider(&self, worker_instance_key: Uuid) -> Result<(), anyhow::Error> {
349        self.worker_manager
350            .write()
351            .unregister_slot_provider(worker_instance_key)
352    }
353
354    /// Finalizes unregistering of worker from client. This must be called at the end of worker
355    /// shutdown in order to finalize shutdown for worker heartbeat properly. Must call after
356    /// `unregister_slot_provider`, otherwise an err will be returned.
357    pub fn finalize_unregister(
358        &self,
359        worker_instance_key: Uuid,
360    ) -> Result<Arc<dyn ClientWorker + Send + Sync>, anyhow::Error> {
361        self.worker_manager
362            .write()
363            .finalize_unregister(worker_instance_key)
364    }
365
366    /// Returns the worker grouping key, which is unique for each worker.
367    pub fn worker_grouping_key(&self) -> Uuid {
368        self.worker_grouping_key
369    }
370
371    #[cfg(test)]
372    /// Returns (num_providers, num_buckets), where a bucket key is namespace+task_queue.
373    /// There is only one provider per bucket so `num_providers` should be equal to `num_buckets`.
374    pub fn num_providers(&self) -> usize {
375        self.worker_manager.read().num_providers()
376    }
377
378    #[cfg(test)]
379    /// Returns the total number of heartbeat workers registered across all namespaces.
380    pub fn num_heartbeat_workers(&self) -> usize {
381        self.worker_manager.read().num_heartbeat_workers()
382    }
383}
384
385impl std::fmt::Debug for ClientWorkerSet {
386    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387        f.debug_struct("ClientWorkerSet")
388            .field("worker_grouping_key", &self.worker_grouping_key)
389            .finish()
390    }
391}
392
393/// Contains a worker heartbeat callback, wrapped for mocking
394pub type HeartbeatCallback = Arc<dyn Fn() -> WorkerHeartbeat + Send + Sync>;
395
396/// Represents a complete worker that can handle both slot management
397/// and worker heartbeat functionality.
398#[cfg_attr(test, mockall::automock)]
399pub trait ClientWorker: Send + Sync {
400    /// The namespace this worker operates in
401    fn namespace(&self) -> &str;
402
403    /// The task queue this worker listens to
404    fn task_queue(&self) -> &str;
405
406    /// Try to reserve a slot for workflow task processing.
407    ///
408    /// This method should return `Some(slot)` if a workflow task slot is available,
409    /// or `None` if all slots are currently in use. The returned slot will be used
410    /// to process exactly one workflow task.
411    fn try_reserve_wft_slot(&self) -> Option<Box<dyn Slot + Send>>;
412
413    /// Get the worker deployment options for this worker, if using deployment-based versioning.
414    fn deployment_options(&self) -> Option<WorkerDeploymentOptions>;
415
416    /// Unique identifier for this worker instance.
417    /// This must be stable across the worker's lifetime and unique per instance.
418    fn worker_instance_key(&self) -> Uuid;
419
420    /// Indicates if worker heartbeating is enabled for this client worker.
421    fn heartbeat_enabled(&self) -> bool;
422
423    /// Returns the heartbeat callback that can be used to get WorkerHeartbeat data.
424    fn heartbeat_callback(&self) -> Option<HeartbeatCallback>;
425
426    /// Creates a new worker that implements the [SharedNamespaceWorkerTrait]
427    fn new_shared_namespace_worker(
428        &self,
429    ) -> Result<Box<dyn SharedNamespaceWorkerTrait + Send + Sync>, anyhow::Error>;
430
431    /// Returns the task types this worker can handle
432    fn worker_task_types(&self) -> WorkerTaskTypes;
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    fn new_mock_slot(with_error: bool) -> Box<MockSlot> {
440        let mut mock_slot = MockSlot::new();
441        if with_error {
442            mock_slot
443                .expect_schedule_wft()
444                .returning(|_| Err(anyhow::anyhow!("Changed my mind")));
445        } else {
446            mock_slot.expect_schedule_wft().returning(|_| Ok(()));
447        }
448        Box::new(mock_slot)
449    }
450
451    fn new_mock_provider(
452        namespace: String,
453        task_queue: String,
454        with_error: bool,
455        no_slots: bool,
456        heartbeat_enabled: bool,
457    ) -> MockClientWorker {
458        let mut mock_provider = MockClientWorker::new();
459        mock_provider
460            .expect_try_reserve_wft_slot()
461            .returning(move || {
462                if no_slots {
463                    None
464                } else {
465                    Some(new_mock_slot(with_error))
466                }
467            });
468        mock_provider.expect_namespace().return_const(namespace);
469        mock_provider.expect_task_queue().return_const(task_queue);
470        mock_provider.expect_deployment_options().return_const(None);
471        mock_provider
472            .expect_heartbeat_enabled()
473            .return_const(heartbeat_enabled);
474        mock_provider
475            .expect_worker_instance_key()
476            .return_const(Uuid::new_v4());
477        mock_provider
478            .expect_worker_task_types()
479            .return_const(WorkerTaskTypes {
480                enable_workflows: true,
481                enable_local_activities: true,
482                enable_remote_activities: true,
483                enable_nexus: true,
484            });
485        mock_provider
486    }
487
488    #[test]
489    fn reserve_wft_slot_retries_another_worker_when_first_has_no_slot() {
490        let mut manager = ClientWorkerSetImpl::new();
491        let namespace = "retry_namespace".to_string();
492        let task_queue = "retry_queue".to_string();
493
494        let failing_worker_id = Uuid::new_v4();
495        let mut failing_worker = MockClientWorker::new();
496        failing_worker
497            .expect_try_reserve_wft_slot()
498            .times(1)
499            .returning(|| None);
500        failing_worker
501            .expect_namespace()
502            .return_const(namespace.clone());
503        failing_worker
504            .expect_task_queue()
505            .return_const(task_queue.clone());
506        failing_worker
507            .expect_deployment_options()
508            .return_const(WorkerDeploymentOptions {
509                version: temporalio_common::worker::WorkerDeploymentVersion {
510                    deployment_name: "test-deployment".to_string(),
511                    build_id: "build-fail".to_string(),
512                },
513                use_worker_versioning: true,
514                default_versioning_behavior: None,
515            });
516        failing_worker
517            .expect_worker_instance_key()
518            .return_const(failing_worker_id);
519        failing_worker
520            .expect_heartbeat_enabled()
521            .return_const(false);
522        failing_worker
523            .expect_worker_task_types()
524            .return_const(WorkerTaskTypes {
525                enable_workflows: true,
526                enable_local_activities: true,
527                enable_remote_activities: true,
528                enable_nexus: true,
529            });
530
531        let succeeding_worker_id = Uuid::new_v4();
532        let mut succeeding_worker = MockClientWorker::new();
533        succeeding_worker
534            .expect_try_reserve_wft_slot()
535            .times(1)
536            .returning(|| Some(new_mock_slot(false)));
537        succeeding_worker
538            .expect_namespace()
539            .return_const(namespace.clone());
540        succeeding_worker
541            .expect_task_queue()
542            .return_const(task_queue.clone());
543        let success_deployment_options = WorkerDeploymentOptions {
544            version: temporalio_common::worker::WorkerDeploymentVersion {
545                deployment_name: "test-deployment".to_string(),
546                build_id: "build-success".to_string(),
547            },
548            use_worker_versioning: true,
549            default_versioning_behavior: None,
550        };
551        succeeding_worker
552            .expect_deployment_options()
553            .return_const(success_deployment_options.clone());
554        succeeding_worker
555            .expect_worker_instance_key()
556            .return_const(succeeding_worker_id);
557        succeeding_worker
558            .expect_heartbeat_enabled()
559            .return_const(false);
560        succeeding_worker
561            .expect_worker_task_types()
562            .return_const(WorkerTaskTypes {
563                enable_workflows: true,
564                enable_local_activities: true,
565                enable_remote_activities: true,
566                enable_nexus: true,
567            });
568
569        manager
570            .register(Arc::new(failing_worker), false)
571            .expect("failing worker registration succeeds");
572        manager
573            .register(Arc::new(succeeding_worker), false)
574            .expect("succeeding worker registration succeeds");
575
576        let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
577
578        let reservation_deployment_options = reservation
579            .expect("succeeding worker was used after failing worker failed")
580            .deployment_options
581            .unwrap();
582        assert_eq!(
583            reservation_deployment_options, success_deployment_options,
584            "deployment options bubble through from succeeding worker"
585        );
586    }
587
588    #[test]
589    fn reserve_wft_slot_retries_respects_slot_boundary() {
590        let mut manager = ClientWorkerSetImpl::new();
591        let namespace = "retry_namespace".to_string();
592        let task_queue = "retry_queue".to_string();
593
594        let failing_worker_id = Uuid::new_v4();
595        let mut failing_worker = MockClientWorker::new();
596        failing_worker
597            .expect_try_reserve_wft_slot()
598            .times(1)
599            .returning(|| None);
600        failing_worker
601            .expect_namespace()
602            .return_const(namespace.clone());
603        failing_worker
604            .expect_task_queue()
605            .return_const(task_queue.clone());
606        failing_worker
607            .expect_deployment_options()
608            .return_const(WorkerDeploymentOptions {
609                version: temporalio_common::worker::WorkerDeploymentVersion {
610                    deployment_name: "test-deployment".to_string(),
611                    build_id: "build-fail".to_string(),
612                },
613                use_worker_versioning: true,
614                default_versioning_behavior: None,
615            });
616        failing_worker
617            .expect_worker_instance_key()
618            .return_const(failing_worker_id);
619        failing_worker
620            .expect_heartbeat_enabled()
621            .return_const(false);
622        failing_worker
623            .expect_worker_task_types()
624            .return_const(WorkerTaskTypes {
625                enable_workflows: true,
626                enable_local_activities: true,
627                enable_remote_activities: true,
628                enable_nexus: true,
629            });
630
631        // On a separate task queue
632        let succeeding_worker_id = Uuid::new_v4();
633        let mut succeeding_worker = MockClientWorker::new();
634        succeeding_worker.expect_try_reserve_wft_slot().times(0);
635        succeeding_worker
636            .expect_namespace()
637            .return_const(namespace.clone());
638        succeeding_worker
639            .expect_task_queue()
640            .return_const("other_task_queue".to_string());
641        succeeding_worker
642            .expect_deployment_options()
643            .return_const(None);
644        succeeding_worker
645            .expect_worker_instance_key()
646            .return_const(succeeding_worker_id);
647        succeeding_worker
648            .expect_heartbeat_enabled()
649            .return_const(false);
650        succeeding_worker
651            .expect_worker_task_types()
652            .return_const(WorkerTaskTypes {
653                enable_workflows: true,
654                enable_local_activities: true,
655                enable_remote_activities: true,
656                enable_nexus: true,
657            });
658
659        manager
660            .register(Arc::new(failing_worker), false)
661            .expect("failing worker registration succeeds");
662        manager
663            .register(Arc::new(succeeding_worker), false)
664            .expect("succeeding worker registration succeeds");
665
666        let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
667        assert!(
668            reservation.is_none(),
669            "succeeding_worker should not be picked due to it being on a separate task queue"
670        );
671    }
672
673    #[test]
674    fn registry_keeps_one_provider_per_namespace() {
675        let manager = ClientWorkerSet::new();
676        let mut worker_keys = vec![];
677        let mut successful_registrations = 0;
678
679        for i in 0..10 {
680            let namespace = format!("myId{}", i % 3);
681            let mock_provider =
682                new_mock_provider(namespace, "bar_q".to_string(), false, false, false);
683            let worker_instance_key = mock_provider.worker_instance_key();
684
685            let result = manager.register_worker(Arc::new(mock_provider), false);
686            if let Err(err) = result {
687                // Should get error for overlapping worker task types
688                assert!(err.to_string().contains(
689                    "Registration of multiple workers with overlapping worker task types"
690                ));
691            } else {
692                successful_registrations += 1;
693                worker_keys.push(worker_instance_key);
694            }
695        }
696
697        assert_eq!(successful_registrations, 3);
698        assert_eq!(3, manager.num_providers());
699
700        let count = worker_keys.iter().fold(0, |count, key| {
701            manager.unregister_slot_provider(*key).unwrap();
702            manager.finalize_unregister(*key).unwrap();
703            // expect error since worker is already unregistered
704            let result = manager.unregister_slot_provider(*key);
705            assert!(result.is_err());
706            let result = manager.finalize_unregister(*key);
707            assert!(result.is_err());
708            count + 1
709        });
710        assert_eq!(3, count);
711        assert_eq!(0, manager.num_providers());
712    }
713
714    struct MockSharedNamespaceWorker {
715        namespace: String,
716        callbacks: Arc<RwLock<HashMap<Uuid, HeartbeatCallback>>>,
717    }
718
719    impl std::fmt::Debug for MockSharedNamespaceWorker {
720        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
721            f.debug_struct("MockSharedNamespaceWorker")
722                .field("namespace", &self.namespace)
723                .field("callbacks_count", &self.callbacks.read().len())
724                .finish()
725        }
726    }
727
728    impl MockSharedNamespaceWorker {
729        fn new(namespace: String) -> Self {
730            Self {
731                namespace,
732                callbacks: Arc::new(RwLock::new(HashMap::new())),
733            }
734        }
735    }
736
737    impl SharedNamespaceWorkerTrait for MockSharedNamespaceWorker {
738        fn namespace(&self) -> String {
739            self.namespace.clone()
740        }
741
742        fn register_callback(
743            &self,
744            worker_instance_key: Uuid,
745            heartbeat_callback: HeartbeatCallback,
746        ) {
747            self.callbacks
748                .write()
749                .insert(worker_instance_key, heartbeat_callback);
750        }
751
752        fn unregister_callback(
753            &self,
754            worker_instance_key: Uuid,
755        ) -> (Option<HeartbeatCallback>, bool) {
756            let mut callbacks = self.callbacks.write();
757            let callback = callbacks.remove(&worker_instance_key);
758            let is_empty = callbacks.is_empty();
759            (callback, is_empty)
760        }
761
762        fn num_workers(&self) -> usize {
763            self.callbacks.read().len()
764        }
765    }
766
767    fn new_mock_provider_with_heartbeat(
768        namespace: String,
769        task_queue: String,
770        heartbeat_enabled: bool,
771        build_id: Option<String>,
772    ) -> MockClientWorker {
773        let mut mock_provider = MockClientWorker::new();
774        mock_provider
775            .expect_try_reserve_wft_slot()
776            .returning(|| Some(new_mock_slot(false)));
777        mock_provider
778            .expect_namespace()
779            .return_const(namespace.clone());
780        mock_provider.expect_task_queue().return_const(task_queue);
781        mock_provider
782            .expect_heartbeat_enabled()
783            .return_const(heartbeat_enabled);
784        mock_provider
785            .expect_worker_instance_key()
786            .return_const(Uuid::new_v4());
787        let deployment_name = "test-deployment".to_string();
788        let build_id_for_closure = build_id.clone();
789        mock_provider
790            .expect_deployment_options()
791            .returning(move || {
792                build_id_for_closure
793                    .as_ref()
794                    .map(|build_id| WorkerDeploymentOptions {
795                        version: temporalio_common::worker::WorkerDeploymentVersion {
796                            deployment_name: deployment_name.clone(),
797                            build_id: build_id.clone(),
798                        },
799                        use_worker_versioning: true,
800                        default_versioning_behavior: None,
801                    })
802            });
803
804        if heartbeat_enabled {
805            mock_provider
806                .expect_heartbeat_callback()
807                .returning(|| Some(Arc::new(WorkerHeartbeat::default)));
808
809            let namespace_clone = namespace.clone();
810            mock_provider
811                .expect_new_shared_namespace_worker()
812                .returning(move || {
813                    Ok(Box::new(MockSharedNamespaceWorker::new(
814                        namespace_clone.clone(),
815                    )))
816                });
817        }
818
819        mock_provider
820            .expect_worker_task_types()
821            .return_const(WorkerTaskTypes {
822                enable_workflows: true,
823                enable_local_activities: true,
824                enable_remote_activities: true,
825                enable_nexus: true,
826            });
827
828        mock_provider
829    }
830
831    #[test]
832    fn duplicate_namespace_task_queue_registration_fails() {
833        let manager = ClientWorkerSet::new();
834
835        let worker1 = new_mock_provider_with_heartbeat(
836            "test_namespace".to_string(),
837            "test_queue".to_string(),
838            true,
839            None,
840        );
841
842        // Same namespace+task_queue but different worker instance
843        let worker2 = new_mock_provider_with_heartbeat(
844            "test_namespace".to_string(),
845            "test_queue".to_string(),
846            true,
847            None,
848        );
849
850        manager.register_worker(Arc::new(worker1), false).unwrap();
851
852        // second worker register should fail due to overlapping worker task types
853        let result = manager.register_worker(Arc::new(worker2), false);
854        assert!(result.is_err());
855        assert!(
856            result
857                .unwrap_err()
858                .to_string()
859                .contains("Registration of multiple workers with overlapping worker task types")
860        );
861
862        assert_eq!(1, manager.num_providers());
863        assert_eq!(manager.num_heartbeat_workers(), 1);
864
865        let impl_ref = manager.worker_manager.read();
866        assert_eq!(impl_ref.shared_worker.len(), 1);
867        assert!(impl_ref.shared_worker.contains_key("test_namespace"));
868    }
869
870    #[test]
871    fn duplicate_namespace_with_different_build_ids_succeeds() {
872        let manager = ClientWorkerSet::new();
873        let namespace = "test_namespace".to_string();
874        let task_queue = "test_queue".to_string();
875
876        let worker1 =
877            new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
878        let worker1_instance_key = worker1.worker_instance_key();
879        let worker2 = new_mock_provider_with_heartbeat(
880            namespace.clone(),
881            task_queue.clone(),
882            false,
883            Some("build-1".to_string()),
884        );
885        let worker2_instance_key = worker2.worker_instance_key();
886        let worker3 =
887            new_mock_provider_with_heartbeat(namespace.clone(), task_queue.clone(), false, None);
888        let worker4 = new_mock_provider_with_heartbeat(
889            namespace.clone(),
890            task_queue.clone(),
891            false,
892            Some("build-1".to_string()),
893        );
894
895        manager.register_worker(Arc::new(worker1), false).unwrap();
896
897        manager
898            .register_worker(Arc::new(worker2), false)
899            .expect("worker with new build ID should register");
900        assert_eq!(2, manager.num_providers());
901
902        assert!(
903            manager
904                .register_worker(Arc::new(worker3), false)
905                .unwrap_err()
906                .to_string()
907                .contains("Registration of multiple workers with overlapping worker task types")
908        );
909
910        assert!(
911            manager
912                .register_worker(Arc::new(worker4), false)
913                .unwrap_err()
914                .to_string()
915                .contains("Registration of multiple workers with overlapping worker task types")
916        );
917        assert_eq!(2, manager.num_providers());
918
919        {
920            let impl_ref = manager.worker_manager.read();
921            let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
922            let providers = impl_ref
923                .slot_providers
924                .get(&slot_key)
925                .expect("slot providers should exist for namespace/task queue");
926            assert_eq!(2, providers.len());
927
928            assert_eq!(providers[0].worker_id, worker1_instance_key);
929            assert_eq!(providers[0].build_id, None);
930            assert_eq!(providers[1].worker_id, worker2_instance_key);
931            assert_eq!(providers[1].build_id, Some("build-1".to_string()));
932        }
933
934        manager
935            .unregister_slot_provider(worker2_instance_key)
936            .unwrap();
937        manager.finalize_unregister(worker2_instance_key).unwrap();
938
939        {
940            let impl_ref = manager.worker_manager.read();
941            let slot_key = SlotKey::new(namespace.clone(), task_queue.clone());
942            let providers = impl_ref
943                .slot_providers
944                .get(&slot_key)
945                .expect("slot providers should exist for namespace/task queue");
946
947            assert_eq!(1, providers.len());
948            assert_eq!(providers[0].worker_id, worker1_instance_key);
949            assert_eq!(providers[0].build_id, None);
950        }
951    }
952
953    #[test]
954    fn multiple_workers_same_namespace_share_heartbeat_manager() {
955        let manager = ClientWorkerSet::new();
956
957        let worker1 = new_mock_provider_with_heartbeat(
958            "shared_namespace".to_string(),
959            "queue1".to_string(),
960            true,
961            None,
962        );
963
964        // Same namespace but different task queue
965        let worker2 = new_mock_provider_with_heartbeat(
966            "shared_namespace".to_string(),
967            "queue2".to_string(),
968            true,
969            None,
970        );
971
972        manager.register_worker(Arc::new(worker1), false).unwrap();
973        manager.register_worker(Arc::new(worker2), false).unwrap();
974
975        assert_eq!(2, manager.num_providers());
976        assert_eq!(manager.num_heartbeat_workers(), 2);
977
978        let impl_ref = manager.worker_manager.read();
979        assert_eq!(impl_ref.shared_worker.len(), 1);
980        assert!(impl_ref.shared_worker.contains_key("shared_namespace"));
981
982        let shared_worker = impl_ref.shared_worker.get("shared_namespace").unwrap();
983        assert_eq!(shared_worker.namespace(), "shared_namespace");
984    }
985
986    #[test]
987    fn different_namespaces_get_separate_heartbeat_managers() {
988        let manager = ClientWorkerSet::new();
989        let worker1 = new_mock_provider_with_heartbeat(
990            "namespace1".to_string(),
991            "queue1".to_string(),
992            true,
993            None,
994        );
995        let worker2 = new_mock_provider_with_heartbeat(
996            "namespace2".to_string(),
997            "queue1".to_string(),
998            true,
999            None,
1000        );
1001
1002        manager.register_worker(Arc::new(worker1), false).unwrap();
1003        manager.register_worker(Arc::new(worker2), false).unwrap();
1004
1005        assert_eq!(2, manager.num_providers());
1006        assert_eq!(manager.num_heartbeat_workers(), 2);
1007
1008        let impl_ref = manager.worker_manager.read();
1009        assert_eq!(impl_ref.num_heartbeat_workers(), 2);
1010        assert!(impl_ref.shared_worker.contains_key("namespace1"));
1011        assert!(impl_ref.shared_worker.contains_key("namespace2"));
1012    }
1013
1014    #[test]
1015    fn unregister_heartbeat_workers_cleans_up_shared_worker_when_last_removed() {
1016        let manager = ClientWorkerSet::new();
1017
1018        // Create two workers with same namespace but different task queues
1019        let worker1 = new_mock_provider_with_heartbeat(
1020            "test_namespace".to_string(),
1021            "queue1".to_string(),
1022            true,
1023            None,
1024        );
1025        let worker2 = new_mock_provider_with_heartbeat(
1026            "test_namespace".to_string(),
1027            "queue2".to_string(),
1028            true,
1029            None,
1030        );
1031        let worker_instance_key1 = worker1.worker_instance_key();
1032        let worker_instance_key2 = worker2.worker_instance_key();
1033
1034        assert_ne!(worker_instance_key1, worker_instance_key2);
1035
1036        manager.register_worker(Arc::new(worker1), false).unwrap();
1037        manager.register_worker(Arc::new(worker2), false).unwrap();
1038
1039        // Verify initial state: 2 slot providers, 2 heartbeat workers, 1 shared worker
1040        assert_eq!(2, manager.num_providers());
1041        assert_eq!(manager.num_heartbeat_workers(), 2);
1042
1043        let impl_ref = manager.worker_manager.read();
1044        assert_eq!(impl_ref.shared_worker.len(), 1);
1045        assert!(impl_ref.shared_worker.contains_key("test_namespace"));
1046        assert_eq!(
1047            impl_ref
1048                .shared_worker
1049                .get("test_namespace")
1050                .unwrap()
1051                .num_workers(),
1052            2
1053        );
1054        drop(impl_ref);
1055
1056        // Unregister first worker
1057        manager
1058            .unregister_slot_provider(worker_instance_key1)
1059            .unwrap();
1060        manager.finalize_unregister(worker_instance_key1).unwrap();
1061
1062        // After unregistering first worker: 1 slot provider, 1 heartbeat worker, shared worker still exists
1063        assert_eq!(1, manager.num_providers());
1064        assert_eq!(manager.num_heartbeat_workers(), 1);
1065
1066        let impl_ref = manager.worker_manager.read();
1067        assert_eq!(impl_ref.num_heartbeat_workers(), 1); // SharedNamespaceWorker still exists
1068        assert!(impl_ref.shared_worker.contains_key("test_namespace"));
1069        assert_eq!(
1070            impl_ref
1071                .shared_worker
1072                .get("test_namespace")
1073                .unwrap()
1074                .num_workers(),
1075            1
1076        );
1077        drop(impl_ref);
1078
1079        // Unregister second worker
1080        manager
1081            .unregister_slot_provider(worker_instance_key2)
1082            .unwrap();
1083        manager.finalize_unregister(worker_instance_key2).unwrap();
1084
1085        // After unregistering last worker: 0 slot providers, 0 heartbeat workers, shared worker is removed
1086        assert_eq!(0, manager.num_providers());
1087        assert_eq!(manager.num_heartbeat_workers(), 0);
1088
1089        let impl_ref = manager.worker_manager.read();
1090        assert_eq!(impl_ref.shared_worker.len(), 0); // SharedNamespaceWorker is cleaned up
1091        assert!(!impl_ref.shared_worker.contains_key("test_namespace"));
1092    }
1093
1094    #[test]
1095    fn workflow_and_activity_only_workers_coexist() {
1096        let manager = ClientWorkerSet::new();
1097        let namespace = "test_namespace".to_string();
1098        let task_queue = "test_queue".to_string();
1099
1100        let mut workflow_nexus_worker = MockClientWorker::new();
1101        workflow_nexus_worker
1102            .expect_namespace()
1103            .return_const(namespace.clone());
1104        workflow_nexus_worker
1105            .expect_task_queue()
1106            .return_const(task_queue.clone());
1107        workflow_nexus_worker
1108            .expect_deployment_options()
1109            .return_const(None);
1110        workflow_nexus_worker
1111            .expect_worker_instance_key()
1112            .return_const(Uuid::new_v4());
1113        workflow_nexus_worker
1114            .expect_heartbeat_enabled()
1115            .return_const(false);
1116        workflow_nexus_worker
1117            .expect_worker_task_types()
1118            .return_const(WorkerTaskTypes {
1119                enable_workflows: true,
1120                enable_local_activities: false,
1121                enable_remote_activities: false,
1122                enable_nexus: true,
1123            });
1124
1125        let mut activity_worker = MockClientWorker::new();
1126        activity_worker
1127            .expect_namespace()
1128            .return_const(namespace.clone());
1129        activity_worker
1130            .expect_task_queue()
1131            .return_const(task_queue.clone());
1132        activity_worker
1133            .expect_deployment_options()
1134            .return_const(None);
1135        activity_worker
1136            .expect_worker_instance_key()
1137            .return_const(Uuid::new_v4());
1138        activity_worker
1139            .expect_heartbeat_enabled()
1140            .return_const(false);
1141        activity_worker
1142            .expect_worker_task_types()
1143            .return_const(WorkerTaskTypes {
1144                enable_workflows: false,
1145                enable_local_activities: false,
1146                enable_remote_activities: true,
1147                enable_nexus: false,
1148            });
1149        activity_worker.expect_try_reserve_wft_slot().times(0); // Should not be called for activity-only worker
1150
1151        manager
1152            .register_worker(Arc::new(workflow_nexus_worker), false)
1153            .expect("workflow-nexus worker should register");
1154        manager
1155            .register_worker(Arc::new(activity_worker), false)
1156            .expect("activity-only worker should register");
1157
1158        assert_eq!(2, manager.num_providers());
1159    }
1160
1161    #[test]
1162    fn overlapping_capabilities_rejected() {
1163        let manager = ClientWorkerSet::new();
1164        let namespace = "test_namespace".to_string();
1165        let task_queue = "test_queue".to_string();
1166
1167        // workflow+activity worker
1168        let mut worker1 = MockClientWorker::new();
1169        worker1.expect_namespace().return_const(namespace.clone());
1170        worker1.expect_task_queue().return_const(task_queue.clone());
1171        worker1.expect_deployment_options().return_const(None);
1172        worker1
1173            .expect_worker_instance_key()
1174            .return_const(Uuid::new_v4());
1175        worker1.expect_heartbeat_enabled().return_const(false);
1176        worker1
1177            .expect_worker_task_types()
1178            .return_const(WorkerTaskTypes {
1179                enable_workflows: true,
1180                enable_local_activities: true,
1181                enable_remote_activities: true,
1182                enable_nexus: false,
1183            });
1184
1185        // workflow+activity worker
1186        let mut worker2 = MockClientWorker::new();
1187        worker2.expect_namespace().return_const(namespace.clone());
1188        worker2.expect_task_queue().return_const(task_queue.clone());
1189        worker2.expect_deployment_options().return_const(None);
1190        worker2
1191            .expect_worker_instance_key()
1192            .return_const(Uuid::new_v4());
1193        worker2.expect_heartbeat_enabled().return_const(false);
1194        worker2
1195            .expect_worker_task_types()
1196            .return_const(WorkerTaskTypes {
1197                enable_workflows: true,
1198                enable_local_activities: true,
1199                enable_remote_activities: true,
1200                enable_nexus: false,
1201            });
1202
1203        manager
1204            .register_worker(Arc::new(worker1), false)
1205            .expect("first worker should register");
1206
1207        let result = manager.register_worker(Arc::new(worker2), false);
1208        assert!(result.is_err());
1209        assert!(
1210            result
1211                .unwrap_err()
1212                .to_string()
1213                .contains("overlapping worker task types")
1214        );
1215
1216        // activity-only worker
1217        let mut worker3 = MockClientWorker::new();
1218        worker3.expect_namespace().return_const(namespace.clone());
1219        worker3.expect_task_queue().return_const(task_queue.clone());
1220        worker3.expect_deployment_options().return_const(None);
1221        worker3
1222            .expect_worker_instance_key()
1223            .return_const(Uuid::new_v4());
1224        worker3.expect_heartbeat_enabled().return_const(false);
1225        worker3
1226            .expect_worker_task_types()
1227            .return_const(WorkerTaskTypes {
1228                enable_workflows: false,
1229                enable_local_activities: false,
1230                enable_remote_activities: true,
1231                enable_nexus: false,
1232            });
1233
1234        let result = manager.register_worker(Arc::new(worker3), false);
1235        assert!(result.is_err());
1236        assert!(
1237            result
1238                .unwrap_err()
1239                .to_string()
1240                .contains("overlapping worker task types")
1241        );
1242    }
1243
1244    #[test]
1245    fn wft_slot_reservation_ignores_non_workflow_workers() {
1246        let mut manager_impl = ClientWorkerSetImpl::new();
1247        let namespace = "test_namespace".to_string();
1248        let task_queue = "test_queue".to_string();
1249
1250        let mut activity_worker = MockClientWorker::new();
1251        activity_worker
1252            .expect_namespace()
1253            .return_const(namespace.clone());
1254        activity_worker
1255            .expect_task_queue()
1256            .return_const(task_queue.clone());
1257        activity_worker
1258            .expect_deployment_options()
1259            .return_const(None);
1260        activity_worker
1261            .expect_worker_instance_key()
1262            .return_const(Uuid::new_v4());
1263        activity_worker
1264            .expect_heartbeat_enabled()
1265            .return_const(false);
1266        activity_worker
1267            .expect_worker_task_types()
1268            .return_const(WorkerTaskTypes {
1269                enable_workflows: false,
1270                enable_local_activities: false,
1271                enable_remote_activities: true,
1272                enable_nexus: false,
1273            });
1274
1275        let mut nexus_worker = MockClientWorker::new();
1276        nexus_worker
1277            .expect_namespace()
1278            .return_const(namespace.clone());
1279        nexus_worker
1280            .expect_task_queue()
1281            .return_const(task_queue.clone());
1282        nexus_worker.expect_deployment_options().return_const(None);
1283        nexus_worker
1284            .expect_worker_instance_key()
1285            .return_const(Uuid::new_v4());
1286        nexus_worker.expect_heartbeat_enabled().return_const(false);
1287        nexus_worker
1288            .expect_worker_task_types()
1289            .return_const(WorkerTaskTypes {
1290                enable_workflows: false,
1291                enable_local_activities: false,
1292                enable_remote_activities: false,
1293                enable_nexus: true,
1294            });
1295
1296        manager_impl
1297            .register(Arc::new(activity_worker), false)
1298            .expect("activity worker should register");
1299        manager_impl
1300            .register(Arc::new(nexus_worker), false)
1301            .expect("nexus worker should register");
1302
1303        let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1304        assert!(
1305            reservation.is_none(),
1306            "should not find workflow workers when only activity/nexus workers registered"
1307        );
1308
1309        // Now register a workflow worker
1310        let mut workflow_worker = MockClientWorker::new();
1311        workflow_worker
1312            .expect_namespace()
1313            .return_const(namespace.clone());
1314        workflow_worker
1315            .expect_task_queue()
1316            .return_const(task_queue.clone());
1317        workflow_worker
1318            .expect_deployment_options()
1319            .return_const(None);
1320        workflow_worker
1321            .expect_worker_instance_key()
1322            .return_const(Uuid::new_v4());
1323        workflow_worker
1324            .expect_heartbeat_enabled()
1325            .return_const(false);
1326        workflow_worker
1327            .expect_worker_task_types()
1328            .return_const(WorkerTaskTypes {
1329                enable_workflows: true,
1330                enable_local_activities: true,
1331                enable_remote_activities: false,
1332                enable_nexus: false,
1333            });
1334        workflow_worker
1335            .expect_try_reserve_wft_slot()
1336            .times(1)
1337            .returning(|| Some(new_mock_slot(false)));
1338
1339        manager_impl
1340            .register(Arc::new(workflow_worker), false)
1341            .expect("workflow worker should register");
1342
1343        let reservation = manager_impl.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1344        assert!(
1345            reservation.is_some(),
1346            "should find workflow worker after it's registered"
1347        );
1348    }
1349
1350    #[test]
1351    fn worker_invalid_type_config_rejected() {
1352        let manager = ClientWorkerSet::new();
1353
1354        // no types enabled
1355        let mut worker = MockClientWorker::new();
1356        worker
1357            .expect_namespace()
1358            .return_const("test_namespace".to_string());
1359        worker
1360            .expect_task_queue()
1361            .return_const("test_queue".to_string());
1362        worker.expect_deployment_options().return_const(None);
1363        worker
1364            .expect_worker_instance_key()
1365            .return_const(Uuid::new_v4());
1366        worker.expect_heartbeat_enabled().return_const(false);
1367        worker
1368            .expect_worker_task_types()
1369            .return_const(WorkerTaskTypes {
1370                enable_workflows: false,
1371                enable_local_activities: false,
1372                enable_remote_activities: false,
1373                enable_nexus: false,
1374            });
1375
1376        let result = manager.register_worker(Arc::new(worker), false);
1377        assert!(result.is_err());
1378        assert!(
1379            result
1380                .unwrap_err()
1381                .to_string()
1382                .contains("must have at least one capability enabled")
1383        );
1384
1385        // local activities enabled without workflows
1386        let mut worker = MockClientWorker::new();
1387        worker
1388            .expect_namespace()
1389            .return_const("test_namespace".to_string());
1390        worker
1391            .expect_task_queue()
1392            .return_const("test_queue".to_string());
1393        worker.expect_deployment_options().return_const(None);
1394        worker
1395            .expect_worker_instance_key()
1396            .return_const(Uuid::new_v4());
1397        worker.expect_heartbeat_enabled().return_const(false);
1398        worker
1399            .expect_worker_task_types()
1400            .return_const(WorkerTaskTypes {
1401                enable_workflows: false,
1402                enable_local_activities: true,
1403                enable_remote_activities: true,
1404                enable_nexus: false,
1405            });
1406
1407        let result = manager.register_worker(Arc::new(worker), false);
1408        assert!(result.is_err());
1409        assert_eq!(
1410            result.unwrap_err().to_string(),
1411            "Local activities cannot be enabled without workflows".to_string()
1412        );
1413    }
1414
1415    #[test]
1416    fn unregister_with_multiple_workers() {
1417        let manager = ClientWorkerSet::new();
1418        let namespace = "test_namespace".to_string();
1419        let task_queue = "test_queue".to_string();
1420
1421        // workflow-only worker
1422        let mut workflow_worker = MockClientWorker::new();
1423        workflow_worker
1424            .expect_namespace()
1425            .return_const(namespace.clone());
1426        workflow_worker
1427            .expect_task_queue()
1428            .return_const(task_queue.clone());
1429        workflow_worker
1430            .expect_deployment_options()
1431            .return_const(None);
1432        let wf_worker_key = Uuid::new_v4();
1433        workflow_worker
1434            .expect_worker_instance_key()
1435            .return_const(wf_worker_key);
1436        workflow_worker
1437            .expect_heartbeat_enabled()
1438            .return_const(false);
1439        workflow_worker
1440            .expect_worker_task_types()
1441            .return_const(WorkerTaskTypes {
1442                enable_workflows: true,
1443                enable_local_activities: true,
1444                enable_remote_activities: false,
1445                enable_nexus: false,
1446            });
1447        workflow_worker
1448            .expect_try_reserve_wft_slot()
1449            .returning(|| Some(new_mock_slot(false)));
1450
1451        // activity-only worker
1452        let mut activity_worker = MockClientWorker::new();
1453        activity_worker
1454            .expect_namespace()
1455            .return_const(namespace.clone());
1456        activity_worker
1457            .expect_task_queue()
1458            .return_const(task_queue.clone());
1459        activity_worker
1460            .expect_deployment_options()
1461            .return_const(None);
1462        let act_worker_key = Uuid::new_v4();
1463        activity_worker
1464            .expect_worker_instance_key()
1465            .return_const(act_worker_key);
1466        activity_worker
1467            .expect_heartbeat_enabled()
1468            .return_const(false);
1469        activity_worker
1470            .expect_worker_task_types()
1471            .return_const(WorkerTaskTypes {
1472                enable_workflows: false,
1473                enable_local_activities: false,
1474                enable_remote_activities: true,
1475                enable_nexus: false,
1476            });
1477
1478        manager
1479            .register_worker(Arc::new(workflow_worker), false)
1480            .expect("workflow worker should register");
1481        manager
1482            .register_worker(Arc::new(activity_worker), false)
1483            .expect("activity worker should register");
1484
1485        assert_eq!(2, manager.num_providers());
1486
1487        let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1488        assert!(
1489            reservation.is_some(),
1490            "should be able to reserve slot from workflow worker"
1491        );
1492
1493        manager
1494            .unregister_slot_provider(wf_worker_key)
1495            .expect("should unregister slot provider for workflow worker");
1496        manager
1497            .finalize_unregister(wf_worker_key)
1498            .expect("should finalize unregister for workflow worker");
1499
1500        // Activity worker should still be registered
1501        assert_eq!(1, manager.num_providers());
1502
1503        let reservation = manager.try_reserve_wft_slot(namespace.clone(), task_queue.clone());
1504        assert!(
1505            reservation.is_none(),
1506            "should not find workflow worker after unregistration"
1507        );
1508
1509        manager
1510            .unregister_slot_provider(act_worker_key)
1511            .expect("should unregister slot provider for activity worker");
1512        manager
1513            .finalize_unregister(act_worker_key)
1514            .expect("should finalize unregister for activity worker");
1515
1516        assert_eq!(0, manager.num_providers());
1517    }
1518
1519    #[test]
1520    fn worker_unregister_order() {
1521        let manager = ClientWorkerSet::new();
1522        let worker = new_mock_provider_with_heartbeat(
1523            "namespace1".to_string(),
1524            "queue1".to_string(),
1525            true,
1526            None,
1527        );
1528        let worker_instance_key = worker.worker_instance_key();
1529        manager.register_worker(Arc::new(worker), false).unwrap();
1530
1531        let res = manager.finalize_unregister(worker_instance_key);
1532        assert!(res.is_err());
1533        let err_string = res.err().map(|e| e.to_string()).unwrap();
1534        assert!(err_string.contains("Worker still in slot_providers during finalize"));
1535
1536        // previous incorrect call to finalize_unregister should not cause any state leaks when
1537        // properly removed later
1538        manager
1539            .unregister_slot_provider(worker_instance_key)
1540            .unwrap();
1541        manager.finalize_unregister(worker_instance_key).unwrap();
1542    }
1543}